12345678910111213141516171819202122232425262728293031323334353637383940414243444546 |
- import torch.nn as nn
- from torchvision.models import resnet18, resnet34, resnet50, resnet101
- class crnn_ctc_equation_torch6(nn.Module):
- def __init__(self, class_num):
- super(crnn_ctc_equation_torch6, self).__init__()
- self.class_num = class_num
- # 18 - 512
- # 50 - 2048
- # 101 - 2048
- resnet = resnet50(pretrained=False)
- modules = list(resnet.children())[:-3]
- self.resnet = nn.Sequential(*modules)
- self.fc1 = nn.Linear(2048, 256)
- self.fc2 = nn.Linear(256, self.class_num)
- self.gru1 = nn.GRU(input_size=256, hidden_size=256)
- self.dropout = nn.Dropout(0.5)
- self.relu = nn.LeakyReLU()
- self.log_softmax = nn.LogSoftmax(2)
- def forward(self, x):
- x = x.permute(0, 3, 1, 2).contiguous()
- x = self.resnet(x)
- # print('x0', x.shape)
- x = x.permute(0, 3, 1, 2).contiguous()
- # print('x1', x.shape)
- x = x.view(x.shape[0], x.shape[1], -1)
- # print('x2', x.shape)
- x = self.fc1(x)
- x = self.relu(x)
- x = self.dropout(x)
- # print('x3', x.shape)
- output, _ = self.gru1(x)
- # print('x4', x.shape)
- x = self.fc2(output)
- # x = self.relu(x)
- x = self.dropout(x)
- # print('x5', x.shape)
- # x = x.permute(1, 0, 2)
- x = self.log_softmax(x)
- return x
|