crnn.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315
  1. import torch.nn as nn
  2. import torch.nn.functional as F
  3. class BidirectionalLSTM(nn.Module):
  4. # Inputs hidden units Out
  5. def __init__(self, nIn, nHidden, nOut):
  6. super(BidirectionalLSTM, self).__init__()
  7. self.rnn = nn.LSTM(nIn, nHidden, bidirectional=True)
  8. self.embedding = nn.Linear(nHidden * 2, nOut)
  9. def forward(self, input):
  10. recurrent, _ = self.rnn(input)
  11. T, b, h = recurrent.size()
  12. t_rec = recurrent.view(T * b, h)
  13. output = self.embedding(t_rec) # [T * b, nOut]
  14. output = output.view(T, b, -1)
  15. return output
  16. class CRNN(nn.Module):
  17. def __init__(self, imgH, nc, nclass, nh, n_rnn=2, leakyRelu=False):
  18. super(CRNN, self).__init__()
  19. assert imgH % 16 == 0, 'imgH has to be a multiple of 16'
  20. ks = [3, 3, 3, 3, 3, 3, 2]
  21. ps = [1, 1, 1, 1, 1, 1, 0]
  22. ss = [1, 1, 1, 1, 1, 1, 1]
  23. nm = [64, 128, 256, 256, 512, 512, 512]
  24. cnn = nn.Sequential()
  25. def convRelu(i, batchNormalization=False):
  26. nIn = nc if i == 0 else nm[i - 1]
  27. nOut = nm[i]
  28. cnn.add_module('conv{0}'.format(i),
  29. nn.Conv2d(nIn, nOut, ks[i], ss[i], ps[i]))
  30. if batchNormalization:
  31. cnn.add_module('batchnorm{0}'.format(i), nn.BatchNorm2d(nOut))
  32. if leakyRelu:
  33. cnn.add_module('relu{0}'.format(i),
  34. nn.LeakyReLU(0.2, inplace=True))
  35. else:
  36. cnn.add_module('relu{0}'.format(i), nn.ReLU(True))
  37. convRelu(0)
  38. cnn.add_module('pooling{0}'.format(0), nn.MaxPool2d(2, 2)) # 64x16x64
  39. convRelu(1)
  40. cnn.add_module('pooling{0}'.format(1), nn.MaxPool2d(2, 2)) # 128x8x32
  41. convRelu(2, True)
  42. convRelu(3)
  43. cnn.add_module('pooling{0}'.format(2),
  44. nn.MaxPool2d((2, 2), (2, 1), (0, 1))) # 256x4x16
  45. convRelu(4, True)
  46. convRelu(5)
  47. cnn.add_module('pooling{0}'.format(3),
  48. nn.MaxPool2d((2, 2), (2, 1), (0, 1))) # 512x2x16
  49. convRelu(6, True) # 512x1x16
  50. self.cnn = cnn
  51. self.rnn = nn.Sequential(
  52. BidirectionalLSTM(512, nh, nh),
  53. BidirectionalLSTM(nh, nh, nclass))
  54. def forward(self, input):
  55. # conv features
  56. conv = self.cnn(input)
  57. print('size', conv.size())
  58. b, c, h, w = conv.size()
  59. print("b, c, h, w :",b, c, h, w )
  60. assert h == 1, "the height of conv must be 1"
  61. conv = conv.squeeze(2) # b *512 * width
  62. print('size', conv.size())
  63. conv = conv.permute(2, 0, 1) # [w, b, c]
  64. print('size', conv.size())
  65. output = F.log_softmax(self.rnn(conv), dim=2)
  66. print('size', output.size())
  67. return output
  68. class SeModule(nn.Module):
  69. def __init__(self, in_size, reduction=4):
  70. super(SeModule, self).__init__()
  71. expand_size = max(in_size // reduction, 8)
  72. self.se = nn.Sequential(
  73. nn.AdaptiveAvgPool2d(1),
  74. nn.Conv2d(in_size, expand_size, kernel_size=1, bias=False),
  75. nn.BatchNorm2d(expand_size),
  76. nn.ReLU(inplace=True),
  77. nn.Conv2d(expand_size, in_size, kernel_size=1, bias=False),
  78. nn.Hardsigmoid()
  79. )
  80. def forward(self, x):
  81. return x * self.se(x)
  82. class Block(nn.Module):
  83. '''expand + depthwise + pointwise'''
  84. def __init__(self, kernel_size, in_size, expand_size, out_size, act, se, stride):
  85. super(Block, self).__init__()
  86. self.stride = stride
  87. self.conv1 = nn.Conv2d(in_size, expand_size, kernel_size=1, bias=False)
  88. self.bn1 = nn.BatchNorm2d(expand_size)
  89. self.act1 = act(inplace=True)
  90. self.conv2 = nn.Conv2d(expand_size, expand_size, kernel_size=kernel_size, stride=stride,
  91. padding=kernel_size // 2, groups=expand_size, bias=False)
  92. self.bn2 = nn.BatchNorm2d(expand_size)
  93. self.act2 = act(inplace=True)
  94. self.se = SeModule(expand_size) if se else nn.Identity()
  95. self.conv3 = nn.Conv2d(expand_size, out_size, kernel_size=1, bias=False)
  96. self.bn3 = nn.BatchNorm2d(out_size)
  97. self.act3 = act(inplace=True)
  98. self.skip = None
  99. if stride == 1 and in_size != out_size:
  100. self.skip = nn.Sequential(
  101. nn.Conv2d(in_size, out_size, kernel_size=1, bias=False),
  102. nn.BatchNorm2d(out_size)
  103. )
  104. if stride == 2 and in_size != out_size:
  105. self.skip = nn.Sequential(
  106. nn.Conv2d(in_channels=in_size, out_channels=in_size, kernel_size=3, groups=in_size, stride=2, padding=1,
  107. bias=False),
  108. nn.BatchNorm2d(in_size),
  109. nn.Conv2d(in_size, out_size, kernel_size=1, bias=True),
  110. nn.BatchNorm2d(out_size)
  111. )
  112. if stride == 2 and in_size == out_size:
  113. self.skip = nn.Sequential(
  114. nn.Conv2d(in_channels=in_size, out_channels=out_size, kernel_size=3, groups=in_size, stride=2,
  115. padding=1, bias=False),
  116. nn.BatchNorm2d(out_size)
  117. )
  118. def forward(self, x):
  119. skip = x
  120. out = self.act1(self.bn1(self.conv1(x)))
  121. out = self.act2(self.bn2(self.conv2(out)))
  122. out = self.se(out)
  123. out = self.bn3(self.conv3(out))
  124. if self.skip is not None:
  125. skip = self.skip(skip)
  126. return self.act3(out + skip)
  127. from torch.nn import init
  128. class MobileNetV3_Large(nn.Module):
  129. def __init__(self, num_classes=1000, act=nn.Hardswish):
  130. self.num_classes = num_classes
  131. super(MobileNetV3_Large, self).__init__()
  132. # self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=2, padding=1, bias=False)
  133. self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=2, padding=1, bias=False)
  134. self.bn1 = nn.BatchNorm2d(16)
  135. self.hs1 = act(inplace=True)
  136. self.bneck = nn.Sequential(
  137. Block(3, 16, 16, 16, nn.ReLU, False, 1),
  138. Block(3, 16, 64, 24, nn.ReLU, False, 2),
  139. Block(3, 24, 72, 24, nn.ReLU, False, 1),
  140. Block(5, 24, 72, 40, nn.ReLU, True, 2),
  141. Block(5, 40, 120, 40, nn.ReLU, True, 1),
  142. Block(5, 40, 120, 40, nn.ReLU, True, 1),
  143. Block(3, 40, 240, 80, act, False, 2),
  144. Block(3, 80, 200, 80, act, False, 1),
  145. Block(3, 80, 184, 80, act, False, 1),
  146. Block(3, 80, 184, 80, act, False, 1),
  147. Block(3, 80, 480, 112, act, True, 1),
  148. Block(3, 112, 672, 112, act, True, 1),
  149. Block(5, 112, 672, 160, act, True, 2),
  150. Block(5, 160, 672, 160, act, True, 1),
  151. Block(5, 160, 960, 160, act, True, 1)
  152. # Block(3, 160, 960, 160, act, True, 1),
  153. )
  154. self.conv2 = nn.Conv2d(160, 960, kernel_size=1, stride=1, padding=0, bias=False)
  155. self.bn2 = nn.BatchNorm2d(960)
  156. # self.bneck = nn.Sequential(
  157. # Block(3, 16, 16, 16, nn.ReLU, True, 2),
  158. # Block(3, 16, 72, 24, nn.ReLU, False, 2),
  159. # Block(3, 24, 88, 24, nn.ReLU, False, 1),
  160. # Block(5, 24, 96, 40, act, True, 2),
  161. # Block(5, 40, 240, 40, act, True, 1),
  162. # Block(5, 40, 240, 40, act, True, 1),
  163. # Block(5, 40, 120, 48, act, True, 1),
  164. # Block(5, 48, 144, 48, act, True, 1),
  165. # Block(5, 48, 288, 96, act, True, 2),
  166. # Block(5, 96, 576, 96, act, True, 1),
  167. # Block(5, 96, 576, 96, act, True, 1),
  168. # )
  169. #
  170. # self.conv2 = nn.Conv2d(96, 576, kernel_size=1, stride=1, padding=0, bias=False)
  171. # self.bn2 = nn.BatchNorm2d(576)
  172. self.hs2 = act(inplace=True)
  173. # self.gap = nn.AdaptiveAvgPool2d(1)
  174. # self.gap = nn.MaxPool2d(kernel_size=2,stride=2,padding=0)
  175. # self.linear3 = nn.Linear(576, 1280, bias=False)
  176. # # self.linear3 = nn.Linear(1920, num_classes, bias=False)
  177. # self.bn3 = nn.BatchNorm1d(1280)
  178. # # self.bn3 = nn.BatchNorm1d(10)
  179. # self.hs3 = act(inplace=True)
  180. # self.drop = nn.Dropout(0.2)
  181. self.pool = nn.MaxPool2d(kernel_size=2, stride=2, padding=0)
  182. self.rnn = nn.Sequential(
  183. BidirectionalLSTM(960, 256, 256),
  184. BidirectionalLSTM(256, 256, self.num_classes))
  185. # self.linear4 = nn.Linear(1280, num_classes)
  186. self.init_params()
  187. def init_params(self):
  188. for m in self.modules():
  189. if isinstance(m, nn.Conv2d):
  190. init.kaiming_normal_(m.weight, mode='fan_out')
  191. if m.bias is not None:
  192. init.constant_(m.bias, 0)
  193. elif isinstance(m, nn.BatchNorm2d):
  194. init.constant_(m.weight, 1)
  195. init.constant_(m.bias, 0)
  196. elif isinstance(m, nn.Linear):
  197. init.normal_(m.weight, std=0.001)
  198. if m.bias is not None:
  199. init.constant_(m.bias, 0)
  200. def forward(self, x):
  201. out = self.hs1(self.bn1(self.conv1(x)))
  202. # out = self.conv1(x)
  203. # print('size1',out.size())
  204. out = self.bneck(out)
  205. # print('size2',out.size())
  206. out = self.hs2(self.bn2(self.conv2(out)))
  207. # out = self.conv2(out)
  208. # print('size3',out.size())
  209. # out = out.permute(0,3,1,2)
  210. # out = self.pool(out)
  211. B, C, H, W = out.shape
  212. assert H == 1
  213. out = out.squeeze(2)
  214. # print('size', out.size())
  215. out = out.permute(2, 0, 1) # [w, b, c]
  216. # out = out.permute(0, 2, 1) # [b, w, c]
  217. # print('size', out.size())
  218. out = self.rnn(out)
  219. # print('size',out.size())
  220. # out = self.gap(out).flatten(2)
  221. # out = out.flatten(2)
  222. # print('size',out.size())
  223. # out_size = out.size(1)
  224. # out = self.gap(out)
  225. # out = self.drop(self.hs3(self.bn3(self.linear3(out))))
  226. # out = self.drop(self.hs3(self.bn3(nn.BatchNorm1d(out_size)(out))))
  227. # print('size',out.size())
  228. # return self.linear4(out)
  229. # return F.log_softmax(out, dim=2)
  230. out = F.softmax(out, dim=2)
  231. return out
  232. def weights_init(m):
  233. classname = m.__class__.__name__
  234. if classname.find('Conv') != -1:
  235. m.weight.data.normal_(0.0, 0.02)
  236. elif classname.find('BatchNorm') != -1:
  237. m.weight.data.normal_(1.0, 0.02)
  238. m.bias.data.fill_(0)
  239. # def get_crnn(config):
  240. #
  241. # model = CRNN(config.MODEL.IMAGE_SIZE.H, 1, config.MODEL.NUM_CLASSES + 1, config.MODEL.NUM_HIDDEN)
  242. # model.apply(weights_init)
  243. #
  244. # return model
  245. def get_crnn(config):
  246. # model = CRNN(config.MODEL.IMAGE_SIZE.H, 1, config.MODEL.NUM_CLASSES + 1, config.MODEL.NUM_HIDDEN)
  247. model = MobileNetV3_Large(config.MODEL.NUM_CLASSES + 1)
  248. model.apply(weights_init)
  249. return model
  250. if __name__ == '__main__':
  251. from lib.utils.utils import model_info
  252. model = MobileNetV3_Large(1000)
  253. model_info(model)
  254. model = CRNN(32,1,1000,256)
  255. model_info(model)
  256. from timm.models import MobileNetV3
  257. from timm.models import MobileNetV3Features