import psycopg2 import numpy as np from BiddingKG.dl.common.Utils import embedding_word def get_sentence_index_list(sentence, dict_path='utils/ppocr_keys_v1.txt'): with open(dict_path, 'r') as f: character_list = f.readlines() for i in range(len(character_list)): character_list[i] = character_list[i][:-1] index_list = [] for character in sentence: if character == '': index_list.append(0) elif character in character_list: _index = character_list.index(character) + 1 index_list.append(_index) else: index_list.append(0) return index_list def postgresql_util(sql, limit): conn = psycopg2.connect(dbname="table_head_label", user="postgres", password="postgres", host="192.168.2.103") cursor = conn.cursor() cursor.execute(sql) print(sql) rows = cursor.fetchmany(1000) cnt = 0 all_rows = [] while rows: if cnt >= limit: break all_rows += rows cnt += len(rows) rows = cursor.fetchmany(1000) return all_rows def get_data_from_sql(dim=10): sql = """ select table_text, pre_label, post_label, id from label_table_head_info where update_user <> 'test27' and table_box_cnt >= 4 and table_box_cnt <= 200 ; """ # sql = """ # select table_text, pre_label, post_label, id # from label_table_head_info # where id = 843 # """ result_list = postgresql_util(sql, limit=1000000) all_data_list = [] all_data_label_list = [] i = 0 # 一行就是一篇表格 for table in result_list: i += 1 if i % 100 == 0: print("Loop", i) pre_label = eval(table[1]) post_label = eval(table[2]) _id = table[3] # table_text需要特殊处理 try: table_text = table[0] if table_text[0] == '"': table_text = eval(table_text) else: table_text = table_text table_text = table_text.replace('\\', '/') table_text = eval(table_text) except: print("无法识别table_text", _id) continue # 只有一行的也不要 if len(post_label) >= 2: data_list, data_label_list = table_pre_process(table_text, post_label, _id) elif len(pre_label) >= 2: data_list, data_label_list = table_pre_process(table_text, pre_label, _id) else: data_list, data_label_list = [], [] all_data_list += data_list all_data_label_list += data_label_list print("len(all_data_list)", len(all_data_list)) # # new_data_list = [] # for data in data_list: # # 中文字符映射为index # # data[0] = get_sentence_index_list(data[0]) # # data[1] = get_sentence_index_list(data[1]) # # 维度不够,填充掩码0 # # if len(data[0]) < dim: # # data[0] = data[0] + [0]*(dim-len(data[0])) # # elif len(data[0]) > dim: # # data[0] = data[0][:dim] # # if len(data[1]) < dim: # # data[1] = data[1] + [0]*(dim-len(data[1])) # # elif len(data[1]) > dim: # # data[1] = data[1][:dim] # # # 中文字符映射为Embedding # data = embedding_word(data, input_shape) # new_data_list.append(data) # # new_data_list = np.array(new_data_list) # data_label_list = np.array(data_label_list) # if np.array(new_data_list).shape[1:] == input_shape: # all_data_list.append(new_data_list) # all_data_label_list.append(data_label_list) # # 防止concat太慢 # split_len = 1000 # _len = int(len(all_data_list) / split_len) # all_data_list_1 = [] # all_data_list_2 = [] # for i in range(_len): # if i == _len - 1: # array1 = np.concatenate(all_data_list[i*split_len:]) # array2 = np.concatenate(all_data_label_list[i*split_len:]) # else: # array1 = np.concatenate(all_data_list[i*split_len:i*split_len+split_len]) # array2 = np.concatenate(all_data_label_list[i*split_len:i*split_len+split_len]) # all_data_list_1.append(array1) # all_data_list_2.append(array2) # all_data_list = np.concatenate(all_data_list_1) # all_data_label_list = np.concatenate(all_data_list_2) return all_data_list, all_data_label_list def table_pre_process(text_list, label_list, _id, is_train=True): """ 表格处理,每个单元格生成2条数据,横竖各1条 :param text_list: :param label_list: :param _id: :param is_train: :return: """ if is_train: if len(text_list) != len(label_list): print("文字单元格与标注单元格数量不匹配!", _id) print("len(text_list)", len(text_list), "len(label_list)", len(label_list)) return [], [] data_list = [] data_label_list = [] for i in range(len(text_list)): row = text_list[i] if is_train: row_label = label_list[i] if i > 0: last_row = text_list[i-1] if is_train: last_row_label = label_list[i-1] else: last_row = [] if is_train: last_row_label = [] if i < len(text_list) - 1: next_row = text_list[i+1] if is_train: next_row_label = label_list[i+1] else: next_row = [] if is_train: next_row_label = [] for j in range(len(row)): col = row[j] if is_train: col_label = row_label[j] # 超出表格置为None, 0 if j > 0: last_col = row[j-1] if is_train: last_col_label = row_label[j-1] else: last_col = None if is_train: last_col_label = 0 if j < len(row) - 1: next_col = row[j+1] if is_train: next_col_label = row_label[j+1] else: next_col = None if is_train: next_col_label = 0 if last_row: last_row_col = last_row[j] if is_train: last_row_col_label = last_row_label[j] else: last_row_col = None if is_train: last_row_col_label = 0 if next_row: next_row_col = next_row[j] if is_train: next_row_col_label = next_row_label[j] else: next_row_col = None if is_train: next_row_col_label = 0 # 三元组有一对不相等就作为数据 # if col != next_col or col != last_col: data_list.append([last_col, col, next_col]) if is_train: data_label_list.append([int(last_col_label), int(col_label), int(next_col_label)]) # if col != next_row_col or col != last_row_col: data_list.append([last_row_col, col, next_row_col]) if is_train: data_label_list.append([int(last_row_col_label), int(col_label), int(next_row_col_label)]) if is_train: return data_list, data_label_list else: return data_list def get_data_from_file(file_type): if file_type == 'np': data_path = 'train_data/data_3.npy' data_label_path = 'train_data/data_label_3.npy' array1 = np.load(data_path) array2 = np.load(data_label_path) return array1, array2 elif file_type == 'txt': data_path = 'train_data/data.txt' data_label_path = 'train_data/data_label.txt' with open(data_path, 'r') as f: data_list = f.readlines() with open(data_label_path, 'r') as f: data_label_list = f.readlines() # for i in range(len(data_list)): # data_list[i] = eval(data_list[i][:-1]) # data_label_list[i] = eval(data_label_list[i][:-1]) return data_list, data_label_list else: print("file type error! only np and txt supported") raise Exception def processed_save_to_np(): array1, array2 = get_data_from_sql() np.save('train_data/data_3.npy', array1) np.save('train_data/data_label_3.npy', array2) # with open('train_data/data.txt', 'w') as f: # for line in list1: # f.write(str(line) + "\n") # with open('train_data/data_label.txt', 'w') as f: # for line in list2: # f.write(str(line) + "\n") def processed_save_to_txt(): list1, list2 = get_data_from_sql() with open('train_data/data.txt', 'w') as f: for line in list1: f.write(str(line) + "\n") with open('train_data/data_label.txt', 'w') as f: for line in list2: f.write(str(line) + "\n") def data_balance(): array1, array2 = get_data_from_file() data_list = array2.tolist() all_cnt = len(data_list) cnt_0 = 0 cnt_1 = 0 for data in data_list: if data[0] == 1 or data[1] == 1: cnt_1 += 1 else: cnt_0 += 1 print("all_cnt", all_cnt) print("label has 1", cnt_1) print("label all 0", cnt_0) def test_embedding(): output_shape = (2, 1, 60) data = [[None], [None]] result = embedding_word(data, output_shape) print(result) def my_data_loader(data_list, data_label_list, batch_size, is_train=True): data_num = len(data_list) # 定义Embedding输出 output_shape = (3, 10, 60) # batch循环取数据 i = 0 if is_train: while True: new_data_list = [] new_data_label_list = [] for j in range(batch_size): if i >= data_num: i = 0 # 中文字符映射为Embedding data = eval(data_list[i][:-1]) data_label = eval(data_label_list[i][:-1]) data = embedding_word(data, output_shape) if data.shape == output_shape: new_data_list.append(data) new_data_label_list.append(data_label) i += 1 new_data_list = np.array(new_data_list) new_data_label_list = np.array(new_data_label_list) X = new_data_list Y = new_data_label_list # (table_num, 3 sentences, dim characters, embedding) -> (3, table_num, dim, embedding) X = np.transpose(X, (1, 0, 2, 3)) # print("input_1", X[0].shape, "input_2", X[1].shape, "input_3", X[2].shape, "Y", Y.shape) yield {'input_1': X[0], 'input_2': X[1], 'input_3': X[2]}, {'output': Y} else: while True: new_data_list = [] for j in range(batch_size): if i >= data_num: i = 0 # 中文字符映射为Embedding data = data_list[i] data = embedding_word(data, output_shape) if data.shape == output_shape: new_data_list.append(data) i += 1 new_data_list = np.array(new_data_list) X = new_data_list X = np.transpose(X, (1, 0, 2, 3)) yield {'input_1': X[0], 'input_2': X[1], 'input_3': X[2]} if __name__ == '__main__': processed_save_to_txt() # data_balance() # test_embedding()