Przeglądaj źródła

Merge remote-tracking branch 'origin/master'

luojiehua 2 lat temu
rodzic
commit
db98d33d2d
1 zmienionych plików z 4 dodań i 1 usunięć
  1. 4 1
      BiddingKG/dl/table_head/predict.py

+ 4 - 1
BiddingKG/dl/table_head/predict.py

@@ -26,7 +26,10 @@ else:
     output_shape = (None, None)
 keras_model_path = os.path.abspath(os.path.dirname(__file__)) + "/best.hdf5"
 # keras模型加载预测都使用同一个session、同一个graph,即可多进程推理
-sess = tf.Session(graph=tf.Graph())
+session_conf = tf.ConfigProto(
+    intra_op_parallelism_threads=5,
+    inter_op_parallelism_threads=5)
+sess = tf.Session(graph=tf.Graph(), config=session_conf)
 # graph = tf.get_default_graph()
 
 # tf_model_path = os.path.abspath(os.path.dirname(__file__)) + '/best_pb/1'