model_torch.py 1.4 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546
  1. import torch.nn as nn
  2. from torchvision.models import resnet18, resnet34, resnet50, resnet101
  3. class crnn_ctc_equation_torch6(nn.Module):
  4. def __init__(self, class_num):
  5. super(crnn_ctc_equation_torch6, self).__init__()
  6. self.class_num = class_num
  7. # 18 - 512
  8. # 50 - 2048
  9. # 101 - 2048
  10. resnet = resnet50(pretrained=False)
  11. modules = list(resnet.children())[:-3]
  12. self.resnet = nn.Sequential(*modules)
  13. self.fc1 = nn.Linear(2048, 256)
  14. self.fc2 = nn.Linear(256, self.class_num)
  15. self.gru1 = nn.GRU(input_size=256, hidden_size=256)
  16. self.dropout = nn.Dropout(0.5)
  17. self.relu = nn.LeakyReLU()
  18. self.log_softmax = nn.LogSoftmax(2)
  19. def forward(self, x):
  20. x = x.permute(0, 3, 1, 2).contiguous()
  21. x = self.resnet(x)
  22. # print('x0', x.shape)
  23. x = x.permute(0, 3, 1, 2).contiguous()
  24. # print('x1', x.shape)
  25. x = x.view(x.shape[0], x.shape[1], -1)
  26. # print('x2', x.shape)
  27. x = self.fc1(x)
  28. x = self.relu(x)
  29. x = self.dropout(x)
  30. # print('x3', x.shape)
  31. output, _ = self.gru1(x)
  32. # print('x4', x.shape)
  33. x = self.fc2(output)
  34. # x = self.relu(x)
  35. x = self.dropout(x)
  36. # print('x5', x.shape)
  37. # x = x.permute(1, 0, 2)
  38. x = self.log_softmax(x)
  39. return x