import torch.nn as nn from torchvision.models import resnet18, resnet34, resnet50, resnet101 class crnn_ctc_equation_torch6(nn.Module): def __init__(self, class_num): super(crnn_ctc_equation_torch6, self).__init__() self.class_num = class_num # 18 - 512 # 50 - 2048 # 101 - 2048 resnet = resnet50(pretrained=False) modules = list(resnet.children())[:-3] self.resnet = nn.Sequential(*modules) self.fc1 = nn.Linear(2048, 256) self.fc2 = nn.Linear(256, self.class_num) self.gru1 = nn.GRU(input_size=256, hidden_size=256) self.dropout = nn.Dropout(0.5) self.relu = nn.LeakyReLU() self.log_softmax = nn.LogSoftmax(2) def forward(self, x): x = x.permute(0, 3, 1, 2).contiguous() x = self.resnet(x) # print('x0', x.shape) x = x.permute(0, 3, 1, 2).contiguous() # print('x1', x.shape) x = x.view(x.shape[0], x.shape[1], -1) # print('x2', x.shape) x = self.fc1(x) x = self.relu(x) x = self.dropout(x) # print('x3', x.shape) output, _ = self.gru1(x) # print('x4', x.shape) x = self.fc2(output) # x = self.relu(x) x = self.dropout(x) # print('x5', x.shape) # x = x.permute(1, 0, 2) x = self.log_softmax(x) return x