|
@@ -1,6 +1,5 @@
|
|
import psycopg2
|
|
import psycopg2
|
|
import numpy as np
|
|
import numpy as np
|
|
-
|
|
|
|
from BiddingKG.dl.common.Utils import embedding_word
|
|
from BiddingKG.dl.common.Utils import embedding_word
|
|
|
|
|
|
|
|
|
|
@@ -83,9 +82,9 @@ def get_data_from_sql(dim=10):
|
|
|
|
|
|
# 只有一行的也不要
|
|
# 只有一行的也不要
|
|
if len(post_label) >= 2:
|
|
if len(post_label) >= 2:
|
|
- data_list, data_label_list = table_process(table_text, post_label, _id)
|
|
|
|
|
|
+ data_list, data_label_list = table_pre_process(table_text, post_label, _id)
|
|
elif len(pre_label) >= 2:
|
|
elif len(pre_label) >= 2:
|
|
- data_list, data_label_list = table_process(table_text, pre_label, _id)
|
|
|
|
|
|
+ data_list, data_label_list = table_pre_process(table_text, pre_label, _id)
|
|
else:
|
|
else:
|
|
data_list, data_label_list = [], []
|
|
data_list, data_label_list = [], []
|
|
|
|
|
|
@@ -140,77 +139,105 @@ def get_data_from_sql(dim=10):
|
|
return all_data_list, all_data_label_list
|
|
return all_data_list, all_data_label_list
|
|
|
|
|
|
|
|
|
|
-def table_process(text_list, label_list, _id):
|
|
|
|
- if len(text_list) != len(label_list):
|
|
|
|
- print("文字单元格与标注单元格数量不匹配!", _id)
|
|
|
|
- print("len(text_list)", len(text_list), "len(label_list)", len(label_list))
|
|
|
|
- return [], []
|
|
|
|
|
|
+def table_pre_process(text_list, label_list, _id, is_train=True):
|
|
|
|
+ """
|
|
|
|
+ 表格处理,每个单元格生成2条数据,横竖各1条
|
|
|
|
+
|
|
|
|
+ :param text_list:
|
|
|
|
+ :param label_list:
|
|
|
|
+ :param _id:
|
|
|
|
+ :param is_train:
|
|
|
|
+ :return:
|
|
|
|
+ """
|
|
|
|
+ if is_train:
|
|
|
|
+ if len(text_list) != len(label_list):
|
|
|
|
+ print("文字单元格与标注单元格数量不匹配!", _id)
|
|
|
|
+ print("len(text_list)", len(text_list), "len(label_list)", len(label_list))
|
|
|
|
+ return [], []
|
|
|
|
|
|
data_list = []
|
|
data_list = []
|
|
data_label_list = []
|
|
data_label_list = []
|
|
for i in range(len(text_list)):
|
|
for i in range(len(text_list)):
|
|
row = text_list[i]
|
|
row = text_list[i]
|
|
- row_label = label_list[i]
|
|
|
|
|
|
+ if is_train:
|
|
|
|
+ row_label = label_list[i]
|
|
|
|
|
|
if i > 0:
|
|
if i > 0:
|
|
last_row = text_list[i-1]
|
|
last_row = text_list[i-1]
|
|
- last_row_label = label_list[i-1]
|
|
|
|
|
|
+ if is_train:
|
|
|
|
+ last_row_label = label_list[i-1]
|
|
else:
|
|
else:
|
|
last_row = []
|
|
last_row = []
|
|
- last_row_label = []
|
|
|
|
|
|
+ if is_train:
|
|
|
|
+ last_row_label = []
|
|
|
|
|
|
if i < len(text_list) - 1:
|
|
if i < len(text_list) - 1:
|
|
next_row = text_list[i+1]
|
|
next_row = text_list[i+1]
|
|
- next_row_label = label_list[i+1]
|
|
|
|
|
|
+ if is_train:
|
|
|
|
+ next_row_label = label_list[i+1]
|
|
else:
|
|
else:
|
|
next_row = []
|
|
next_row = []
|
|
- next_row_label = []
|
|
|
|
|
|
+ if is_train:
|
|
|
|
+ next_row_label = []
|
|
|
|
|
|
for j in range(len(row)):
|
|
for j in range(len(row)):
|
|
col = row[j]
|
|
col = row[j]
|
|
- col_label = row_label[j]
|
|
|
|
|
|
+ if is_train:
|
|
|
|
+ col_label = row_label[j]
|
|
|
|
|
|
# 超出表格置为None, 0
|
|
# 超出表格置为None, 0
|
|
if j > 0:
|
|
if j > 0:
|
|
last_col = row[j-1]
|
|
last_col = row[j-1]
|
|
- last_col_label = row_label[j-1]
|
|
|
|
|
|
+ if is_train:
|
|
|
|
+ last_col_label = row_label[j-1]
|
|
else:
|
|
else:
|
|
last_col = None
|
|
last_col = None
|
|
- last_col_label = 0
|
|
|
|
|
|
+ if is_train:
|
|
|
|
+ last_col_label = 0
|
|
|
|
|
|
if j < len(row) - 1:
|
|
if j < len(row) - 1:
|
|
next_col = row[j+1]
|
|
next_col = row[j+1]
|
|
- next_col_label = row_label[j+1]
|
|
|
|
|
|
+ if is_train:
|
|
|
|
+ next_col_label = row_label[j+1]
|
|
else:
|
|
else:
|
|
next_col = None
|
|
next_col = None
|
|
- next_col_label = 0
|
|
|
|
|
|
+ if is_train:
|
|
|
|
+ next_col_label = 0
|
|
|
|
|
|
if last_row:
|
|
if last_row:
|
|
last_row_col = last_row[j]
|
|
last_row_col = last_row[j]
|
|
- last_row_col_label = last_row_label[j]
|
|
|
|
|
|
+ if is_train:
|
|
|
|
+ last_row_col_label = last_row_label[j]
|
|
else:
|
|
else:
|
|
last_row_col = None
|
|
last_row_col = None
|
|
- last_row_col_label = 0
|
|
|
|
|
|
+ if is_train:
|
|
|
|
+ last_row_col_label = 0
|
|
|
|
|
|
if next_row:
|
|
if next_row:
|
|
next_row_col = next_row[j]
|
|
next_row_col = next_row[j]
|
|
- next_row_col_label = next_row_label[j]
|
|
|
|
|
|
+ if is_train:
|
|
|
|
+ next_row_col_label = next_row_label[j]
|
|
else:
|
|
else:
|
|
next_row_col = None
|
|
next_row_col = None
|
|
- next_row_col_label = 0
|
|
|
|
|
|
+ if is_train:
|
|
|
|
+ next_row_col_label = 0
|
|
|
|
|
|
# 三元组有一对不相等就作为数据
|
|
# 三元组有一对不相等就作为数据
|
|
# if col != next_col or col != last_col:
|
|
# if col != next_col or col != last_col:
|
|
data_list.append([last_col, col, next_col])
|
|
data_list.append([last_col, col, next_col])
|
|
- data_label_list.append([int(last_col_label), int(col_label),
|
|
|
|
- int(next_col_label)])
|
|
|
|
|
|
+ if is_train:
|
|
|
|
+ data_label_list.append([int(last_col_label), int(col_label),
|
|
|
|
+ int(next_col_label)])
|
|
|
|
|
|
# if col != next_row_col or col != last_row_col:
|
|
# if col != next_row_col or col != last_row_col:
|
|
data_list.append([last_row_col, col, next_row_col])
|
|
data_list.append([last_row_col, col, next_row_col])
|
|
- data_label_list.append([int(last_row_col_label), int(col_label),
|
|
|
|
- int(next_row_col_label)])
|
|
|
|
-
|
|
|
|
- return data_list, data_label_list
|
|
|
|
|
|
+ if is_train:
|
|
|
|
+ data_label_list.append([int(last_row_col_label), int(col_label),
|
|
|
|
+ int(next_row_col_label)])
|
|
|
|
+ if is_train:
|
|
|
|
+ return data_list, data_label_list
|
|
|
|
+ else:
|
|
|
|
+ return data_list
|
|
|
|
|
|
|
|
|
|
def get_data_from_file(file_type):
|
|
def get_data_from_file(file_type):
|
|
@@ -285,7 +312,7 @@ def test_embedding():
|
|
print(result)
|
|
print(result)
|
|
|
|
|
|
|
|
|
|
-def my_data_loader(data_list, data_label_list, batch_size):
|
|
|
|
|
|
+def my_data_loader(data_list, data_label_list, batch_size, is_train=True):
|
|
data_num = len(data_list)
|
|
data_num = len(data_list)
|
|
|
|
|
|
# 定义Embedding输出
|
|
# 定义Embedding输出
|
|
@@ -293,30 +320,50 @@ def my_data_loader(data_list, data_label_list, batch_size):
|
|
|
|
|
|
# batch循环取数据
|
|
# batch循环取数据
|
|
i = 0
|
|
i = 0
|
|
- while True:
|
|
|
|
- new_data_list = []
|
|
|
|
- new_data_label_list = []
|
|
|
|
- for j in range(batch_size):
|
|
|
|
- if i >= data_num:
|
|
|
|
- i = 0
|
|
|
|
-
|
|
|
|
- # 中文字符映射为Embedding
|
|
|
|
- data = eval(data_list[i][:-1])
|
|
|
|
- data_label = eval(data_label_list[i][:-1])
|
|
|
|
- data = embedding_word(data, output_shape)
|
|
|
|
- if data.shape == output_shape:
|
|
|
|
- new_data_list.append(data)
|
|
|
|
- new_data_label_list.append(data_label)
|
|
|
|
- i += 1
|
|
|
|
-
|
|
|
|
- new_data_list = np.array(new_data_list)
|
|
|
|
- new_data_label_list = np.array(new_data_label_list)
|
|
|
|
- X = new_data_list
|
|
|
|
- Y = new_data_label_list
|
|
|
|
- # (table_num, 3 sentences, dim characters, embedding) -> (3, table_num, dim, embedding)
|
|
|
|
- X = np.transpose(X, (1, 0, 2, 3))
|
|
|
|
- # print("input_1", X[0].shape, "input_2", X[1].shape, "input_3", X[2].shape, "Y", Y.shape)
|
|
|
|
- yield {'input_1': X[0], 'input_2': X[1], 'input_3': X[2]}, {'output': Y}
|
|
|
|
|
|
+ if is_train:
|
|
|
|
+ while True:
|
|
|
|
+ new_data_list = []
|
|
|
|
+ new_data_label_list = []
|
|
|
|
+ for j in range(batch_size):
|
|
|
|
+ if i >= data_num:
|
|
|
|
+ i = 0
|
|
|
|
+
|
|
|
|
+ # 中文字符映射为Embedding
|
|
|
|
+ data = eval(data_list[i][:-1])
|
|
|
|
+ data_label = eval(data_label_list[i][:-1])
|
|
|
|
+ data = embedding_word(data, output_shape)
|
|
|
|
+ if data.shape == output_shape:
|
|
|
|
+ new_data_list.append(data)
|
|
|
|
+ new_data_label_list.append(data_label)
|
|
|
|
+ i += 1
|
|
|
|
+
|
|
|
|
+ new_data_list = np.array(new_data_list)
|
|
|
|
+ new_data_label_list = np.array(new_data_label_list)
|
|
|
|
+ X = new_data_list
|
|
|
|
+ Y = new_data_label_list
|
|
|
|
+ # (table_num, 3 sentences, dim characters, embedding) -> (3, table_num, dim, embedding)
|
|
|
|
+ X = np.transpose(X, (1, 0, 2, 3))
|
|
|
|
+ # print("input_1", X[0].shape, "input_2", X[1].shape, "input_3", X[2].shape, "Y", Y.shape)
|
|
|
|
+ yield {'input_1': X[0], 'input_2': X[1], 'input_3': X[2]}, {'output': Y}
|
|
|
|
+
|
|
|
|
+ else:
|
|
|
|
+ while True:
|
|
|
|
+ new_data_list = []
|
|
|
|
+ for j in range(batch_size):
|
|
|
|
+ if i >= data_num:
|
|
|
|
+ i = 0
|
|
|
|
+
|
|
|
|
+ # 中文字符映射为Embedding
|
|
|
|
+ data = data_list[i]
|
|
|
|
+ data = embedding_word(data, output_shape)
|
|
|
|
+ if data.shape == output_shape:
|
|
|
|
+ new_data_list.append(data)
|
|
|
|
+ i += 1
|
|
|
|
+
|
|
|
|
+ new_data_list = np.array(new_data_list)
|
|
|
|
+ X = new_data_list
|
|
|
|
+ X = np.transpose(X, (1, 0, 2, 3))
|
|
|
|
+ yield {'input_1': X[0], 'input_2': X[1], 'input_3': X[2]}
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__':
|
|
if __name__ == '__main__':
|