crnn.py 3.0 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192
  1. import torch.nn as nn
  2. import torch.nn.functional as F
  3. class BidirectionalLSTM(nn.Module):
  4. # Inputs hidden units Out
  5. def __init__(self, nIn, nHidden, nOut):
  6. super(BidirectionalLSTM, self).__init__()
  7. self.rnn = nn.LSTM(nIn, nHidden, bidirectional=True)
  8. self.embedding = nn.Linear(nHidden * 2, nOut)
  9. def forward(self, input):
  10. recurrent, _ = self.rnn(input)
  11. T, b, h = recurrent.size()
  12. t_rec = recurrent.view(T * b, h)
  13. output = self.embedding(t_rec) # [T * b, nOut]
  14. output = output.view(T, b, -1)
  15. return output
  16. class CRNN(nn.Module):
  17. def __init__(self, imgH, nc, nclass, nh, n_rnn=2, leakyRelu=False):
  18. super(CRNN, self).__init__()
  19. assert imgH % 16 == 0, 'imgH has to be a multiple of 16'
  20. ks = [3, 3, 3, 3, 3, 3, 2]
  21. ps = [1, 1, 1, 1, 1, 1, 0]
  22. ss = [1, 1, 1, 1, 1, 1, 1]
  23. nm = [64, 128, 256, 256, 512, 512, 512]
  24. cnn = nn.Sequential()
  25. def convRelu(i, batchNormalization=False):
  26. nIn = nc if i == 0 else nm[i - 1]
  27. nOut = nm[i]
  28. cnn.add_module('conv{0}'.format(i),
  29. nn.Conv2d(nIn, nOut, ks[i], ss[i], ps[i]))
  30. if batchNormalization:
  31. cnn.add_module('batchnorm{0}'.format(i), nn.BatchNorm2d(nOut))
  32. if leakyRelu:
  33. cnn.add_module('relu{0}'.format(i),
  34. nn.LeakyReLU(0.2, inplace=True))
  35. else:
  36. cnn.add_module('relu{0}'.format(i), nn.ReLU(True))
  37. convRelu(0)
  38. cnn.add_module('pooling{0}'.format(0), nn.MaxPool2d(2, 2)) # 64x16x64
  39. convRelu(1)
  40. cnn.add_module('pooling{0}'.format(1), nn.MaxPool2d(2, 2)) # 128x8x32
  41. convRelu(2, True)
  42. convRelu(3)
  43. cnn.add_module('pooling{0}'.format(2),
  44. nn.MaxPool2d((2, 2), (2, 1), (0, 1))) # 256x4x16
  45. convRelu(4, True)
  46. convRelu(5)
  47. cnn.add_module('pooling{0}'.format(3),
  48. nn.MaxPool2d((2, 2), (2, 1), (0, 1))) # 512x2x16
  49. convRelu(6, True) # 512x1x16
  50. self.cnn = cnn
  51. self.rnn = nn.Sequential(
  52. BidirectionalLSTM(512, nh, nh),
  53. BidirectionalLSTM(nh, nh, nclass))
  54. def forward(self, input):
  55. # conv features
  56. conv = self.cnn(input)
  57. b, c, h, w = conv.size()
  58. print(conv.size())
  59. assert h == 1, "the height of conv must be 1"
  60. conv = conv.squeeze(2) # b *512 * width
  61. conv = conv.permute(2, 0, 1) # [w, b, c]
  62. output = F.log_softmax(self.rnn(conv), dim=2)
  63. return output
  64. def weights_init(m):
  65. classname = m.__class__.__name__
  66. if classname.find('Conv') != -1:
  67. m.weight.data.normal_(0.0, 0.02)
  68. elif classname.find('BatchNorm') != -1:
  69. m.weight.data.normal_(1.0, 0.02)
  70. m.bias.data.fill_(0)
  71. def get_crnn(config):
  72. model = CRNN(config.MODEL.IMAGE_SIZE.H, 1, config.MODEL.NUM_CLASSES + 1, config.MODEL.NUM_HIDDEN)
  73. model.apply(weights_init)
  74. return model