123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323 |
- import psycopg2
- import numpy as np
- from BiddingKG.dl.common.Utils import embedding_word
- def get_sentence_index_list(sentence, dict_path='utils/ppocr_keys_v1.txt'):
- with open(dict_path, 'r') as f:
- character_list = f.readlines()
- for i in range(len(character_list)):
- character_list[i] = character_list[i][:-1]
- index_list = []
- for character in sentence:
- if character == '':
- index_list.append(0)
- elif character in character_list:
- _index = character_list.index(character) + 1
- index_list.append(_index)
- else:
- index_list.append(0)
- return index_list
- def postgresql_util(sql, limit):
- conn = psycopg2.connect(dbname="table_head_label", user="postgres", password="postgres",
- host="192.168.2.103")
- cursor = conn.cursor()
- cursor.execute(sql)
- print(sql)
- rows = cursor.fetchmany(1000)
- cnt = 0
- all_rows = []
- while rows:
- if cnt >= limit:
- break
- all_rows += rows
- cnt += len(rows)
- rows = cursor.fetchmany(1000)
- return all_rows
- def get_data_from_sql(dim=10):
- sql = """
- select table_text, pre_label, post_label, id
- from label_table_head_info
- where update_user <> 'test27' and table_box_cnt >= 4 and table_box_cnt <= 200
- ;
- """
- # sql = """
- # select table_text, pre_label, post_label, id
- # from label_table_head_info
- # where id = 843
- # """
- result_list = postgresql_util(sql, limit=1000000)
- all_data_list = []
- all_data_label_list = []
- i = 0
- # 一行就是一篇表格
- for table in result_list:
- i += 1
- if i % 100 == 0:
- print("Loop", i)
- pre_label = eval(table[1])
- post_label = eval(table[2])
- _id = table[3]
- # table_text需要特殊处理
- try:
- table_text = table[0]
- if table_text[0] == '"':
- table_text = eval(table_text)
- else:
- table_text = table_text
- table_text = table_text.replace('\\', '/')
- table_text = eval(table_text)
- except:
- print("无法识别table_text", _id)
- continue
- # 只有一行的也不要
- if len(post_label) >= 2:
- data_list, data_label_list = table_process(table_text, post_label, _id)
- elif len(pre_label) >= 2:
- data_list, data_label_list = table_process(table_text, pre_label, _id)
- else:
- data_list, data_label_list = [], []
- all_data_list += data_list
- all_data_label_list += data_label_list
- print("len(all_data_list)", len(all_data_list))
- #
- # new_data_list = []
- # for data in data_list:
- # # 中文字符映射为index
- # # data[0] = get_sentence_index_list(data[0])
- # # data[1] = get_sentence_index_list(data[1])
- # # 维度不够,填充掩码0
- # # if len(data[0]) < dim:
- # # data[0] = data[0] + [0]*(dim-len(data[0]))
- # # elif len(data[0]) > dim:
- # # data[0] = data[0][:dim]
- # # if len(data[1]) < dim:
- # # data[1] = data[1] + [0]*(dim-len(data[1]))
- # # elif len(data[1]) > dim:
- # # data[1] = data[1][:dim]
- #
- # # 中文字符映射为Embedding
- # data = embedding_word(data, input_shape)
- # new_data_list.append(data)
- #
- # new_data_list = np.array(new_data_list)
- # data_label_list = np.array(data_label_list)
- # if np.array(new_data_list).shape[1:] == input_shape:
- # all_data_list.append(new_data_list)
- # all_data_label_list.append(data_label_list)
- # # 防止concat太慢
- # split_len = 1000
- # _len = int(len(all_data_list) / split_len)
- # all_data_list_1 = []
- # all_data_list_2 = []
- # for i in range(_len):
- # if i == _len - 1:
- # array1 = np.concatenate(all_data_list[i*split_len:])
- # array2 = np.concatenate(all_data_label_list[i*split_len:])
- # else:
- # array1 = np.concatenate(all_data_list[i*split_len:i*split_len+split_len])
- # array2 = np.concatenate(all_data_label_list[i*split_len:i*split_len+split_len])
- # all_data_list_1.append(array1)
- # all_data_list_2.append(array2)
- # all_data_list = np.concatenate(all_data_list_1)
- # all_data_label_list = np.concatenate(all_data_list_2)
- return all_data_list, all_data_label_list
- def table_process(text_list, label_list, _id):
- if len(text_list) != len(label_list):
- print("文字单元格与标注单元格数量不匹配!", _id)
- print("len(text_list)", len(text_list), "len(label_list)", len(label_list))
- return [], []
- data_list = []
- data_label_list = []
- for i in range(len(text_list)):
- row = text_list[i]
- row_label = label_list[i]
- if i > 0:
- last_row = text_list[i-1]
- last_row_label = label_list[i-1]
- else:
- last_row = []
- last_row_label = []
- if i < len(text_list) - 1:
- next_row = text_list[i+1]
- next_row_label = label_list[i+1]
- else:
- next_row = []
- next_row_label = []
- for j in range(len(row)):
- col = row[j]
- col_label = row_label[j]
- # 超出表格置为None, 0
- if j > 0:
- last_col = row[j-1]
- last_col_label = row_label[j-1]
- else:
- last_col = None
- last_col_label = 0
- if j < len(row) - 1:
- next_col = row[j+1]
- next_col_label = row_label[j+1]
- else:
- next_col = None
- next_col_label = 0
- if last_row:
- last_row_col = last_row[j]
- last_row_col_label = last_row_label[j]
- else:
- last_row_col = None
- last_row_col_label = 0
- if next_row:
- next_row_col = next_row[j]
- next_row_col_label = next_row_label[j]
- else:
- next_row_col = None
- next_row_col_label = 0
- # 三元组有一对不相等就作为数据
- # if col != next_col or col != last_col:
- data_list.append([last_col, col, next_col])
- data_label_list.append([int(last_col_label), int(col_label),
- int(next_col_label)])
- # if col != next_row_col or col != last_row_col:
- data_list.append([last_row_col, col, next_row_col])
- data_label_list.append([int(last_row_col_label), int(col_label),
- int(next_row_col_label)])
- return data_list, data_label_list
- def get_data_from_file(file_type):
- if file_type == 'np':
- data_path = 'train_data/data_3.npy'
- data_label_path = 'train_data/data_label_3.npy'
- array1 = np.load(data_path)
- array2 = np.load(data_label_path)
- return array1, array2
- elif file_type == 'txt':
- data_path = 'train_data/data.txt'
- data_label_path = 'train_data/data_label.txt'
- with open(data_path, 'r') as f:
- data_list = f.readlines()
- with open(data_label_path, 'r') as f:
- data_label_list = f.readlines()
- # for i in range(len(data_list)):
- # data_list[i] = eval(data_list[i][:-1])
- # data_label_list[i] = eval(data_label_list[i][:-1])
- return data_list, data_label_list
- else:
- print("file type error! only np and txt supported")
- raise Exception
- def processed_save_to_np():
- array1, array2 = get_data_from_sql()
- np.save('train_data/data_3.npy', array1)
- np.save('train_data/data_label_3.npy', array2)
- # with open('train_data/data.txt', 'w') as f:
- # for line in list1:
- # f.write(str(line) + "\n")
- # with open('train_data/data_label.txt', 'w') as f:
- # for line in list2:
- # f.write(str(line) + "\n")
- def processed_save_to_txt():
- list1, list2 = get_data_from_sql()
- with open('train_data/data.txt', 'w') as f:
- for line in list1:
- f.write(str(line) + "\n")
- with open('train_data/data_label.txt', 'w') as f:
- for line in list2:
- f.write(str(line) + "\n")
- def data_balance():
- array1, array2 = get_data_from_file()
- data_list = array2.tolist()
- all_cnt = len(data_list)
- cnt_0 = 0
- cnt_1 = 0
- for data in data_list:
- if data[0] == 1 or data[1] == 1:
- cnt_1 += 1
- else:
- cnt_0 += 1
- print("all_cnt", all_cnt)
- print("label has 1", cnt_1)
- print("label all 0", cnt_0)
- def test_embedding():
- output_shape = (2, 1, 60)
- data = [[None], [None]]
- result = embedding_word(data, output_shape)
- print(result)
- def my_data_loader(data_list, data_label_list, batch_size):
- data_num = len(data_list)
- # 定义Embedding输出
- output_shape = (3, 10, 60)
- # batch循环取数据
- i = 0
- while True:
- new_data_list = []
- for j in range(batch_size):
- if i >= data_num:
- i = 0
- # 中文字符映射为Embedding
- data = eval(data_list[i][:-1])
- data = embedding_word(data, output_shape)
- if data.shape == output_shape:
- new_data_list.append(data)
- i += 1
- new_data_list = np.array(new_data_list)
- data_label_list = np.array(data_label_list)
- X = new_data_list
- Y = data_label_list
- # (table_num, 3 sentences, dim characters, embedding) -> (3, table_num, dim, embedding)
- X = np.transpose(X, (1, 0, 2, 3))
- yield [X[0], X[1], X[2]], Y
- if __name__ == '__main__':
- processed_save_to_txt()
- # data_balance()
- # test_embedding()
|