import sys import os sys.path.append(os.path.abspath("../..")) import psycopg2 import numpy as np def get_sentence_index_list(sentence, dict_path='utils/ppocr_keys_v1.txt'): with open(dict_path, 'r') as f: character_list = f.readlines() for i in range(len(character_list)): character_list[i] = character_list[i][:-1] index_list = [] for character in sentence: if character == '': index_list.append(0) elif character in character_list: _index = character_list.index(character) + 1 index_list.append(_index) else: index_list.append(0) return index_list def postgresql_util(sql, limit): conn = psycopg2.connect(dbname="table_head_label", user="postgres", password="postgres", host="192.168.2.103") cursor = conn.cursor() cursor.execute(sql) print(sql) rows = cursor.fetchmany(1000) cnt = 0 all_rows = [] while rows: if cnt >= limit: break all_rows += rows cnt += len(rows) rows = cursor.fetchmany(1000) return all_rows def get_data_from_sql(dim=10): sql = """ select table_text, pre_label, post_label, id from label_table_head_info where update_user <> 'test27' and table_box_cnt >= 4 and table_box_cnt <= 200 limit 1000; """ # sql = """ # select table_text, pre_label, post_label, id # from label_table_head_info # where id = 843 # """ result_list = postgresql_util(sql, limit=10000) all_data_list = [] all_data_label_list = [] i = 0 # 一行就是一篇表格 for table in result_list: i += 1 if i % 100 == 0: print("Loop", i) pre_label = eval(table[1]) post_label = eval(table[2]) _id = table[3] # table_text需要特殊处理 try: table_text = table[0] if table_text[0] == '"': table_text = eval(table_text) else: table_text = table_text table_text = table_text.replace('\\', '/') table_text = eval(table_text) except: print("无法识别table_text", _id) continue # 只有一行的也不要 if len(post_label) >= 2: data_list, data_label_list = table_process(table_text, post_label, _id) elif len(pre_label) >= 2: data_list, data_label_list = table_process(table_text, pre_label, _id) else: data_list, data_label_list = [], [] for data in data_list: # 中文字符映射为index data[0] = get_sentence_index_list(data[0]) data[1] = get_sentence_index_list(data[1]) # 维度不够,填充掩码0 if len(data[0]) < dim: data[0] = data[0] + [0]*(dim-len(data[0])) elif len(data[0]) > dim: data[0] = data[0][:dim] if len(data[1]) < dim: data[1] = data[1] + [0]*(dim-len(data[1])) elif len(data[1]) > dim: data[1] = data[1][:dim] all_data_list += data_list all_data_label_list += data_label_list return all_data_list, all_data_label_list def table_process(text_list, label_list, _id): if len(text_list) != len(label_list): print("文字单元格与标注单元格数量不匹配!", _id) print("len(text_list)", len(text_list), "len(label_list)", len(label_list)) return [], [] data_list = [] data_label_list = [] for i in range(len(text_list)): row = text_list[i] row_label = label_list[i] if i < len(text_list) - 1: next_row = text_list[i+1] next_row_label = label_list[i+1] else: next_row = [] next_row_label = [] for j in range(len(row)): col = row[j] col_label = row_label[j] if j < len(row) - 1: next_col = row[j+1] next_col_label = row_label[j+1] else: next_col = "" next_col_label = "" if next_row: next_row_col = next_row[j] next_row_col_label = next_row_label[j] else: next_row_col = "" next_row_col_label = "" if next_col: if col != next_col: data_list.append([col, next_col]) data_label_list.append([int(col_label), int(next_col_label)]) if next_row_col: if col != next_row_col: data_list.append([col, next_row_col]) data_label_list.append([int(col_label), int(next_row_col_label)]) return data_list, data_label_list def get_data_from_file(): data_path = 'train_data/data.txt' data_label_path = 'train_data/data_label.txt' with open(data_path, 'r') as f: data_list = f.readlines() with open(data_label_path, 'r') as f: data_label_list = f.readlines() for i in range(len(data_list)): data_list[i] = eval(data_list[i][:-1]) data_label_list[i] = eval(data_label_list[i][:-1]) print(len(data_list)) return data_list, data_label_list def processed_save_to_txt(): list1, list2 = get_data_from_sql() with open('train_data/data.txt', 'w') as f: for line in list1: f.write(str(line) + "\n") with open('train_data/data_label.txt', 'w') as f: for line in list2: f.write(str(line) + "\n") if __name__ == '__main__': get_data_from_file()