#coding:utf-8 import codecs import copy import os import re import sys from bs4 import BeautifulSoup sys.path.append(os.path.dirname(__file__)) from models.model import get_model from post_process import table_post_process, table_post_process_2 from pre_process import my_data_loader, table_pre_process, table_pre_process_2, my_data_loader_2 # from BiddingKG.dl.interface.Preprocessing import tableToText, segment # init model model_id = 1 if model_id == 1: input_shape = (6, 20, 60) output_shape = (1,) else: input_shape = (None, None, 20, 60) output_shape = (None, None) model = get_model(input_shape, output_shape, model_id=model_id) # load weights model_path = os.path.dirname(__file__) + "/best.hdf5" model.load_weights(model_path) def predict(table_text_list, model_id=1): table_text_list_copy = copy.deepcopy(table_text_list) # 表格单元格数*2 即为单次预测batch_size if model_id == 1: data_list = table_pre_process(table_text_list_copy, [], 0, is_train=False) else: data_list = table_pre_process_2(table_text_list_copy, [], 0, is_train=False, padding=False) batch_size = len(data_list) # print("batch_size", batch_size) # print("data_list", data_list) # 数据预处理 if model_id == 1: predict_x = my_data_loader(data_list, [], batch_size, is_train=False) else: predict_x = my_data_loader_2(data_list, [], 1, is_train=False) # 预测 predict_result = model.predict_generator(predict_x, steps=1) # print("predict_result", predict_result.shape) # 数据后处理 if model_id == 1: table_label_list = table_post_process(table_text_list_copy, predict_result) else: table_label_list = table_post_process_2(table_text_list_copy, predict_result) # 打印保存结构 # save_print_result(table_text_list, table_label_list) return table_label_list def save_print_result(table_text_list, table_label_list): print("table_head predict") file_path = "C:\\Users\\Administrator\\Desktop\\table_head_test.txt" write_list = [] for k in range(len(table_label_list)): print(k, table_text_list[k]) print(k, table_label_list[k]) write_list.append(str(k) + " " + str(table_text_list[k])) write_list.append("\n") write_list.append(str(k) + " " + str(table_label_list[k])) write_list.append("\n") write_list.append("\n") with open(file_path, "a+") as f: f.writelines(write_list) # def predict_html(_path): # def get_trs(tbody): # #获取所有的tr # trs = [] # objs = tbody.find_all(recursive=False) # for obj in objs: # if obj.name=="tr": # trs.append(obj) # if obj.name=="tbody": # for tr in obj.find_all("tr",recursive=False): # trs.append(tr) # return trs # # def get_table(tbody): # trs = get_trs(tbody) # inner_table = [] # for tr in trs: # tr_line = [] # tds = tr.findChildren(['td', 'th'], recursive=False) # if len(tds) == 0: # tr_line.append(re.sub('\xa0', '', segment(tr,final=False))) # 2021/12/21 修复部分表格没有td 造成数据丢失 # for td in tds: # tr_line.append(re.sub('\xa0', '', segment(td,final=False))) # inner_table.append(tr_line) # return inner_table # # def fix_table(inner_table, fix_value=""): # maxWidth = 0 # for item in inner_table: # if len(item)>maxWidth: # maxWidth = len(item) # for i in range(len(inner_table)): # if len(inner_table[i])