|
@@ -0,0 +1,132 @@
|
|
|
+#coding=utf-8
|
|
|
+import os
|
|
|
+import sys
|
|
|
+import numpy as np
|
|
|
+import torch
|
|
|
+from torch.utils.data import Dataset
|
|
|
+
|
|
|
+sys.path.append(os.path.abspath(os.path.dirname(__file__) + "/../../../"))
|
|
|
+from BiddingKG.dl.common.Utils import embedding_word, embedding_word_forward
|
|
|
+
|
|
|
+
|
|
|
+def set_label(row, row_label):
|
|
|
+ if len(row) == 1:
|
|
|
+ row_label = [0 for x in row]
|
|
|
+ elif len(set(row)) == 1:
|
|
|
+ row_label = [0 for x in row]
|
|
|
+ else:
|
|
|
+ row_label = [0 if x in ["", " ", "/", '无', '-'] else row_label[i] for i, x in enumerate(row)]
|
|
|
+ return row_label
|
|
|
+
|
|
|
+
|
|
|
+def set_same_table_head(inputs, y_pred1):
|
|
|
+ inputs = torch.squeeze(inputs, 0)
|
|
|
+
|
|
|
+ for i in range(inputs.shape[0]):
|
|
|
+ for j in range(inputs.shape[1]-1):
|
|
|
+ col1 = inputs[i, j, :, :]
|
|
|
+ col2 = inputs[i, j+1, :, :]
|
|
|
+ if (torch.abs(col1 - col2) < 1e-4).all():
|
|
|
+ # print('same value', col1[abs(col1) > 0.], col2[abs(col1) > 0.])
|
|
|
+ if (y_pred1[i, j] <= 0.5 and y_pred1[i, j+1] <= 0.5) or (y_pred1[i, j] > 0.5 and y_pred1[i, j+1] > 0.5):
|
|
|
+ continue
|
|
|
+ else:
|
|
|
+ # print('differ label', y_pred[i, j], y_pred[i, j+1])
|
|
|
+ y_pred1[i, j+1] = y_pred1[i, j]
|
|
|
+
|
|
|
+ for i in range(inputs.shape[1]):
|
|
|
+ for j in range(inputs.shape[0]-1):
|
|
|
+ row1 = inputs[j, i, :, :]
|
|
|
+ row2 = inputs[j+1, i, :, :]
|
|
|
+ if (torch.abs(row1 - row2) < 1e-4).all():
|
|
|
+ if (y_pred1[j, i] <= 0.5 and y_pred1[j+1, i] <= 0.5) or (y_pred1[j, i] > 0.5 and y_pred1[j+1, i] > 0.5):
|
|
|
+ continue
|
|
|
+ else:
|
|
|
+ # print('same value', row1[abs(row1) > 0.], row2[abs(row2) > 0.])
|
|
|
+ # print('differ label', y_pred[i, j], y_pred[i, j+1])
|
|
|
+ # print('before', x11[0, j, i], x11[0, j+1, i])
|
|
|
+ y_pred1[j+1, i] = y_pred1[j, i]
|
|
|
+ # print('after', x1[0, j, i], x1[0, j+1, i])
|
|
|
+ return y_pred1
|
|
|
+
|
|
|
+
|
|
|
+def data_to_numpy29(data_list, data_label_list):
|
|
|
+ """
|
|
|
+ 输出表格 (table_cnt, row, col, 20, 60)
|
|
|
+
|
|
|
+ :param data_list:
|
|
|
+ :param data_label_list:
|
|
|
+ :return:
|
|
|
+ """
|
|
|
+ data_num = len(data_list)
|
|
|
+
|
|
|
+ new_data_list = []
|
|
|
+ new_label_list = []
|
|
|
+ mask_list = []
|
|
|
+ for i in range(len(data_list)):
|
|
|
+ table = data_list[i]
|
|
|
+ table_label = []
|
|
|
+ if data_label_list:
|
|
|
+ table_label = data_label_list[i]
|
|
|
+ embed_list = []
|
|
|
+ label_list = []
|
|
|
+ mask = []
|
|
|
+ for j in range(len(table)):
|
|
|
+ row = table[j]
|
|
|
+ blank_list = [0 if x in ["", " ", "/"] else 1 for x in row]
|
|
|
+ mask.append(blank_list)
|
|
|
+ row = embedding_word_forward(row, shape=(len(row), 20, 60))
|
|
|
+ embed_list.append(row)
|
|
|
+ if data_label_list:
|
|
|
+ row_label = table_label[j]
|
|
|
+ # print(j, row_label)
|
|
|
+ row_label = [int(x) for x in row_label]
|
|
|
+ row_label = set_label(table[j], row_label)
|
|
|
+ label_list.append(row_label)
|
|
|
+ embed_list = np.array(embed_list, dtype=np.float32)
|
|
|
+ label_list = np.array(label_list, dtype=np.float32)
|
|
|
+ mask = np.array(mask, dtype=np.float32)
|
|
|
+ # print('embed_list.shape', embed_list.shape)
|
|
|
+ # print('label_list.shape', label_list.shape)
|
|
|
+ new_data_list.append(embed_list)
|
|
|
+ new_label_list.append(label_list)
|
|
|
+ mask_list.append(mask)
|
|
|
+
|
|
|
+ new_data_list = np.array(new_data_list, dtype=np.float32)
|
|
|
+ new_label_list = np.array(new_label_list, dtype=np.float32)
|
|
|
+ mask_list = np.array(mask_list, dtype=np.float32)
|
|
|
+ # print(new_data_list.shape)
|
|
|
+
|
|
|
+ return new_data_list, new_label_list, mask_list
|
|
|
+
|
|
|
+
|
|
|
+class CustomDatasetTiny40(Dataset):
|
|
|
+ def __init__(self, data_x, data_y, mode=0):
|
|
|
+ if mode in [0, 1]:
|
|
|
+ # Split -> Train, Test
|
|
|
+ split_size = int(len(data_x)*0.1)
|
|
|
+ test_x, test_y = data_x[:split_size], data_y[:split_size]
|
|
|
+ train_x, train_y = data_x[split_size:], data_y[split_size:]
|
|
|
+
|
|
|
+ if mode == 0:
|
|
|
+ self.data = train_x
|
|
|
+ self.targets = train_y
|
|
|
+ else:
|
|
|
+ self.data = test_x
|
|
|
+ self.targets = test_y
|
|
|
+ else:
|
|
|
+ pass
|
|
|
+
|
|
|
+ # self.data = data
|
|
|
+ # self.targets = targets
|
|
|
+
|
|
|
+ def __len__(self):
|
|
|
+ return len(self.data)
|
|
|
+
|
|
|
+ def __getitem__(self, idx):
|
|
|
+ # x, y = data_to_numpy12([self.data[idx]], [self.targets[idx]])
|
|
|
+ x, y, mask = data_to_numpy29([self.data[idx]], [self.targets[idx]])
|
|
|
+ x = x[0]
|
|
|
+ y = y[0]
|
|
|
+ mask = mask[0]
|
|
|
+ return x, y, mask
|