Ver código fonte

Merge remote-tracking branch 'origin/master'

lsm 9 meses atrás
pai
commit
1b996c453c

+ 1 - 1
BiddingKG/dl/interface/Preprocessing.py

@@ -8,7 +8,7 @@ import time
 import codecs
 
 from BiddingKG.dl.ratio.re_ratio import extract_ratio
-from BiddingKG.dl.table_head.predict import predict
+from BiddingKG.dl.table_head.predict_torch import predict
 
 sys.setrecursionlimit(1000000)
 sys.path.append(os.path.abspath("../.."))

BIN
BiddingKG/dl/table_head/model_40_0.959.pth


+ 83 - 0
BiddingKG/dl/table_head/models/model_torch.py

@@ -0,0 +1,83 @@
+import torch.nn as nn
+import torch
+
+
+class TableHeadModel(nn.Module):
+    def __init__(self):
+        super(TableHeadModel, self).__init__()
+        self.char_num = 20
+        self.char_embed = 60
+        self.char_embed_expand = 128
+
+        self.dense0 = nn.Linear(self.char_embed, self.char_embed_expand)
+
+        self.dense3 = nn.Linear(self.char_num * self.char_embed_expand, 64)
+        self.dense4 = nn.Linear(64, 1)
+
+        self.sigmoid = nn.Sigmoid()
+
+        self.ln_dnn_2 = nn.LayerNorm([64])
+
+        self.device = torch.device("cpu")
+
+        self.relu = nn.LeakyReLU()
+        self.dropout = nn.Dropout(0.3)
+
+        self.cnn1d_0 = nn.Conv1d(self.char_embed_expand,
+                                 self.char_embed_expand,
+                                 (3,), padding=self.get_padding(3))
+        self.cnn1d_1 = nn.Conv1d(self.char_embed_expand,
+                                 self.char_embed_expand,
+                                 (3,), padding=self.get_padding(3))
+
+        self.cnn3d_0 = nn.Conv3d(self.char_embed_expand, self.char_embed_expand,
+                                 (3, 3, 3), padding=self.get_padding(3))
+        self.cnn3d_1 = nn.Conv3d(self.char_embed_expand, self.char_embed_expand,
+                                 (3, 3, 3), padding=self.get_padding(3))
+
+    def get_padding(self, kernel_size, stride=1):
+        return (kernel_size - 1) // 2 * stride
+
+    def forward(self, x):
+        batch, row, col, char_num, char_embed = x.shape
+
+        # cnn 1d
+        cnn1d_x = torch.squeeze(x, 0)
+        cnn1d_x = cnn1d_x.view([row*col, char_num, char_embed])
+
+        cnn1d_x = self.dense0(cnn1d_x)
+
+        cnn1d_x = torch.permute(cnn1d_x, [0, 2, 1])
+        cnn1d_x = self.cnn1d_0(cnn1d_x)
+        cnn1d_x = self.relu(cnn1d_x)
+        cnn1d_x = self.dropout(cnn1d_x)
+        cnn1d_x = self.cnn1d_1(cnn1d_x)
+        cnn1d_x = self.relu(cnn1d_x)
+        cnn1d_x = self.dropout(cnn1d_x)
+
+        cnn1d_x = torch.permute(cnn1d_x, [0, 2, 1])
+        cnn1d_x = cnn1d_x.contiguous().view(row, col, char_num, self.char_embed_expand)
+        cnn1d_x = torch.unsqueeze(cnn1d_x, 0)
+        # print(cnn1d_x.shape)
+
+        # cnn 3d
+        cnn3d_x = torch.permute(cnn1d_x, [0, 4, 3, 1, 2])
+        cnn3d_x = self.cnn3d_0(cnn3d_x)
+        cnn3d_x = self.relu(cnn3d_x)
+        cnn3d_x = self.dropout(cnn3d_x)
+        cnn3d_x = self.cnn3d_1(cnn3d_x)
+        cnn3d_x = self.relu(cnn3d_x)
+        cnn3d_x = self.dropout(cnn3d_x)
+
+        cnn3d_x = torch.squeeze(cnn3d_x, 0)
+        cnn3d_x = torch.permute(cnn3d_x, [2, 3, 1, 0])
+        cnn3d_x = cnn3d_x.contiguous().view(row, col, char_num * self.char_embed_expand)
+
+        # dnn
+        x = self.dense3(cnn3d_x)
+        x = self.ln_dnn_2(x)
+        x = self.relu(x)
+        x = self.dense4(x)
+        x = self.sigmoid(x)
+        x = torch.squeeze(x, -1)
+        return x

