|
@@ -1,8 +1,10 @@
|
|
|
|
+import os
|
|
import random
|
|
import random
|
|
-
|
|
|
|
|
|
+import sys
|
|
import psycopg2
|
|
import psycopg2
|
|
import numpy as np
|
|
import numpy as np
|
|
-from BiddingKG.dl.common.Utils import embedding_word
|
|
|
|
|
|
+sys.path.append(os.path.dirname(__file__) + "/../")
|
|
|
|
+from common.Utils import embedding_word, embedding_word_forward
|
|
|
|
|
|
|
|
|
|
def get_sentence_index_list(sentence, dict_path='utils/ppocr_keys_v1.txt'):
|
|
def get_sentence_index_list(sentence, dict_path='utils/ppocr_keys_v1.txt'):
|
|
@@ -42,22 +44,28 @@ def postgresql_util(sql, limit):
|
|
return all_rows
|
|
return all_rows
|
|
|
|
|
|
|
|
|
|
-def get_data_from_sql(dim=10):
|
|
|
|
|
|
+def get_data_from_sql(dim=10, whole_table=False, padding=True):
|
|
|
|
+ sql = """
|
|
|
|
+ select table_text, pre_label, post_label, id
|
|
|
|
+ from label_table_head_info
|
|
|
|
+ where status = 0 and (update_user='test9' or update_user='test1' or update_user='test7' or update_user='test26')
|
|
|
|
+ ;
|
|
|
|
+ """
|
|
# sql = """
|
|
# sql = """
|
|
# select table_text, pre_label, post_label, id
|
|
# select table_text, pre_label, post_label, id
|
|
# from label_table_head_info
|
|
# from label_table_head_info
|
|
- # where update_user <> 'test27' and update_user <> 'test20' and table_box_cnt >= 4 and table_box_cnt <= 200
|
|
|
|
|
|
+ # where status = 1 and update_time >= '2022-01-17' and update_time <= '2022-01-22'
|
|
# ;
|
|
# ;
|
|
# """
|
|
# """
|
|
- sql = """
|
|
|
|
- select table_text, pre_label, post_label, id
|
|
|
|
- from label_table_head_info
|
|
|
|
- where status = 1 and update_time >= '2022-01-17'
|
|
|
|
- ;
|
|
|
|
- """
|
|
|
|
|
|
|
|
result_list = postgresql_util(sql, limit=1000000)
|
|
result_list = postgresql_util(sql, limit=1000000)
|
|
|
|
|
|
|
|
+ # 需排除的id
|
|
|
|
+ with open(r"C:\Users\Administrator\Desktop\table_not_eval.txt", "r") as f:
|
|
|
|
+ delete_id_list = eval(f.read())
|
|
|
|
+ with open(r"C:\Users\Administrator\Desktop\table_delete.txt", "r") as f:
|
|
|
|
+ delete_id_list += eval(f.read())
|
|
|
|
+
|
|
all_data_list = []
|
|
all_data_list = []
|
|
all_data_label_list = []
|
|
all_data_label_list = []
|
|
i = 0
|
|
i = 0
|
|
@@ -71,6 +79,10 @@ def get_data_from_sql(dim=10):
|
|
post_label = eval(table[2])
|
|
post_label = eval(table[2])
|
|
_id = table[3]
|
|
_id = table[3]
|
|
|
|
|
|
|
|
+ if _id in delete_id_list:
|
|
|
|
+ print("pass", _id)
|
|
|
|
+ continue
|
|
|
|
+
|
|
# table_text需要特殊处理
|
|
# table_text需要特殊处理
|
|
try:
|
|
try:
|
|
table_text = table[0]
|
|
table_text = table[0]
|
|
@@ -84,17 +96,35 @@ def get_data_from_sql(dim=10):
|
|
print("无法识别table_text", _id)
|
|
print("无法识别table_text", _id)
|
|
continue
|
|
continue
|
|
|
|
|
|
- # 只有一行的也不要
|
|
|
|
- if len(post_label) >= 2:
|
|
|
|
- data_list, data_label_list = table_pre_process(table_text, post_label, _id)
|
|
|
|
- elif len(pre_label) >= 2:
|
|
|
|
- data_list, data_label_list = table_pre_process(table_text, pre_label, _id)
|
|
|
|
|
|
+ if whole_table:
|
|
|
|
+ if len(post_label) >= 2:
|
|
|
|
+ data_list, data_label_list = table_pre_process_2(table_text, post_label,
|
|
|
|
+ _id, padding=padding)
|
|
|
|
+ elif len(pre_label) >= 2:
|
|
|
|
+ data_list, data_label_list = table_pre_process_2(table_text, pre_label,
|
|
|
|
+ _id, padding=padding)
|
|
|
|
+ else:
|
|
|
|
+ data_list, data_label_list = [], []
|
|
else:
|
|
else:
|
|
- data_list, data_label_list = [], []
|
|
|
|
|
|
+ # 只有一行的也不要
|
|
|
|
+ if len(post_label) >= 2:
|
|
|
|
+ data_list, data_label_list = table_pre_process(table_text, post_label, _id)
|
|
|
|
+ elif len(pre_label) >= 2:
|
|
|
|
+ data_list, data_label_list = table_pre_process(table_text, pre_label, _id)
|
|
|
|
+ else:
|
|
|
|
+ data_list, data_label_list = [], []
|
|
|
|
|
|
all_data_list += data_list
|
|
all_data_list += data_list
|
|
all_data_label_list += data_label_list
|
|
all_data_label_list += data_label_list
|
|
|
|
|
|
|
|
+ # 按维度大小排序
|
|
|
|
+ if whole_table:
|
|
|
|
+ _list = []
|
|
|
|
+ for data, label in zip(all_data_list, all_data_label_list):
|
|
|
|
+ _list.append([data, label])
|
|
|
|
+ _list.sort(key=lambda x: (len(x[0]), len(x[0][0])))
|
|
|
|
+ all_data_list[:], all_data_label_list[:] = zip(*_list)
|
|
|
|
+
|
|
print("len(all_data_list)", len(all_data_list))
|
|
print("len(all_data_list)", len(all_data_list))
|
|
return all_data_list, all_data_label_list
|
|
return all_data_list, all_data_label_list
|
|
|
|
|
|
@@ -206,7 +236,84 @@ def table_pre_process(text_list, label_list, _id, is_train=True):
|
|
return data_list
|
|
return data_list
|
|
|
|
|
|
|
|
|
|
-def get_data_from_file(file_type):
|
|
|
|
|
|
+def table_pre_process_2(text_list, label_list, _id, is_train=True, padding=True):
|
|
|
|
+ """
|
|
|
|
+ 表格处理,整个表格为一个数组,且填充长宽维度
|
|
|
|
+
|
|
|
|
+ :param text_list:
|
|
|
|
+ :param label_list:
|
|
|
|
+ :param _id:
|
|
|
|
+ :param is_train:
|
|
|
|
+ :return:
|
|
|
|
+ """
|
|
|
|
+ # 判断表格长宽是否合理
|
|
|
|
+ row_len = len(text_list)
|
|
|
|
+ best_row_len = get_best_padding_size(row_len, min_len=8)
|
|
|
|
+ col_len = len(text_list[0])
|
|
|
|
+ best_col_len = get_best_padding_size(col_len, min_len=8)
|
|
|
|
+ if best_row_len is None:
|
|
|
|
+ if is_train:
|
|
|
|
+ return [], []
|
|
|
|
+ else:
|
|
|
|
+ return []
|
|
|
|
+ if best_col_len is None:
|
|
|
|
+ if is_train:
|
|
|
|
+ return [], []
|
|
|
|
+ else:
|
|
|
|
+ return []
|
|
|
|
+
|
|
|
|
+ if is_train:
|
|
|
|
+ if len(text_list) != len(label_list):
|
|
|
|
+ print("文字单元格与标注单元格数量不匹配!", _id)
|
|
|
|
+ print("len(text_list)", len(text_list), "len(label_list)", len(label_list))
|
|
|
|
+ return [], []
|
|
|
|
+
|
|
|
|
+ if padding:
|
|
|
|
+ for i in range(row_len):
|
|
|
|
+ col_len = len(text_list[i])
|
|
|
|
+ text_list[i] += [None]*(best_col_len-col_len)
|
|
|
|
+ if is_train:
|
|
|
|
+ label_list[i] += ["0"]*(best_col_len-col_len)
|
|
|
|
+ text_list += [[None]*best_col_len]*(best_row_len-row_len)
|
|
|
|
+ if is_train:
|
|
|
|
+ label_list += [["0"]*best_col_len]*(best_row_len-row_len)
|
|
|
|
+
|
|
|
|
+ if is_train:
|
|
|
|
+ for i in range(len(label_list)):
|
|
|
|
+ for j in range(len(label_list[i])):
|
|
|
|
+ label_list[i][j] = int(label_list[i][j])
|
|
|
|
+ return [text_list], [label_list]
|
|
|
|
+ else:
|
|
|
|
+ return [text_list]
|
|
|
|
+
|
|
|
|
+
|
|
|
|
+def get_best_padding_size(axis_len, min_len=3, max_len=300):
|
|
|
|
+ # sizes = [8, 16, 24, 32, 40, 48, 56, 64, 72, 80, 88, 96, 104, 112, 120,
|
|
|
|
+ # 128, 136, 144, 152, 160, 168, 176, 184, 192, 200, 208, 216, 224,
|
|
|
|
+ # 232, 240, 248, 256, 264, 272, 280, 288, 296]
|
|
|
|
+ # sizes = [3, 6, 9, 12, 15, 18, 21, 24, 27, 30, 33, 36, 39, 42, 45, 48, 51, 54, 57,
|
|
|
|
+ # 60, 63, 66, 69, 72, 75, 78, 81, 84, 87, 90, 93, 96, 99, 102, 105, 108, 111,
|
|
|
|
+ # 114, 117, 120, 123, 126, 129, 132, 135, 138, 141, 144, 147, 150, 153, 156,
|
|
|
|
+ # 159, 162, 165, 168, 171, 174, 177, 180, 183, 186, 189, 192, 195, 198, 201,
|
|
|
|
+ # 204, 207, 210, 213, 216, 219, 222, 225, 228, 231, 234, 237, 240, 243, 246,
|
|
|
|
+ # 249, 252, 255, 258, 261, 264, 267, 270, 273, 276, 279, 282, 285, 288, 291,
|
|
|
|
+ # 294, 297]
|
|
|
|
+ sizes = []
|
|
|
|
+ for i in range(1, max_len):
|
|
|
|
+ if i * min_len <= max_len:
|
|
|
|
+ sizes.append(i * min_len)
|
|
|
|
+ if axis_len > sizes[-1]:
|
|
|
|
+ return axis_len
|
|
|
|
+ best_len = sizes[-1]
|
|
|
|
+ for height in sizes:
|
|
|
|
+ if axis_len <= height:
|
|
|
|
+ best_len = height
|
|
|
|
+ break
|
|
|
|
+ # print("get_best_padding_size", axis_len, best_len)
|
|
|
|
+ return best_len
|
|
|
|
+
|
|
|
|
+
|
|
|
|
+def get_data_from_file(file_type, model_id=1):
|
|
if file_type == 'np':
|
|
if file_type == 'np':
|
|
data_path = 'train_data/data_3.npy'
|
|
data_path = 'train_data/data_3.npy'
|
|
data_label_path = 'train_data/data_label_3.npy'
|
|
data_label_path = 'train_data/data_label_3.npy'
|
|
@@ -215,17 +322,20 @@ def get_data_from_file(file_type):
|
|
array2 = np.load(data_label_path)
|
|
array2 = np.load(data_label_path)
|
|
return array1, array2
|
|
return array1, array2
|
|
elif file_type == 'txt':
|
|
elif file_type == 'txt':
|
|
- data_path = 'train_data/data3.txt'
|
|
|
|
- data_label_path = 'train_data/data_label3.txt'
|
|
|
|
-
|
|
|
|
|
|
+ if model_id == 1:
|
|
|
|
+ data_path = 'train_data/data1.txt'
|
|
|
|
+ data_label_path = 'train_data/data_label1.txt'
|
|
|
|
+ elif model_id == 2:
|
|
|
|
+ data_path = 'train_data/data2.txt'
|
|
|
|
+ data_label_path = 'train_data/data_label2.txt'
|
|
|
|
+ elif model_id == 3:
|
|
|
|
+ data_path = 'train_data/data3.txt'
|
|
|
|
+ data_label_path = 'train_data/data_label3.txt'
|
|
with open(data_path, 'r') as f:
|
|
with open(data_path, 'r') as f:
|
|
data_list = f.readlines()
|
|
data_list = f.readlines()
|
|
with open(data_label_path, 'r') as f:
|
|
with open(data_label_path, 'r') as f:
|
|
data_label_list = f.readlines()
|
|
data_label_list = f.readlines()
|
|
|
|
|
|
- # for i in range(len(data_list)):
|
|
|
|
- # data_list[i] = eval(data_list[i][:-1])
|
|
|
|
- # data_label_list[i] = eval(data_label_list[i][:-1])
|
|
|
|
return data_list, data_label_list
|
|
return data_list, data_label_list
|
|
else:
|
|
else:
|
|
print("file type error! only np and txt supported")
|
|
print("file type error! only np and txt supported")
|
|
@@ -245,18 +355,19 @@ def processed_save_to_np():
|
|
# f.write(str(line) + "\n")
|
|
# f.write(str(line) + "\n")
|
|
|
|
|
|
|
|
|
|
-def processed_save_to_txt():
|
|
|
|
- list1, list2 = get_data_from_sql()
|
|
|
|
|
|
+def processed_save_to_txt(whole_table=False, padding=True):
|
|
|
|
+ list1, list2 = get_data_from_sql(whole_table=whole_table, padding=padding)
|
|
|
|
|
|
# 打乱
|
|
# 打乱
|
|
|
|
+ # if not whole_table or not padding:
|
|
zip_list = list(zip(list1, list2))
|
|
zip_list = list(zip(list1, list2))
|
|
random.shuffle(zip_list)
|
|
random.shuffle(zip_list)
|
|
list1[:], list2[:] = zip(*zip_list)
|
|
list1[:], list2[:] = zip(*zip_list)
|
|
|
|
|
|
- with open('train_data/data3.txt', 'w') as f:
|
|
|
|
|
|
+ with open('train_data/data1.txt', 'w') as f:
|
|
for line in list1:
|
|
for line in list1:
|
|
f.write(str(line) + "\n")
|
|
f.write(str(line) + "\n")
|
|
- with open('train_data/data_label3.txt', 'w') as f:
|
|
|
|
|
|
+ with open('train_data/data_label1.txt', 'w') as f:
|
|
for line in list2:
|
|
for line in list2:
|
|
f.write(str(line) + "\n")
|
|
f.write(str(line) + "\n")
|
|
|
|
|
|
@@ -287,7 +398,7 @@ def my_data_loader(data_list, data_label_list, batch_size, is_train=True):
|
|
data_num = len(data_list)
|
|
data_num = len(data_list)
|
|
|
|
|
|
# 定义Embedding输出
|
|
# 定义Embedding输出
|
|
- output_shape = (6, 10, 60)
|
|
|
|
|
|
+ output_shape = (6, 20, 60)
|
|
|
|
|
|
# batch循环取数据
|
|
# batch循环取数据
|
|
i = 0
|
|
i = 0
|
|
@@ -349,8 +460,109 @@ def my_data_loader(data_list, data_label_list, batch_size, is_train=True):
|
|
'input_4': X[3], 'input_5': X[4], 'input_6': X[5], }
|
|
'input_4': X[3], 'input_5': X[4], 'input_6': X[5], }
|
|
|
|
|
|
|
|
|
|
|
|
+def my_data_loader_2(table_list, table_label_list, batch_size, is_train=True):
|
|
|
|
+ pad_len = 0
|
|
|
|
+
|
|
|
|
+ table_num = len(table_list)
|
|
|
|
+ if is_train and batch_size == 1:
|
|
|
|
+ table_list, table_label_list = get_random(table_list, table_label_list)
|
|
|
|
+
|
|
|
|
+ # Embedding shape
|
|
|
|
+ output_shape = (20, 60)
|
|
|
|
+
|
|
|
|
+ # batch循环取数据
|
|
|
|
+ i = 0
|
|
|
|
+ last_shape = None
|
|
|
|
+ while True:
|
|
|
|
+ new_table_list = []
|
|
|
|
+ new_table_label_list = []
|
|
|
|
+ for j in range(batch_size):
|
|
|
|
+ if i >= table_num:
|
|
|
|
+ i = 0
|
|
|
|
+ if is_train:
|
|
|
|
+ table_list, table_label_list = get_random(table_list, table_label_list,
|
|
|
|
+ seed=random.randint(1, 40))
|
|
|
|
+
|
|
|
|
+ if type(table_list[i]) != list:
|
|
|
|
+ table = eval(table_list[i][:-1])
|
|
|
|
+ else:
|
|
|
|
+ table = table_list[i]
|
|
|
|
+
|
|
|
|
+ if batch_size > 1:
|
|
|
|
+ if last_shape is None:
|
|
|
|
+ last_shape = (len(table), len(table[0]))
|
|
|
|
+ continue
|
|
|
|
+ if (len(table), len(table[0])) != last_shape:
|
|
|
|
+ last_shape = (len(table), len(table[0]))
|
|
|
|
+ break
|
|
|
|
+
|
|
|
|
+ if is_train:
|
|
|
|
+ table_label = eval(table_label_list[i][:-1])
|
|
|
|
+
|
|
|
|
+ # 中文字符映射为Embedding
|
|
|
|
+ for k in range(len(table)):
|
|
|
|
+ table[k] = embedding_word_forward(table[k], (len(table[k]),
|
|
|
|
+ output_shape[0],
|
|
|
|
+ output_shape[1]))
|
|
|
|
+ new_table_list.append(table)
|
|
|
|
+ if is_train:
|
|
|
|
+ new_table_label_list.append(table_label)
|
|
|
|
+ i += 1
|
|
|
|
+ new_table_list = np.array(new_table_list)
|
|
|
|
+ X = new_table_list
|
|
|
|
+ if X.shape[-2:] != output_shape:
|
|
|
|
+ # print("Dimension not match!", X.shape)
|
|
|
|
+ # print("\n")
|
|
|
|
+ continue
|
|
|
|
+
|
|
|
|
+ # 获取Padding大小
|
|
|
|
+ pad_height = get_best_padding_size(X.shape[1], pad_len)
|
|
|
|
+ pad_width = get_best_padding_size(X.shape[2], pad_len)
|
|
|
|
+ input_2 = np.zeros([1, X.shape[1], X.shape[2], pad_height, pad_width])
|
|
|
|
+
|
|
|
|
+ if is_train:
|
|
|
|
+ new_table_label_list = np.array(new_table_label_list)
|
|
|
|
+ Y = new_table_label_list
|
|
|
|
+ # Y = Y.astype(np.float32)
|
|
|
|
+ # yield {"input_1": X, "input_2": input_2}, \
|
|
|
|
+ # {"output_1": Y, "output_2": Y}
|
|
|
|
+ yield {"input_1": X, "input_2": input_2}, \
|
|
|
|
+ {"output": Y}
|
|
|
|
+ else:
|
|
|
|
+ yield {"input_1": X, "input_2": input_2}
|
|
|
|
+
|
|
|
|
+
|
|
|
|
+def check_train_data():
|
|
|
|
+ data_list, label_list = get_data_from_file('txt', model_id=2)
|
|
|
|
+ for data in data_list:
|
|
|
|
+ data = eval(data)
|
|
|
|
+ if len(data) % 8 != 0:
|
|
|
|
+ print(len(data))
|
|
|
|
+ print(len(data[0]))
|
|
|
|
+ for row in data:
|
|
|
|
+ if len(row) % 8 != 0:
|
|
|
|
+ print(len(data))
|
|
|
|
+ print(len(row))
|
|
|
|
+
|
|
|
|
+
|
|
|
|
+def get_random(text_list, label_list, seed=42):
|
|
|
|
+ random.seed(seed)
|
|
|
|
+ zip_list = list(zip(text_list, label_list))
|
|
|
|
+ random.shuffle(zip_list)
|
|
|
|
+ text_list[:], label_list[:] = zip(*zip_list)
|
|
|
|
+ return text_list, label_list
|
|
|
|
+
|
|
|
|
+
|
|
if __name__ == '__main__':
|
|
if __name__ == '__main__':
|
|
- processed_save_to_txt()
|
|
|
|
|
|
+ processed_save_to_txt(whole_table=False, padding=False)
|
|
# data_balance()
|
|
# data_balance()
|
|
|
|
|
|
# test_embedding()
|
|
# test_embedding()
|
|
|
|
+ # check_train_data()
|
|
|
|
+
|
|
|
|
+ # _list = []
|
|
|
|
+ # for i in range(1, 100):
|
|
|
|
+ # _list.append(i*3)
|
|
|
|
+ # print(_list)
|
|
|
|
+
|
|
|
|
+ # print(get_best_padding_size(9, 5))
|