ソースを参照

不同TensorFlow版本合适的batch_size不同,表头的改为256

fangjiasheng 2 年 前
コミット
e22471d035
1 ファイル変更4 行追加2 行削除
  1. 4 2
      BiddingKG/dl/table_head/predict.py

+ 4 - 2
BiddingKG/dl/table_head/predict.py

@@ -76,13 +76,15 @@ def predict(table_text_list, model_id=1):
         predict_x = my_data_loader_2(data_list, [], 1, is_train=False)
 
     # 预测
+    # start_time = time.time()
     with sess.as_default():
         with sess.graph.as_default():
-            # predict_result = model.predict_generator(predict_x, steps=steps)
+            # predict_result = model.predict_generator(predict_x, steps=1)
             # 设置batch size为1最快,默认为32很慢
             predict_result = model.predict([predict_x[0], predict_x[1], predict_x[2],
                                             predict_x[3], predict_x[4], predict_x[5]],
-                                           batch_size=1)
+                                           batch_size=256)
+    # print("table head predict time", time.time()-start_time, predict_x.shape)
 
     # 数据后处理
     if model_id == 1: