|
@@ -76,13 +76,15 @@ def predict(table_text_list, model_id=1):
|
|
|
predict_x = my_data_loader_2(data_list, [], 1, is_train=False)
|
|
|
|
|
|
# 预测
|
|
|
+ # start_time = time.time()
|
|
|
with sess.as_default():
|
|
|
with sess.graph.as_default():
|
|
|
- # predict_result = model.predict_generator(predict_x, steps=steps)
|
|
|
+ # predict_result = model.predict_generator(predict_x, steps=1)
|
|
|
# 设置batch size为1最快,默认为32很慢
|
|
|
predict_result = model.predict([predict_x[0], predict_x[1], predict_x[2],
|
|
|
predict_x[3], predict_x[4], predict_x[5]],
|
|
|
- batch_size=1)
|
|
|
+ batch_size=256)
|
|
|
+ # print("table head predict time", time.time()-start_time, predict_x.shape)
|
|
|
|
|
|
# 数据后处理
|
|
|
if model_id == 1:
|