|
@@ -1,6 +1,7 @@
|
|
|
#coding:utf-8
|
|
|
import copy
|
|
|
import json
|
|
|
+import math
|
|
|
import os
|
|
|
import sys
|
|
|
import time
|
|
@@ -11,7 +12,7 @@ from flask import Flask
|
|
|
sys.path.append(os.path.abspath(os.path.dirname(__file__)))
|
|
|
from models.model import get_model
|
|
|
from post_process import table_post_process, table_post_process_2
|
|
|
-from pre_process import my_data_loader, table_pre_process, table_pre_process_2, my_data_loader_2
|
|
|
+from pre_process import my_data_loader, table_pre_process, table_pre_process_2, my_data_loader_2, my_data_loader_predict
|
|
|
|
|
|
# from BiddingKG.dl.interface.Preprocessing import tableToText, segment
|
|
|
|
|
@@ -48,8 +49,6 @@ sess = tf.Session(graph=tf.Graph())
|
|
|
|
|
|
|
|
|
def predict(table_text_list, model_id=1):
|
|
|
- start_time = time.time()
|
|
|
-
|
|
|
if globals().get("model") is None:
|
|
|
print("="*15, "init table_head model", "="*15)
|
|
|
with sess.as_default():
|
|
@@ -58,7 +57,6 @@ def predict(table_text_list, model_id=1):
|
|
|
# load weights
|
|
|
model.load_weights(keras_model_path)
|
|
|
globals()["model"] = model
|
|
|
- # print("="*15, "finish init", "="*15)
|
|
|
else:
|
|
|
model = globals().get("model")
|
|
|
|
|
@@ -69,32 +67,29 @@ def predict(table_text_list, model_id=1):
|
|
|
data_list = table_pre_process(table_text_list_copy, [], 0, is_train=False)
|
|
|
else:
|
|
|
data_list = table_pre_process_2(table_text_list_copy, [], 0, is_train=False, padding=False)
|
|
|
- batch_size = len(data_list)
|
|
|
- # print("batch_size", batch_size)
|
|
|
- # print("data_list", data_list)
|
|
|
|
|
|
# 数据预处理
|
|
|
+ batch_size = len(data_list)
|
|
|
if model_id == 1:
|
|
|
- predict_x = my_data_loader(data_list, [], batch_size, is_train=False)
|
|
|
+ predict_x = my_data_loader_predict(data_list, [], batch_size)
|
|
|
else:
|
|
|
predict_x = my_data_loader_2(data_list, [], 1, is_train=False)
|
|
|
- # print(time.time()-start_time)
|
|
|
|
|
|
# 预测
|
|
|
- # with graph.as_default():
|
|
|
with sess.as_default():
|
|
|
with sess.graph.as_default():
|
|
|
- predict_result = model.predict_generator(predict_x, steps=1)
|
|
|
- # print("predict_result", predict_result)
|
|
|
- # print(time.time()-start_time)
|
|
|
- # print("predict_result", predict_result.shape)
|
|
|
+ # predict_result = model.predict_generator(predict_x, steps=steps)
|
|
|
+ # 设置batch size为1最快,默认为32很慢
|
|
|
+ predict_result = model.predict([predict_x[0], predict_x[1], predict_x[2],
|
|
|
+ predict_x[3], predict_x[4], predict_x[5]],
|
|
|
+ batch_size=1)
|
|
|
|
|
|
# 数据后处理
|
|
|
if model_id == 1:
|
|
|
table_label_list = table_post_process(table_text_list_copy, predict_result)
|
|
|
else:
|
|
|
table_label_list = table_post_process_2(table_text_list_copy, predict_result)
|
|
|
- # print(time.time()-start_time)
|
|
|
+
|
|
|
# 打印保存结构
|
|
|
# save_print_result(table_text_list, table_label_list)
|
|
|
return table_label_list
|