+ 132 - 0
BiddingKG/dl/table_head/pre_process_torch.py

@@ -0,0 +1,132 @@
+#coding=utf-8
+import os
+import sys
+import numpy as np
+import torch
+from torch.utils.data import Dataset
+
+sys.path.append(os.path.abspath(os.path.dirname(__file__) + "/../../../"))
+from BiddingKG.dl.common.Utils import embedding_word, embedding_word_forward
+
+
+def set_label(row, row_label):
+    if len(row) == 1:
+        row_label = [0 for x in row]
+    elif len(set(row)) == 1:
+        row_label = [0 for x in row]
+    else:
+        row_label = [0 if x in ["", " ", "/", '无', '-', '~~'] else row_label[i] for i, x in enumerate(row)]
+    return row_label
+
+
+def set_same_table_head(inputs, y_pred1):
+    inputs = torch.squeeze(inputs, 0)
+
+    for i in range(inputs.shape[0]):
+        for j in range(inputs.shape[1]-1):
+            col1 = inputs[i, j, :, :]
+            col2 = inputs[i, j+1, :, :]
+            if (torch.abs(col1 - col2) < 1e-4).all():
+                # print('same value', col1[abs(col1) > 0.], col2[abs(col1) > 0.])
+                if (y_pred1[i, j] <= 0.5 and y_pred1[i, j+1] <= 0.5) or (y_pred1[i, j] > 0.5 and y_pred1[i, j+1] > 0.5):
+                    continue
+                else:
+                    # print('differ label', y_pred[i, j], y_pred[i, j+1])
+                    y_pred1[i, j+1] = y_pred1[i, j]
+
+    for i in range(inputs.shape[1]):
+        for j in range(inputs.shape[0]-1):
+            row1 = inputs[j, i, :, :]
+            row2 = inputs[j+1, i, :, :]
+            if (torch.abs(row1 - row2) < 1e-4).all():
+                if (y_pred1[j, i] <= 0.5 and y_pred1[j+1, i] <= 0.5) or (y_pred1[j, i] > 0.5 and y_pred1[j+1, i] > 0.5):
+                    continue
+                else:
+                    # print('same value', row1[abs(row1) > 0.], row2[abs(row2) > 0.])
+                    # print('differ label', y_pred[i, j], y_pred[i, j+1])
+                    # print('before', x11[0, j, i], x11[0, j+1, i])
+                    y_pred1[j+1, i] = y_pred1[j, i]
+                    # print('after', x1[0, j, i],  x1[0, j+1, i])
+    return y_pred1
+
+
+def data_to_numpy29(data_list, data_label_list):
+    """
+    输出表格 (table_cnt, row, col, 20, 60)
+
+    :param data_list:
+    :param data_label_list:
+    :return:
+    """
+    data_num = len(data_list)
+
+    new_data_list = []
+    new_label_list = []
+    mask_list = []
+    for i in range(len(data_list)):
+        table = data_list[i]
+        table_label = []
+        if data_label_list:
+            table_label = data_label_list[i]
+        embed_list = []
+        label_list = []
+        mask = []
+        for j in range(len(table)):
+            row = table[j]
+            blank_list = [0 if x in ["", " ", "/"] else 1 for x in row]
+            mask.append(blank_list)
+            row = embedding_word_forward(row, shape=(len(row), 20, 60))
+            embed_list.append(row)
+            if data_label_list:
+                row_label = table_label[j]
+                # print(j, row_label)
+                row_label = [int(x) for x in row_label]
+                row_label = set_label(table[j], row_label)
+                label_list.append(row_label)
+        embed_list = np.array(embed_list, dtype=np.float32)
+        label_list = np.array(label_list, dtype=np.float32)
+        mask = np.array(mask, dtype=np.float32)
+        # print('embed_list.shape', embed_list.shape)
+        # print('label_list.shape', label_list.shape)
+        new_data_list.append(embed_list)
+        new_label_list.append(label_list)
+        mask_list.append(mask)
+
+    new_data_list = np.array(new_data_list, dtype=np.float32)
+    new_label_list = np.array(new_label_list, dtype=np.float32)
+    mask_list = np.array(mask_list, dtype=np.float32)
+    # print(new_data_list.shape)
+
+    return new_data_list, new_label_list, mask_list
+
+
+class CustomDatasetTiny40(Dataset):
+    def __init__(self, data_x, data_y, mode=0):
+        if mode in [0, 1]:
+            # Split -> Train, Test
+            split_size = int(len(data_x)*0.1)
+            test_x, test_y = data_x[:split_size], data_y[:split_size]
+            train_x, train_y = data_x[split_size:], data_y[split_size:]
+
+            if mode == 0:
+                self.data = train_x
+                self.targets = train_y
+            else:
+                self.data = test_x
+                self.targets = test_y
+        else:
+            pass
+
+        # self.data = data
+        # self.targets = targets
+
+    def __len__(self):
+        return len(self.data)
+
+    def __getitem__(self, idx):
+        # x, y = data_to_numpy12([self.data[idx]], [self.targets[idx]])
+        x, y, mask = data_to_numpy29([self.data[idx]], [self.targets[idx]])
+        x = x[0]
+        y = y[0]
+        mask = mask[0]
+        return x, y, mask

