|
@@ -1,3 +1,5 @@
|
|
|
+import random
|
|
|
+
|
|
|
import psycopg2
|
|
|
import numpy as np
|
|
|
from BiddingKG.dl.common.Utils import embedding_word
|
|
@@ -41,17 +43,19 @@ def postgresql_util(sql, limit):
|
|
|
|
|
|
|
|
|
def get_data_from_sql(dim=10):
|
|
|
+ # sql = """
|
|
|
+ # select table_text, pre_label, post_label, id
|
|
|
+ # from label_table_head_info
|
|
|
+ # where update_user <> 'test27' and update_user <> 'test20' and table_box_cnt >= 4 and table_box_cnt <= 200
|
|
|
+ # ;
|
|
|
+ # """
|
|
|
sql = """
|
|
|
select table_text, pre_label, post_label, id
|
|
|
from label_table_head_info
|
|
|
- where update_user <> 'test27' and table_box_cnt >= 4 and table_box_cnt <= 200
|
|
|
+ where status = 1 and update_time >= '2022-01-17'
|
|
|
;
|
|
|
"""
|
|
|
- # sql = """
|
|
|
- # select table_text, pre_label, post_label, id
|
|
|
- # from label_table_head_info
|
|
|
- # where id = 843
|
|
|
- # """
|
|
|
+
|
|
|
result_list = postgresql_util(sql, limit=1000000)
|
|
|
|
|
|
all_data_list = []
|
|
@@ -92,50 +96,6 @@ def get_data_from_sql(dim=10):
|
|
|
all_data_label_list += data_label_list
|
|
|
|
|
|
print("len(all_data_list)", len(all_data_list))
|
|
|
-
|
|
|
- #
|
|
|
- # new_data_list = []
|
|
|
- # for data in data_list:
|
|
|
- # # 中文字符映射为index
|
|
|
- # # data[0] = get_sentence_index_list(data[0])
|
|
|
- # # data[1] = get_sentence_index_list(data[1])
|
|
|
- # # 维度不够,填充掩码0
|
|
|
- # # if len(data[0]) < dim:
|
|
|
- # # data[0] = data[0] + [0]*(dim-len(data[0]))
|
|
|
- # # elif len(data[0]) > dim:
|
|
|
- # # data[0] = data[0][:dim]
|
|
|
- # # if len(data[1]) < dim:
|
|
|
- # # data[1] = data[1] + [0]*(dim-len(data[1]))
|
|
|
- # # elif len(data[1]) > dim:
|
|
|
- # # data[1] = data[1][:dim]
|
|
|
- #
|
|
|
- # # 中文字符映射为Embedding
|
|
|
- # data = embedding_word(data, input_shape)
|
|
|
- # new_data_list.append(data)
|
|
|
- #
|
|
|
- # new_data_list = np.array(new_data_list)
|
|
|
- # data_label_list = np.array(data_label_list)
|
|
|
- # if np.array(new_data_list).shape[1:] == input_shape:
|
|
|
- # all_data_list.append(new_data_list)
|
|
|
- # all_data_label_list.append(data_label_list)
|
|
|
-
|
|
|
- # # 防止concat太慢
|
|
|
- # split_len = 1000
|
|
|
- # _len = int(len(all_data_list) / split_len)
|
|
|
- # all_data_list_1 = []
|
|
|
- # all_data_list_2 = []
|
|
|
- # for i in range(_len):
|
|
|
- # if i == _len - 1:
|
|
|
- # array1 = np.concatenate(all_data_list[i*split_len:])
|
|
|
- # array2 = np.concatenate(all_data_label_list[i*split_len:])
|
|
|
- # else:
|
|
|
- # array1 = np.concatenate(all_data_list[i*split_len:i*split_len+split_len])
|
|
|
- # array2 = np.concatenate(all_data_label_list[i*split_len:i*split_len+split_len])
|
|
|
- # all_data_list_1.append(array1)
|
|
|
- # all_data_list_2.append(array2)
|
|
|
- # all_data_list = np.concatenate(all_data_list_1)
|
|
|
- # all_data_label_list = np.concatenate(all_data_list_2)
|
|
|
-
|
|
|
return all_data_list, all_data_label_list
|
|
|
|
|
|
|
|
@@ -191,49 +151,55 @@ def table_pre_process(text_list, label_list, _id, is_train=True):
|
|
|
if is_train:
|
|
|
last_col_label = row_label[j-1]
|
|
|
else:
|
|
|
- last_col = None
|
|
|
+ last_col = col
|
|
|
if is_train:
|
|
|
- last_col_label = 0
|
|
|
+ last_col_label = col_label
|
|
|
|
|
|
if j < len(row) - 1:
|
|
|
next_col = row[j+1]
|
|
|
if is_train:
|
|
|
next_col_label = row_label[j+1]
|
|
|
else:
|
|
|
- next_col = None
|
|
|
+ next_col = col
|
|
|
if is_train:
|
|
|
- next_col_label = 0
|
|
|
+ next_col_label = col_label
|
|
|
|
|
|
if last_row:
|
|
|
last_row_col = last_row[j]
|
|
|
if is_train:
|
|
|
last_row_col_label = last_row_label[j]
|
|
|
else:
|
|
|
- last_row_col = None
|
|
|
+ last_row_col = col
|
|
|
if is_train:
|
|
|
- last_row_col_label = 0
|
|
|
+ last_row_col_label = col_label
|
|
|
|
|
|
if next_row:
|
|
|
next_row_col = next_row[j]
|
|
|
if is_train:
|
|
|
next_row_col_label = next_row_label[j]
|
|
|
else:
|
|
|
- next_row_col = None
|
|
|
+ next_row_col = col
|
|
|
if is_train:
|
|
|
- next_row_col_label = 0
|
|
|
+ next_row_col_label = col_label
|
|
|
+
|
|
|
+ # data_list.append([last_col, col, next_col])
|
|
|
+ # if is_train:
|
|
|
+ # data_label_list.append([int(last_col_label), int(col_label),
|
|
|
+ # int(next_col_label)])
|
|
|
+ #
|
|
|
+ # data_list.append([last_row_col, col, next_row_col])
|
|
|
+ # if is_train:
|
|
|
+ # data_label_list.append([int(last_row_col_label), int(col_label),
|
|
|
+ # int(next_row_col_label)])
|
|
|
|
|
|
- # 三元组有一对不相等就作为数据
|
|
|
- # if col != next_col or col != last_col:
|
|
|
- data_list.append([last_col, col, next_col])
|
|
|
if is_train:
|
|
|
- data_label_list.append([int(last_col_label), int(col_label),
|
|
|
- int(next_col_label)])
|
|
|
-
|
|
|
- # if col != next_row_col or col != last_row_col:
|
|
|
- data_list.append([last_row_col, col, next_row_col])
|
|
|
- if is_train:
|
|
|
- data_label_list.append([int(last_row_col_label), int(col_label),
|
|
|
- int(next_row_col_label)])
|
|
|
+ dup_list = [str(x) for x in data_list]
|
|
|
+ data = [last_col, col, next_col, last_row_col, col, next_row_col]
|
|
|
+ if str(data) not in dup_list:
|
|
|
+ data_list.append([last_col, col, next_col, last_row_col, col, next_row_col])
|
|
|
+ data_label_list.append(int(col_label))
|
|
|
+ else:
|
|
|
+ data_list.append([last_col, col, next_col, last_row_col, col, next_row_col])
|
|
|
if is_train:
|
|
|
return data_list, data_label_list
|
|
|
else:
|
|
@@ -249,8 +215,8 @@ def get_data_from_file(file_type):
|
|
|
array2 = np.load(data_label_path)
|
|
|
return array1, array2
|
|
|
elif file_type == 'txt':
|
|
|
- data_path = 'train_data/data.txt'
|
|
|
- data_label_path = 'train_data/data_label.txt'
|
|
|
+ data_path = 'train_data/data3.txt'
|
|
|
+ data_label_path = 'train_data/data_label3.txt'
|
|
|
|
|
|
with open(data_path, 'r') as f:
|
|
|
data_list = f.readlines()
|
|
@@ -281,22 +247,27 @@ def processed_save_to_np():
|
|
|
|
|
|
def processed_save_to_txt():
|
|
|
list1, list2 = get_data_from_sql()
|
|
|
- with open('train_data/data.txt', 'w') as f:
|
|
|
+
|
|
|
+ # 打乱
|
|
|
+ zip_list = list(zip(list1, list2))
|
|
|
+ random.shuffle(zip_list)
|
|
|
+ list1[:], list2[:] = zip(*zip_list)
|
|
|
+
|
|
|
+ with open('train_data/data3.txt', 'w') as f:
|
|
|
for line in list1:
|
|
|
f.write(str(line) + "\n")
|
|
|
- with open('train_data/data_label.txt', 'w') as f:
|
|
|
+ with open('train_data/data_label3.txt', 'w') as f:
|
|
|
for line in list2:
|
|
|
f.write(str(line) + "\n")
|
|
|
|
|
|
|
|
|
def data_balance():
|
|
|
- array1, array2 = get_data_from_file()
|
|
|
- data_list = array2.tolist()
|
|
|
- all_cnt = len(data_list)
|
|
|
+ data_list, data_label_list = get_data_from_file('txt')
|
|
|
+ all_cnt = len(data_label_list)
|
|
|
cnt_0 = 0
|
|
|
cnt_1 = 0
|
|
|
- for data in data_list:
|
|
|
- if data[0] == 1 or data[1] == 1:
|
|
|
+ for data in data_label_list:
|
|
|
+ if eval(data[:-1])[1] == 1:
|
|
|
cnt_1 += 1
|
|
|
else:
|
|
|
cnt_0 += 1
|
|
@@ -316,7 +287,7 @@ def my_data_loader(data_list, data_label_list, batch_size, is_train=True):
|
|
|
data_num = len(data_list)
|
|
|
|
|
|
# 定义Embedding输出
|
|
|
- output_shape = (3, 10, 60)
|
|
|
+ output_shape = (6, 10, 60)
|
|
|
|
|
|
# batch循环取数据
|
|
|
i = 0
|
|
@@ -343,8 +314,19 @@ def my_data_loader(data_list, data_label_list, batch_size, is_train=True):
|
|
|
Y = new_data_label_list
|
|
|
# (table_num, 3 sentences, dim characters, embedding) -> (3, table_num, dim, embedding)
|
|
|
X = np.transpose(X, (1, 0, 2, 3))
|
|
|
- # print("input_1", X[0].shape, "input_2", X[1].shape, "input_3", X[2].shape, "Y", Y.shape)
|
|
|
- yield {'input_1': X[0], 'input_2': X[1], 'input_3': X[2]}, {'output': Y}
|
|
|
+
|
|
|
+ if (X[0] == X[1]).all():
|
|
|
+ X[0] = np.zeros_like(X[1], dtype='float32')
|
|
|
+ if (X[2] == X[1]).all():
|
|
|
+ X[2] = np.zeros_like(X[1], dtype='float32')
|
|
|
+ if (X[3] == X[1]).all():
|
|
|
+ X[3] = np.zeros_like(X[1], dtype='float32')
|
|
|
+ if (X[5] == X[1]).all():
|
|
|
+ X[5] = np.zeros_like(X[1], dtype='float32')
|
|
|
+
|
|
|
+ yield {'input_1': X[0], 'input_2': X[1], 'input_3': X[2],
|
|
|
+ 'input_4': X[3], 'input_5': X[4], 'input_6': X[5]}, \
|
|
|
+ {'output': Y}
|
|
|
|
|
|
else:
|
|
|
while True:
|
|
@@ -363,7 +345,8 @@ def my_data_loader(data_list, data_label_list, batch_size, is_train=True):
|
|
|
new_data_list = np.array(new_data_list)
|
|
|
X = new_data_list
|
|
|
X = np.transpose(X, (1, 0, 2, 3))
|
|
|
- yield {'input_1': X[0], 'input_2': X[1], 'input_3': X[2]}
|
|
|
+ yield {'input_1': X[0], 'input_2': X[1], 'input_3': X[2],
|
|
|
+ 'input_4': X[3], 'input_5': X[4], 'input_6': X[5], }
|
|
|
|
|
|
|
|
|
if __name__ == '__main__':
|