import codecs import re from bs4 import BeautifulSoup from BiddingKG.dl.table_head.models.model import get_model from BiddingKG.dl.table_head.post_process import table_post_process from BiddingKG.dl.table_head.pre_process import my_data_loader, table_pre_process from BiddingKG.dl.interface.Preprocessing import tableToText, segment # init model input_shape = (6, 10, 60) output_shape = (1,) model = get_model(input_shape, output_shape) # load weights model_path = "checkpoints/best.hdf5" model.load_weights(model_path) def predict(table_text_list): # 表格单元格数*2 即为单次预测batch_size data_list = table_pre_process(table_text_list, [], 0, is_train=False) batch_size = len(data_list) # print("batch_size", batch_size) # 数据预处理 predict_x = my_data_loader(data_list, [], batch_size, is_train=False) # 预测 predict_result = model.predict_generator(predict_x, steps=1) # print("predict_result", predict_result.shape) # 数据后处理 table_label_list = table_post_process(table_text_list, predict_result) return table_label_list def predict_html(): def get_trs(tbody): #获取所有的tr trs = [] objs = tbody.find_all(recursive=False) for obj in objs: if obj.name=="tr": trs.append(obj) if obj.name=="tbody": for tr in obj.find_all("tr",recursive=False): trs.append(tr) return trs def get_table(tbody): trs = get_trs(tbody) inner_table = [] for tr in trs: tr_line = [] tds = tr.findChildren(['td', 'th'], recursive=False) if len(tds) == 0: tr_line.append(re.sub('\xa0', '', segment(tr,final=False))) # 2021/12/21 修复部分表格没有td 造成数据丢失 for td in tds: tr_line.append(re.sub('\xa0', '', segment(td,final=False))) inner_table.append(tr_line) return inner_table def fix_table(inner_table, fix_value=""): maxWidth = 0 for item in inner_table: if len(item)>maxWidth: maxWidth = len(item) for i in range(len(inner_table)): if len(inner_table[i])