|
@@ -73,6 +73,100 @@ class TableHeadModel(nn.Module):
|
|
cnn3d_x = torch.permute(cnn3d_x, [2, 3, 1, 0])
|
|
cnn3d_x = torch.permute(cnn3d_x, [2, 3, 1, 0])
|
|
cnn3d_x = cnn3d_x.contiguous().view(row, col, char_num * self.char_embed_expand)
|
|
cnn3d_x = cnn3d_x.contiguous().view(row, col, char_num * self.char_embed_expand)
|
|
|
|
|
|
|
|
+ # dnn
|
|
|
|
+ x = self.dense3(cnn3d_x)
|
|
|
|
+ x = self.ln_dnn_2(x)
|
|
|
|
+ x = self.relu(x)
|
|
|
|
+ x = self.dense4(x)
|
|
|
|
+ x = self.sigmoid(x)
|
|
|
|
+ x = torch.squeeze(x, -1)
|
|
|
|
+ return x
|
|
|
|
+
|
|
|
|
+
|
|
|
|
+class TableHeadModel2(nn.Module):
|
|
|
|
+ def __init__(self):
|
|
|
|
+ super(TableHeadModel2, self).__init__()
|
|
|
|
+ self.char_num = 20
|
|
|
|
+ self.char_embed = 60
|
|
|
|
+ self.char_embed_expand = 128
|
|
|
|
+
|
|
|
|
+ self.dense0 = nn.Linear(self.char_embed, self.char_embed_expand)
|
|
|
|
+
|
|
|
|
+ self.dense3 = nn.Linear(self.char_num * self.char_embed_expand, 64)
|
|
|
|
+ self.dense4 = nn.Linear(64, 1)
|
|
|
|
+
|
|
|
|
+ self.sigmoid = nn.Sigmoid()
|
|
|
|
+
|
|
|
|
+ self.ln_dnn_2 = nn.LayerNorm([64])
|
|
|
|
+
|
|
|
|
+ self.device = torch.device("cpu")
|
|
|
|
+
|
|
|
|
+ self.relu = nn.LeakyReLU()
|
|
|
|
+ self.dropout = nn.Dropout(0.6)
|
|
|
|
+
|
|
|
|
+ # self.cnn1d_0 = nn.Conv1d(self.char_embed_expand,
|
|
|
|
+ # self.char_embed_expand,
|
|
|
|
+ # (3,), padding=self.get_padding(3))
|
|
|
|
+ # self.cnn1d_1 = nn.Conv1d(self.char_embed_expand,
|
|
|
|
+ # self.char_embed_expand,
|
|
|
|
+ # (3,), padding=self.get_padding(3))
|
|
|
|
+
|
|
|
|
+ encoder_layer1 = nn.TransformerEncoderLayer(d_model=self.char_embed_expand, nhead=2,
|
|
|
|
+ dim_feedforward=128, batch_first=True)
|
|
|
|
+ self.transformer1 = nn.TransformerEncoder(encoder_layer1, 2)
|
|
|
|
+ self.ln_encoder_0 = nn.LayerNorm([self.char_embed_expand])
|
|
|
|
+
|
|
|
|
+ self.cnn3d_0 = nn.Conv3d(self.char_embed_expand, self.char_embed_expand,
|
|
|
|
+ (3, 3, 3), padding=self.get_padding(3))
|
|
|
|
+ self.cnn3d_1 = nn.Conv3d(self.char_embed_expand, self.char_embed_expand,
|
|
|
|
+ (3, 3, 3), padding=self.get_padding(3))
|
|
|
|
+ # self.cnn3d_2 = nn.Conv3d(self.char_embed, self.char_embed,
|
|
|
|
+ # (3, 3, 3), padding=self.get_padding(3))
|
|
|
|
+
|
|
|
|
+ def get_padding(self, kernel_size, stride=1):
|
|
|
|
+ return (kernel_size - 1) // 2 * stride
|
|
|
|
+
|
|
|
|
+ def forward(self, x):
|
|
|
|
+ batch, row, col, char_num, char_embed = x.shape
|
|
|
|
+
|
|
|
|
+ # Embedding
|
|
|
|
+ x = torch.squeeze(x, 0)
|
|
|
|
+ x = x.view([row*col, char_num, char_embed])
|
|
|
|
+ x = self.dense0(x)
|
|
|
|
+
|
|
|
|
+ # transformer
|
|
|
|
+ box_attention = self.transformer1(x)
|
|
|
|
+ box_attention = self.ln_encoder_0(box_attention)
|
|
|
|
+ box_attention = torch.permute(box_attention, [0, 2, 1])
|
|
|
|
+ box_attention = box_attention.contiguous().view(row, col, char_num, self.char_embed_expand)
|
|
|
|
+ box_attention = torch.unsqueeze(box_attention, 0)
|
|
|
|
+
|
|
|
|
+ # cnn1d_x = torch.permute(cnn1d_x, [0, 2, 1])
|
|
|
|
+ # cnn1d_x = self.cnn1d_0(cnn1d_x)
|
|
|
|
+ # cnn1d_x = self.relu(cnn1d_x)
|
|
|
|
+ # cnn1d_x = self.dropout(cnn1d_x)
|
|
|
|
+ # cnn1d_x = self.cnn1d_1(cnn1d_x)
|
|
|
|
+ # cnn1d_x = self.relu(cnn1d_x)
|
|
|
|
+ # cnn1d_x = self.dropout(cnn1d_x)
|
|
|
|
+ #
|
|
|
|
+ # cnn1d_x = torch.permute(cnn1d_x, [0, 2, 1])
|
|
|
|
+ # cnn1d_x = cnn1d_x.contiguous().view(row, col, char_num, self.char_embed_expand)
|
|
|
|
+ # cnn1d_x = torch.unsqueeze(cnn1d_x, 0)
|
|
|
|
+ # print(cnn1d_x.shape)
|
|
|
|
+
|
|
|
|
+ # cnn 3d
|
|
|
|
+ cnn3d_x = torch.permute(box_attention, [0, 4, 3, 1, 2])
|
|
|
|
+ cnn3d_x = self.cnn3d_0(cnn3d_x)
|
|
|
|
+ cnn3d_x = self.relu(cnn3d_x)
|
|
|
|
+ cnn3d_x = self.dropout(cnn3d_x)
|
|
|
|
+ cnn3d_x = self.cnn3d_1(cnn3d_x)
|
|
|
|
+ cnn3d_x = self.relu(cnn3d_x)
|
|
|
|
+ cnn3d_x = self.dropout(cnn3d_x)
|
|
|
|
+
|
|
|
|
+ cnn3d_x = torch.squeeze(cnn3d_x, 0)
|
|
|
|
+ cnn3d_x = torch.permute(cnn3d_x, [2, 3, 1, 0])
|
|
|
|
+ cnn3d_x = cnn3d_x.contiguous().view(row, col, char_num * self.char_embed_expand)
|
|
|
|
+
|
|
# dnn
|
|
# dnn
|
|
x = self.dense3(cnn3d_x)
|
|
x = self.dense3(cnn3d_x)
|
|
x = self.ln_dnn_2(x)
|
|
x = self.ln_dnn_2(x)
|