import os import random import sys import numpy as np sys.path.append(os.path.dirname(__file__) + "/../") from common.Utils import embedding_word, embedding_word_forward def get_sentence_index_list(sentence, dict_path='utils/ppocr_keys_v1.txt'): with open(dict_path, 'r') as f: character_list = f.readlines() for i in range(len(character_list)): character_list[i] = character_list[i][:-1] index_list = [] for character in sentence: if character == '': index_list.append(0) elif character in character_list: _index = character_list.index(character) + 1 index_list.append(_index) else: index_list.append(0) return index_list def postgresql_util(sql, limit): import psycopg2 conn = psycopg2.connect(dbname="table_head_label", user="postgres", password="postgres", host="192.168.2.103") cursor = conn.cursor() cursor.execute(sql) print(sql) rows = cursor.fetchmany(1000) cnt = 0 all_rows = [] while rows: if cnt >= limit: break all_rows += rows cnt += len(rows) rows = cursor.fetchmany(1000) return all_rows def get_data_from_sql(dim=10, whole_table=False, padding=True): sql = """ select table_text, pre_label, post_label, id from label_table_head_info where status = 0 and (update_user='test9' or update_user='test1' or update_user='test7' or update_user='test26') ; """ # sql = """ # select table_text, pre_label, post_label, id # from label_table_head_info # where status = 1 and update_time >= '2022-01-17' and update_time <= '2022-01-22' # ; # """ result_list = postgresql_util(sql, limit=1000000) # 需排除的id with open(r"C:\Users\Administrator\Desktop\table_not_eval.txt", "r") as f: delete_id_list = eval(f.read()) with open(r"C:\Users\Administrator\Desktop\table_delete.txt", "r") as f: delete_id_list += eval(f.read()) all_data_list = [] all_data_label_list = [] i = 0 # 一行就是一篇表格 for table in result_list: i += 1 if i % 100 == 0: print("Loop", i) pre_label = eval(table[1]) post_label = eval(table[2]) _id = table[3] if _id in delete_id_list: print("pass", _id) continue # table_text需要特殊处理 try: table_text = table[0] if table_text[0] == '"': table_text = eval(table_text) else: table_text = table_text table_text = table_text.replace('\\', '/') table_text = eval(table_text) except: print("无法识别table_text", _id) continue if whole_table: if len(post_label) >= 2: data_list, data_label_list = table_pre_process_2(table_text, post_label, _id, padding=padding) elif len(pre_label) >= 2: data_list, data_label_list = table_pre_process_2(table_text, pre_label, _id, padding=padding) else: data_list, data_label_list = [], [] else: # 只有一行的也不要 if len(post_label) >= 2: data_list, data_label_list = table_pre_process(table_text, post_label, _id) elif len(pre_label) >= 2: data_list, data_label_list = table_pre_process(table_text, pre_label, _id) else: data_list, data_label_list = [], [] all_data_list += data_list all_data_label_list += data_label_list # 按维度大小排序 if whole_table: _list = [] for data, label in zip(all_data_list, all_data_label_list): _list.append([data, label]) _list.sort(key=lambda x: (len(x[0]), len(x[0][0]))) all_data_list[:], all_data_label_list[:] = zip(*_list) print("len(all_data_list)", len(all_data_list)) return all_data_list, all_data_label_list def table_pre_process(text_list, label_list, _id, is_train=True): """ 表格处理,每个单元格生成2条数据,横竖各1条 :param text_list: :param label_list: :param _id: :param is_train: :return: """ if is_train: if len(text_list) != len(label_list): print("文字单元格与标注单元格数量不匹配!", _id) print("len(text_list)", len(text_list), "len(label_list)", len(label_list)) return [], [] data_list = [] data_label_list = [] for i in range(len(text_list)): row = text_list[i] if is_train: row_label = label_list[i] if i > 0: last_row = text_list[i-1] if is_train: last_row_label = label_list[i-1] else: last_row = [] if is_train: last_row_label = [] if i < len(text_list) - 1: next_row = text_list[i+1] if is_train: next_row_label = label_list[i+1] else: next_row = [] if is_train: next_row_label = [] for j in range(len(row)): col = row[j] if is_train: col_label = row_label[j] # 超出表格置为None, 0 if j > 0: last_col = row[j-1] if is_train: last_col_label = row_label[j-1] else: last_col = col if is_train: last_col_label = col_label if j < len(row) - 1: next_col = row[j+1] if is_train: next_col_label = row_label[j+1] else: next_col = col if is_train: next_col_label = col_label if last_row: last_row_col = last_row[j] if is_train: last_row_col_label = last_row_label[j] else: last_row_col = col if is_train: last_row_col_label = col_label if next_row: next_row_col = next_row[j] if is_train: next_row_col_label = next_row_label[j] else: next_row_col = col if is_train: next_row_col_label = col_label # data_list.append([last_col, col, next_col]) # if is_train: # data_label_list.append([int(last_col_label), int(col_label), # int(next_col_label)]) # # data_list.append([last_row_col, col, next_row_col]) # if is_train: # data_label_list.append([int(last_row_col_label), int(col_label), # int(next_row_col_label)]) if is_train: dup_list = [str(x) for x in data_list] data = [last_col, col, next_col, last_row_col, col, next_row_col] if str(data) not in dup_list: data_list.append([last_col, col, next_col, last_row_col, col, next_row_col]) data_label_list.append(int(col_label)) else: data_list.append([last_col, col, next_col, last_row_col, col, next_row_col]) if is_train: return data_list, data_label_list else: return data_list def table_pre_process_2(text_list, label_list, _id, is_train=True, padding=True): """ 表格处理,整个表格为一个数组,且填充长宽维度 :param text_list: :param label_list: :param _id: :param is_train: :return: """ # 判断表格长宽是否合理 row_len = len(text_list) best_row_len = get_best_padding_size(row_len, min_len=8) col_len = len(text_list[0]) best_col_len = get_best_padding_size(col_len, min_len=8) if best_row_len is None: if is_train: return [], [] else: return [] if best_col_len is None: if is_train: return [], [] else: return [] if is_train: if len(text_list) != len(label_list): print("文字单元格与标注单元格数量不匹配!", _id) print("len(text_list)", len(text_list), "len(label_list)", len(label_list)) return [], [] if padding: for i in range(row_len): col_len = len(text_list[i]) text_list[i] += [None]*(best_col_len-col_len) if is_train: label_list[i] += ["0"]*(best_col_len-col_len) text_list += [[None]*best_col_len]*(best_row_len-row_len) if is_train: label_list += [["0"]*best_col_len]*(best_row_len-row_len) if is_train: for i in range(len(label_list)): for j in range(len(label_list[i])): label_list[i][j] = int(label_list[i][j]) return [text_list], [label_list] else: return [text_list] def get_best_padding_size(axis_len, min_len=3, max_len=300): # sizes = [8, 16, 24, 32, 40, 48, 56, 64, 72, 80, 88, 96, 104, 112, 120, # 128, 136, 144, 152, 160, 168, 176, 184, 192, 200, 208, 216, 224, # 232, 240, 248, 256, 264, 272, 280, 288, 296] # sizes = [3, 6, 9, 12, 15, 18, 21, 24, 27, 30, 33, 36, 39, 42, 45, 48, 51, 54, 57, # 60, 63, 66, 69, 72, 75, 78, 81, 84, 87, 90, 93, 96, 99, 102, 105, 108, 111, # 114, 117, 120, 123, 126, 129, 132, 135, 138, 141, 144, 147, 150, 153, 156, # 159, 162, 165, 168, 171, 174, 177, 180, 183, 186, 189, 192, 195, 198, 201, # 204, 207, 210, 213, 216, 219, 222, 225, 228, 231, 234, 237, 240, 243, 246, # 249, 252, 255, 258, 261, 264, 267, 270, 273, 276, 279, 282, 285, 288, 291, # 294, 297] sizes = [] for i in range(1, max_len): if i * min_len <= max_len: sizes.append(i * min_len) if axis_len > sizes[-1]: return axis_len best_len = sizes[-1] for height in sizes: if axis_len <= height: best_len = height break # print("get_best_padding_size", axis_len, best_len) return best_len def get_data_from_file(file_type, model_id=1): if file_type == 'np': data_path = 'train_data/data_3.npy' data_label_path = 'train_data/data_label_3.npy' array1 = np.load(data_path) array2 = np.load(data_label_path) return array1, array2 elif file_type == 'txt': if model_id == 1: data_path = 'train_data/data1.txt' data_label_path = 'train_data/data_label1.txt' elif model_id == 2: data_path = 'train_data/data2.txt' data_label_path = 'train_data/data_label2.txt' elif model_id == 3: data_path = 'train_data/data3.txt' data_label_path = 'train_data/data_label3.txt' with open(data_path, 'r') as f: data_list = f.readlines() with open(data_label_path, 'r') as f: data_label_list = f.readlines() return data_list, data_label_list else: print("file type error! only np and txt supported") raise Exception def processed_save_to_np(): array1, array2 = get_data_from_sql() np.save('train_data/data_3.npy', array1) np.save('train_data/data_label_3.npy', array2) # with open('train_data/data.txt', 'w') as f: # for line in list1: # f.write(str(line) + "\n") # with open('train_data/data_label.txt', 'w') as f: # for line in list2: # f.write(str(line) + "\n") def processed_save_to_txt(whole_table=False, padding=True): list1, list2 = get_data_from_sql(whole_table=whole_table, padding=padding) # 打乱 # if not whole_table or not padding: zip_list = list(zip(list1, list2)) random.shuffle(zip_list) list1[:], list2[:] = zip(*zip_list) with open('train_data/data1.txt', 'w') as f: for line in list1: f.write(str(line) + "\n") with open('train_data/data_label1.txt', 'w') as f: for line in list2: f.write(str(line) + "\n") def data_balance(): data_list, data_label_list = get_data_from_file('txt') all_cnt = len(data_label_list) cnt_0 = 0 cnt_1 = 0 for data in data_label_list: if eval(data[:-1])[1] == 1: cnt_1 += 1 else: cnt_0 += 1 print("all_cnt", all_cnt) print("label has 1", cnt_1) print("label all 0", cnt_0) def test_embedding(): output_shape = (2, 1, 60) data = [[None], [None]] result = embedding_word(data, output_shape) print(result) def my_data_loader(data_list, data_label_list, batch_size, is_train=True): data_num = len(data_list) # 定义Embedding输出 output_shape = (6, 20, 60) # batch循环取数据 i = 0 if is_train: while True: new_data_list = [] new_data_label_list = [] for j in range(batch_size): if i >= data_num: i = 0 # 中文字符映射为Embedding data = eval(data_list[i][:-1]) data_label = eval(data_label_list[i][:-1]) data = embedding_word(data, output_shape) if data.shape == output_shape: new_data_list.append(data) new_data_label_list.append(data_label) i += 1 new_data_list = np.array(new_data_list) new_data_label_list = np.array(new_data_label_list) X = new_data_list Y = new_data_label_list # (table_num, 3 sentences, dim characters, embedding) -> (3, table_num, dim, embedding) X = np.transpose(X, (1, 0, 2, 3)) if (X[0] == X[1]).all(): X[0] = np.zeros_like(X[1], dtype='float32') if (X[2] == X[1]).all(): X[2] = np.zeros_like(X[1], dtype='float32') if (X[3] == X[1]).all(): X[3] = np.zeros_like(X[1], dtype='float32') if (X[5] == X[1]).all(): X[5] = np.zeros_like(X[1], dtype='float32') yield {'input_1': X[0], 'input_2': X[1], 'input_3': X[2], 'input_4': X[3], 'input_5': X[4], 'input_6': X[5]}, \ {'output': Y} else: new_data_list = [] for j in range(len(data_list)): # 中文字符映射为Embedding data = data_list[i] data = embedding_word(data, output_shape) if data.shape == output_shape: new_data_list.append(data) i += 1 for j in range(0, len(data_list), batch_size): sub_data_list = np.array(new_data_list[j: j+batch_size]) X = sub_data_list X = np.transpose(X, (1, 0, 2, 3)) # print(X) # return X yield {'input_1': X[0], 'input_2': X[1], 'input_3': X[2], 'input_4': X[3], 'input_5': X[4], 'input_6': X[5], } def my_data_loader_predict(data_list, data_label_list, batch_size): data_num = len(data_list) # 定义Embedding输出 output_shape = (6, 20, 60) i = 0 new_data_list = [] for j in range(len(data_list)): # 中文字符映射为Embedding data = data_list[i] data = embedding_word(data, output_shape) if data.shape == output_shape: new_data_list.append(data) i += 1 sub_data_list = np.array(new_data_list) X = sub_data_list X = np.transpose(X, (1, 0, 2, 3)) return X def my_data_loader_2(table_list, table_label_list, batch_size, is_train=True): pad_len = 0 table_num = len(table_list) if is_train and batch_size == 1: table_list, table_label_list = get_random(table_list, table_label_list) # Embedding shape output_shape = (20, 60) # batch循环取数据 i = 0 last_shape = None while True: new_table_list = [] new_table_label_list = [] for j in range(batch_size): if i >= table_num: i = 0 if is_train: table_list, table_label_list = get_random(table_list, table_label_list, seed=random.randint(1, 40)) if type(table_list[i]) != list: table = eval(table_list[i][:-1]) else: table = table_list[i] if batch_size > 1: if last_shape is None: last_shape = (len(table), len(table[0])) continue if (len(table), len(table[0])) != last_shape: last_shape = (len(table), len(table[0])) break if is_train: table_label = eval(table_label_list[i][:-1]) # 中文字符映射为Embedding for k in range(len(table)): table[k] = embedding_word_forward(table[k], (len(table[k]), output_shape[0], output_shape[1])) new_table_list.append(table) if is_train: new_table_label_list.append(table_label) i += 1 new_table_list = np.array(new_table_list) X = new_table_list if X.shape[-2:] != output_shape: # print("Dimension not match!", X.shape) # print("\n") continue # 获取Padding大小 pad_height = get_best_padding_size(X.shape[1], pad_len) pad_width = get_best_padding_size(X.shape[2], pad_len) input_2 = np.zeros([1, X.shape[1], X.shape[2], pad_height, pad_width]) if is_train: new_table_label_list = np.array(new_table_label_list) Y = new_table_label_list # Y = Y.astype(np.float32) # yield {"input_1": X, "input_2": input_2}, \ # {"output_1": Y, "output_2": Y} yield {"input_1": X, "input_2": input_2}, \ {"output": Y} else: yield {"input_1": X, "input_2": input_2} def check_train_data(): data_list, label_list = get_data_from_file('txt', model_id=2) for data in data_list: data = eval(data) if len(data) % 8 != 0: print(len(data)) print(len(data[0])) for row in data: if len(row) % 8 != 0: print(len(data)) print(len(row)) def get_random(text_list, label_list, seed=42): random.seed(seed) zip_list = list(zip(text_list, label_list)) random.shuffle(zip_list) text_list[:], label_list[:] = zip(*zip_list) return text_list, label_list if __name__ == '__main__': processed_save_to_txt(whole_table=False, padding=False) # data_balance() # test_embedding() # check_train_data() # _list = [] # for i in range(1, 100): # _list.append(i*3) # print(_list) # print(get_best_padding_size(9, 5))