|
@@ -1,9 +1,8 @@
|
|
-import sys
|
|
|
|
-import os
|
|
|
|
-sys.path.append(os.path.abspath("../.."))
|
|
|
|
import psycopg2
|
|
import psycopg2
|
|
import numpy as np
|
|
import numpy as np
|
|
|
|
|
|
|
|
+from BiddingKG.dl.common.Utils import embedding_word
|
|
|
|
+
|
|
|
|
|
|
def get_sentence_index_list(sentence, dict_path='utils/ppocr_keys_v1.txt'):
|
|
def get_sentence_index_list(sentence, dict_path='utils/ppocr_keys_v1.txt'):
|
|
with open(dict_path, 'r') as f:
|
|
with open(dict_path, 'r') as f:
|
|
@@ -47,14 +46,14 @@ def get_data_from_sql(dim=10):
|
|
select table_text, pre_label, post_label, id
|
|
select table_text, pre_label, post_label, id
|
|
from label_table_head_info
|
|
from label_table_head_info
|
|
where update_user <> 'test27' and table_box_cnt >= 4 and table_box_cnt <= 200
|
|
where update_user <> 'test27' and table_box_cnt >= 4 and table_box_cnt <= 200
|
|
- limit 1000;
|
|
|
|
|
|
+ ;
|
|
"""
|
|
"""
|
|
# sql = """
|
|
# sql = """
|
|
# select table_text, pre_label, post_label, id
|
|
# select table_text, pre_label, post_label, id
|
|
# from label_table_head_info
|
|
# from label_table_head_info
|
|
# where id = 843
|
|
# where id = 843
|
|
# """
|
|
# """
|
|
- result_list = postgresql_util(sql, limit=10000)
|
|
|
|
|
|
+ result_list = postgresql_util(sql, limit=1000000)
|
|
|
|
|
|
all_data_list = []
|
|
all_data_list = []
|
|
all_data_label_list = []
|
|
all_data_label_list = []
|
|
@@ -90,22 +89,54 @@ def get_data_from_sql(dim=10):
|
|
else:
|
|
else:
|
|
data_list, data_label_list = [], []
|
|
data_list, data_label_list = [], []
|
|
|
|
|
|
- for data in data_list:
|
|
|
|
- # 中文字符映射为index
|
|
|
|
- data[0] = get_sentence_index_list(data[0])
|
|
|
|
- data[1] = get_sentence_index_list(data[1])
|
|
|
|
-
|
|
|
|
- # 维度不够,填充掩码0
|
|
|
|
- if len(data[0]) < dim:
|
|
|
|
- data[0] = data[0] + [0]*(dim-len(data[0]))
|
|
|
|
- elif len(data[0]) > dim:
|
|
|
|
- data[0] = data[0][:dim]
|
|
|
|
- if len(data[1]) < dim:
|
|
|
|
- data[1] = data[1] + [0]*(dim-len(data[1]))
|
|
|
|
- elif len(data[1]) > dim:
|
|
|
|
- data[1] = data[1][:dim]
|
|
|
|
all_data_list += data_list
|
|
all_data_list += data_list
|
|
all_data_label_list += data_label_list
|
|
all_data_label_list += data_label_list
|
|
|
|
+
|
|
|
|
+ print("len(all_data_list)", len(all_data_list))
|
|
|
|
+
|
|
|
|
+ #
|
|
|
|
+ # new_data_list = []
|
|
|
|
+ # for data in data_list:
|
|
|
|
+ # # 中文字符映射为index
|
|
|
|
+ # # data[0] = get_sentence_index_list(data[0])
|
|
|
|
+ # # data[1] = get_sentence_index_list(data[1])
|
|
|
|
+ # # 维度不够,填充掩码0
|
|
|
|
+ # # if len(data[0]) < dim:
|
|
|
|
+ # # data[0] = data[0] + [0]*(dim-len(data[0]))
|
|
|
|
+ # # elif len(data[0]) > dim:
|
|
|
|
+ # # data[0] = data[0][:dim]
|
|
|
|
+ # # if len(data[1]) < dim:
|
|
|
|
+ # # data[1] = data[1] + [0]*(dim-len(data[1]))
|
|
|
|
+ # # elif len(data[1]) > dim:
|
|
|
|
+ # # data[1] = data[1][:dim]
|
|
|
|
+ #
|
|
|
|
+ # # 中文字符映射为Embedding
|
|
|
|
+ # data = embedding_word(data, input_shape)
|
|
|
|
+ # new_data_list.append(data)
|
|
|
|
+ #
|
|
|
|
+ # new_data_list = np.array(new_data_list)
|
|
|
|
+ # data_label_list = np.array(data_label_list)
|
|
|
|
+ # if np.array(new_data_list).shape[1:] == input_shape:
|
|
|
|
+ # all_data_list.append(new_data_list)
|
|
|
|
+ # all_data_label_list.append(data_label_list)
|
|
|
|
+
|
|
|
|
+ # # 防止concat太慢
|
|
|
|
+ # split_len = 1000
|
|
|
|
+ # _len = int(len(all_data_list) / split_len)
|
|
|
|
+ # all_data_list_1 = []
|
|
|
|
+ # all_data_list_2 = []
|
|
|
|
+ # for i in range(_len):
|
|
|
|
+ # if i == _len - 1:
|
|
|
|
+ # array1 = np.concatenate(all_data_list[i*split_len:])
|
|
|
|
+ # array2 = np.concatenate(all_data_label_list[i*split_len:])
|
|
|
|
+ # else:
|
|
|
|
+ # array1 = np.concatenate(all_data_list[i*split_len:i*split_len+split_len])
|
|
|
|
+ # array2 = np.concatenate(all_data_label_list[i*split_len:i*split_len+split_len])
|
|
|
|
+ # all_data_list_1.append(array1)
|
|
|
|
+ # all_data_list_2.append(array2)
|
|
|
|
+ # all_data_list = np.concatenate(all_data_list_1)
|
|
|
|
+ # all_data_label_list = np.concatenate(all_data_list_2)
|
|
|
|
+
|
|
return all_data_list, all_data_label_list
|
|
return all_data_list, all_data_label_list
|
|
|
|
|
|
|
|
|
|
@@ -121,6 +152,13 @@ def table_process(text_list, label_list, _id):
|
|
row = text_list[i]
|
|
row = text_list[i]
|
|
row_label = label_list[i]
|
|
row_label = label_list[i]
|
|
|
|
|
|
|
|
+ if i > 0:
|
|
|
|
+ last_row = text_list[i-1]
|
|
|
|
+ last_row_label = label_list[i-1]
|
|
|
|
+ else:
|
|
|
|
+ last_row = []
|
|
|
|
+ last_row_label = []
|
|
|
|
+
|
|
if i < len(text_list) - 1:
|
|
if i < len(text_list) - 1:
|
|
next_row = text_list[i+1]
|
|
next_row = text_list[i+1]
|
|
next_row_label = label_list[i+1]
|
|
next_row_label = label_list[i+1]
|
|
@@ -132,47 +170,86 @@ def table_process(text_list, label_list, _id):
|
|
col = row[j]
|
|
col = row[j]
|
|
col_label = row_label[j]
|
|
col_label = row_label[j]
|
|
|
|
|
|
|
|
+ # 超出表格置为None, 0
|
|
|
|
+ if j > 0:
|
|
|
|
+ last_col = row[j-1]
|
|
|
|
+ last_col_label = row_label[j-1]
|
|
|
|
+ else:
|
|
|
|
+ last_col = None
|
|
|
|
+ last_col_label = 0
|
|
|
|
+
|
|
if j < len(row) - 1:
|
|
if j < len(row) - 1:
|
|
next_col = row[j+1]
|
|
next_col = row[j+1]
|
|
next_col_label = row_label[j+1]
|
|
next_col_label = row_label[j+1]
|
|
else:
|
|
else:
|
|
- next_col = ""
|
|
|
|
- next_col_label = ""
|
|
|
|
|
|
+ next_col = None
|
|
|
|
+ next_col_label = 0
|
|
|
|
+
|
|
|
|
+ if last_row:
|
|
|
|
+ last_row_col = last_row[j]
|
|
|
|
+ last_row_col_label = last_row_label[j]
|
|
|
|
+ else:
|
|
|
|
+ last_row_col = None
|
|
|
|
+ last_row_col_label = 0
|
|
|
|
|
|
if next_row:
|
|
if next_row:
|
|
next_row_col = next_row[j]
|
|
next_row_col = next_row[j]
|
|
next_row_col_label = next_row_label[j]
|
|
next_row_col_label = next_row_label[j]
|
|
else:
|
|
else:
|
|
- next_row_col = ""
|
|
|
|
- next_row_col_label = ""
|
|
|
|
-
|
|
|
|
- if next_col:
|
|
|
|
- if col != next_col:
|
|
|
|
- data_list.append([col, next_col])
|
|
|
|
- data_label_list.append([int(col_label), int(next_col_label)])
|
|
|
|
- if next_row_col:
|
|
|
|
- if col != next_row_col:
|
|
|
|
- data_list.append([col, next_row_col])
|
|
|
|
- data_label_list.append([int(col_label), int(next_row_col_label)])
|
|
|
|
|
|
+ next_row_col = None
|
|
|
|
+ next_row_col_label = 0
|
|
|
|
+
|
|
|
|
+ # 三元组有一对不相等就作为数据
|
|
|
|
+ # if col != next_col or col != last_col:
|
|
|
|
+ data_list.append([last_col, col, next_col])
|
|
|
|
+ data_label_list.append([int(last_col_label), int(col_label),
|
|
|
|
+ int(next_col_label)])
|
|
|
|
+
|
|
|
|
+ # if col != next_row_col or col != last_row_col:
|
|
|
|
+ data_list.append([last_row_col, col, next_row_col])
|
|
|
|
+ data_label_list.append([int(last_row_col_label), int(col_label),
|
|
|
|
+ int(next_row_col_label)])
|
|
|
|
|
|
return data_list, data_label_list
|
|
return data_list, data_label_list
|
|
|
|
|
|
|
|
|
|
-def get_data_from_file():
|
|
|
|
- data_path = 'train_data/data.txt'
|
|
|
|
- data_label_path = 'train_data/data_label.txt'
|
|
|
|
|
|
+def get_data_from_file(file_type):
|
|
|
|
+ if file_type == 'np':
|
|
|
|
+ data_path = 'train_data/data_3.npy'
|
|
|
|
+ data_label_path = 'train_data/data_label_3.npy'
|
|
|
|
|
|
- with open(data_path, 'r') as f:
|
|
|
|
- data_list = f.readlines()
|
|
|
|
- with open(data_label_path, 'r') as f:
|
|
|
|
- data_label_list = f.readlines()
|
|
|
|
|
|
+ array1 = np.load(data_path)
|
|
|
|
+ array2 = np.load(data_label_path)
|
|
|
|
+ return array1, array2
|
|
|
|
+ elif file_type == 'txt':
|
|
|
|
+ data_path = 'train_data/data.txt'
|
|
|
|
+ data_label_path = 'train_data/data_label.txt'
|
|
|
|
|
|
- for i in range(len(data_list)):
|
|
|
|
- data_list[i] = eval(data_list[i][:-1])
|
|
|
|
- data_label_list[i] = eval(data_label_list[i][:-1])
|
|
|
|
|
|
+ with open(data_path, 'r') as f:
|
|
|
|
+ data_list = f.readlines()
|
|
|
|
+ with open(data_label_path, 'r') as f:
|
|
|
|
+ data_label_list = f.readlines()
|
|
|
|
|
|
- print(len(data_list))
|
|
|
|
- return data_list, data_label_list
|
|
|
|
|
|
+ # for i in range(len(data_list)):
|
|
|
|
+ # data_list[i] = eval(data_list[i][:-1])
|
|
|
|
+ # data_label_list[i] = eval(data_label_list[i][:-1])
|
|
|
|
+ return data_list, data_label_list
|
|
|
|
+ else:
|
|
|
|
+ print("file type error! only np and txt supported")
|
|
|
|
+ raise Exception
|
|
|
|
+
|
|
|
|
+
|
|
|
|
+def processed_save_to_np():
|
|
|
|
+ array1, array2 = get_data_from_sql()
|
|
|
|
+ np.save('train_data/data_3.npy', array1)
|
|
|
|
+ np.save('train_data/data_label_3.npy', array2)
|
|
|
|
+
|
|
|
|
+ # with open('train_data/data.txt', 'w') as f:
|
|
|
|
+ # for line in list1:
|
|
|
|
+ # f.write(str(line) + "\n")
|
|
|
|
+ # with open('train_data/data_label.txt', 'w') as f:
|
|
|
|
+ # for line in list2:
|
|
|
|
+ # f.write(str(line) + "\n")
|
|
|
|
|
|
|
|
|
|
def processed_save_to_txt():
|
|
def processed_save_to_txt():
|
|
@@ -185,5 +262,62 @@ def processed_save_to_txt():
|
|
f.write(str(line) + "\n")
|
|
f.write(str(line) + "\n")
|
|
|
|
|
|
|
|
|
|
|
|
+def data_balance():
|
|
|
|
+ array1, array2 = get_data_from_file()
|
|
|
|
+ data_list = array2.tolist()
|
|
|
|
+ all_cnt = len(data_list)
|
|
|
|
+ cnt_0 = 0
|
|
|
|
+ cnt_1 = 0
|
|
|
|
+ for data in data_list:
|
|
|
|
+ if data[0] == 1 or data[1] == 1:
|
|
|
|
+ cnt_1 += 1
|
|
|
|
+ else:
|
|
|
|
+ cnt_0 += 1
|
|
|
|
+ print("all_cnt", all_cnt)
|
|
|
|
+ print("label has 1", cnt_1)
|
|
|
|
+ print("label all 0", cnt_0)
|
|
|
|
+
|
|
|
|
+
|
|
|
|
+def test_embedding():
|
|
|
|
+ output_shape = (2, 1, 60)
|
|
|
|
+ data = [[None], [None]]
|
|
|
|
+ result = embedding_word(data, output_shape)
|
|
|
|
+ print(result)
|
|
|
|
+
|
|
|
|
+
|
|
|
|
+def my_data_loader(data_list, data_label_list, batch_size):
|
|
|
|
+ data_num = len(data_list)
|
|
|
|
+
|
|
|
|
+ # 定义Embedding输出
|
|
|
|
+ output_shape = (3, 10, 60)
|
|
|
|
+
|
|
|
|
+ # batch循环取数据
|
|
|
|
+ i = 0
|
|
|
|
+ while True:
|
|
|
|
+ new_data_list = []
|
|
|
|
+ for j in range(batch_size):
|
|
|
|
+ if i >= data_num:
|
|
|
|
+ i = 0
|
|
|
|
+
|
|
|
|
+ # 中文字符映射为Embedding
|
|
|
|
+ data = eval(data_list[i][:-1])
|
|
|
|
+ data = embedding_word(data, output_shape)
|
|
|
|
+ if data.shape == output_shape:
|
|
|
|
+ new_data_list.append(data)
|
|
|
|
+ i += 1
|
|
|
|
+
|
|
|
|
+ new_data_list = np.array(new_data_list)
|
|
|
|
+ data_label_list = np.array(data_label_list)
|
|
|
|
+ X = new_data_list
|
|
|
|
+ Y = data_label_list
|
|
|
|
+ # (table_num, 3 sentences, dim characters, embedding) -> (3, table_num, dim, embedding)
|
|
|
|
+ X = np.transpose(X, (1, 0, 2, 3))
|
|
|
|
+
|
|
|
|
+ yield [X[0], X[1], X[2]], Y
|
|
|
|
+
|
|
|
|
+
|
|
if __name__ == '__main__':
|
|
if __name__ == '__main__':
|
|
- get_data_from_file()
|
|
|
|
|
|
+ processed_save_to_txt()
|
|
|
|
+ # data_balance()
|
|
|
|
+
|
|
|
|
+ # test_embedding()
|