import random import psycopg2 import numpy as np from BiddingKG.dl.common.Utils import embedding_word def get_sentence_index_list(sentence, dict_path='utils/ppocr_keys_v1.txt'): with open(dict_path, 'r') as f: character_list = f.readlines() for i in range(len(character_list)): character_list[i] = character_list[i][:-1] index_list = [] for character in sentence: if character == '': index_list.append(0) elif character in character_list: _index = character_list.index(character) + 1 index_list.append(_index) else: index_list.append(0) return index_list def postgresql_util(sql, limit): conn = psycopg2.connect(dbname="table_head_label", user="postgres", password="postgres", host="192.168.2.103") cursor = conn.cursor() cursor.execute(sql) print(sql) rows = cursor.fetchmany(1000) cnt = 0 all_rows = [] while rows: if cnt >= limit: break all_rows += rows cnt += len(rows) rows = cursor.fetchmany(1000) return all_rows def get_data_from_sql(dim=10): # sql = """ # select table_text, pre_label, post_label, id # from label_table_head_info # where update_user <> 'test27' and update_user <> 'test20' and table_box_cnt >= 4 and table_box_cnt <= 200 # ; # """ sql = """ select table_text, pre_label, post_label, id from label_table_head_info where status = 1 and update_time >= '2022-01-17' ; """ result_list = postgresql_util(sql, limit=1000000) all_data_list = [] all_data_label_list = [] i = 0 # 一行就是一篇表格 for table in result_list: i += 1 if i % 100 == 0: print("Loop", i) pre_label = eval(table[1]) post_label = eval(table[2]) _id = table[3] # table_text需要特殊处理 try: table_text = table[0] if table_text[0] == '"': table_text = eval(table_text) else: table_text = table_text table_text = table_text.replace('\\', '/') table_text = eval(table_text) except: print("无法识别table_text", _id) continue # 只有一行的也不要 if len(post_label) >= 2: data_list, data_label_list = table_pre_process(table_text, post_label, _id) elif len(pre_label) >= 2: data_list, data_label_list = table_pre_process(table_text, pre_label, _id) else: data_list, data_label_list = [], [] all_data_list += data_list all_data_label_list += data_label_list print("len(all_data_list)", len(all_data_list)) return all_data_list, all_data_label_list def table_pre_process(text_list, label_list, _id, is_train=True): """ 表格处理,每个单元格生成2条数据,横竖各1条 :param text_list: :param label_list: :param _id: :param is_train: :return: """ if is_train: if len(text_list) != len(label_list): print("文字单元格与标注单元格数量不匹配!", _id) print("len(text_list)", len(text_list), "len(label_list)", len(label_list)) return [], [] data_list = [] data_label_list = [] for i in range(len(text_list)): row = text_list[i] if is_train: row_label = label_list[i] if i > 0: last_row = text_list[i-1] if is_train: last_row_label = label_list[i-1] else: last_row = [] if is_train: last_row_label = [] if i < len(text_list) - 1: next_row = text_list[i+1] if is_train: next_row_label = label_list[i+1] else: next_row = [] if is_train: next_row_label = [] for j in range(len(row)): col = row[j] if is_train: col_label = row_label[j] # 超出表格置为None, 0 if j > 0: last_col = row[j-1] if is_train: last_col_label = row_label[j-1] else: last_col = col if is_train: last_col_label = col_label if j < len(row) - 1: next_col = row[j+1] if is_train: next_col_label = row_label[j+1] else: next_col = col if is_train: next_col_label = col_label if last_row: last_row_col = last_row[j] if is_train: last_row_col_label = last_row_label[j] else: last_row_col = col if is_train: last_row_col_label = col_label if next_row: next_row_col = next_row[j] if is_train: next_row_col_label = next_row_label[j] else: next_row_col = col if is_train: next_row_col_label = col_label # data_list.append([last_col, col, next_col]) # if is_train: # data_label_list.append([int(last_col_label), int(col_label), # int(next_col_label)]) # # data_list.append([last_row_col, col, next_row_col]) # if is_train: # data_label_list.append([int(last_row_col_label), int(col_label), # int(next_row_col_label)]) if is_train: dup_list = [str(x) for x in data_list] data = [last_col, col, next_col, last_row_col, col, next_row_col] if str(data) not in dup_list: data_list.append([last_col, col, next_col, last_row_col, col, next_row_col]) data_label_list.append(int(col_label)) else: data_list.append([last_col, col, next_col, last_row_col, col, next_row_col]) if is_train: return data_list, data_label_list else: return data_list def get_data_from_file(file_type): if file_type == 'np': data_path = 'train_data/data_3.npy' data_label_path = 'train_data/data_label_3.npy' array1 = np.load(data_path) array2 = np.load(data_label_path) return array1, array2 elif file_type == 'txt': data_path = 'train_data/data3.txt' data_label_path = 'train_data/data_label3.txt' with open(data_path, 'r') as f: data_list = f.readlines() with open(data_label_path, 'r') as f: data_label_list = f.readlines() # for i in range(len(data_list)): # data_list[i] = eval(data_list[i][:-1]) # data_label_list[i] = eval(data_label_list[i][:-1]) return data_list, data_label_list else: print("file type error! only np and txt supported") raise Exception def processed_save_to_np(): array1, array2 = get_data_from_sql() np.save('train_data/data_3.npy', array1) np.save('train_data/data_label_3.npy', array2) # with open('train_data/data.txt', 'w') as f: # for line in list1: # f.write(str(line) + "\n") # with open('train_data/data_label.txt', 'w') as f: # for line in list2: # f.write(str(line) + "\n") def processed_save_to_txt(): list1, list2 = get_data_from_sql() # 打乱 zip_list = list(zip(list1, list2)) random.shuffle(zip_list) list1[:], list2[:] = zip(*zip_list) with open('train_data/data3.txt', 'w') as f: for line in list1: f.write(str(line) + "\n") with open('train_data/data_label3.txt', 'w') as f: for line in list2: f.write(str(line) + "\n") def data_balance(): data_list, data_label_list = get_data_from_file('txt') all_cnt = len(data_label_list) cnt_0 = 0 cnt_1 = 0 for data in data_label_list: if eval(data[:-1])[1] == 1: cnt_1 += 1 else: cnt_0 += 1 print("all_cnt", all_cnt) print("label has 1", cnt_1) print("label all 0", cnt_0) def test_embedding(): output_shape = (2, 1, 60) data = [[None], [None]] result = embedding_word(data, output_shape) print(result) def my_data_loader(data_list, data_label_list, batch_size, is_train=True): data_num = len(data_list) # 定义Embedding输出 output_shape = (6, 10, 60) # batch循环取数据 i = 0 if is_train: while True: new_data_list = [] new_data_label_list = [] for j in range(batch_size): if i >= data_num: i = 0 # 中文字符映射为Embedding data = eval(data_list[i][:-1]) data_label = eval(data_label_list[i][:-1]) data = embedding_word(data, output_shape) if data.shape == output_shape: new_data_list.append(data) new_data_label_list.append(data_label) i += 1 new_data_list = np.array(new_data_list) new_data_label_list = np.array(new_data_label_list) X = new_data_list Y = new_data_label_list # (table_num, 3 sentences, dim characters, embedding) -> (3, table_num, dim, embedding) X = np.transpose(X, (1, 0, 2, 3)) if (X[0] == X[1]).all(): X[0] = np.zeros_like(X[1], dtype='float32') if (X[2] == X[1]).all(): X[2] = np.zeros_like(X[1], dtype='float32') if (X[3] == X[1]).all(): X[3] = np.zeros_like(X[1], dtype='float32') if (X[5] == X[1]).all(): X[5] = np.zeros_like(X[1], dtype='float32') yield {'input_1': X[0], 'input_2': X[1], 'input_3': X[2], 'input_4': X[3], 'input_5': X[4], 'input_6': X[5]}, \ {'output': Y} else: while True: new_data_list = [] for j in range(batch_size): if i >= data_num: i = 0 # 中文字符映射为Embedding data = data_list[i] data = embedding_word(data, output_shape) if data.shape == output_shape: new_data_list.append(data) i += 1 new_data_list = np.array(new_data_list) X = new_data_list X = np.transpose(X, (1, 0, 2, 3)) yield {'input_1': X[0], 'input_2': X[1], 'input_3': X[2], 'input_4': X[3], 'input_5': X[4], 'input_6': X[5], } if __name__ == '__main__': processed_save_to_txt() # data_balance() # test_embedding()