+ 68 - 0
BiddingKG/dl/table_head/predict_torch.py

@@ -0,0 +1,68 @@
+import copy
+import os
+import sys
+import torch
+from torch.utils.data import DataLoader
+
+sys.path.append(os.path.abspath(os.path.dirname(__file__) + "/../../../"))
+from BiddingKG.dl.table_head.models.model_torch import TableHeadModel
+from BiddingKG.dl.table_head.pre_process_torch import CustomDatasetTiny40, set_same_table_head, set_label
+
+device = torch.device("cpu")
+model_path = os.path.abspath(os.path.dirname(__file__)) + '/model_40_0.959.pth'
+batch_size = 1
+
+
+def predict(table_text_list):
+    if globals().get("model") is None:
+        print("="*15, "init table_head model", "="*15)
+        # 实例化模型
+        model = TableHeadModel()
+        model.to(device)
+        model.load_state_dict(torch.load(model_path, map_location=torch.device(device)))
+        # 将模型设置为评估模式
+        model.eval()
+        globals()["model"] = model
+    else:
+        model = globals().get("model")
+
+    if len(table_text_list) <= 0:
+        return []
+
+    data_x = copy.deepcopy(table_text_list)
+    data_y = [[0 for col in row] for row in data_x]
+
+    row_len = len(data_x)
+    col_len = len(data_x[0])
+
+    if col_len >= 50:
+        return data_y
+
+    if col_len >= 20:
+        batch_row_len = 50
+    else:
+        batch_row_len = 100
+
+    result_list = []
+    for i in range(0, row_len, batch_row_len):
+        batch_data_x = data_x[i:i+batch_row_len]
+        dataset = CustomDatasetTiny40([batch_data_x], [data_y], mode=0)
+        data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
+        # 存储预测结果
+        with torch.no_grad():
+            for data, targets, _ in data_loader:
+                data = data.to(device)
+                outputs = model(data)
+                outputs = set_same_table_head(data, outputs)
+                result = torch.zeros_like(outputs)
+                result[outputs >= 0.5] = 1
+                result = result.numpy().tolist()
+        result_list += result
+
+    # 设置一些特定的表头
+    for i in range(len(result_list)):
+        row = table_text_list[i]
+        row_label = result_list[i]
+        result_list[i] = set_label(row, row_label)
+
+    return result_list