فهرست منبع

表头识别推理优化,更改batch_size加速

fangjiasheng 2 سال پیش
والد
کامیت
dd1f417ddf
3فایلهای تغییر یافته به همراه42 افزوده شده و 31 حذف شده
  1. 1 8
      BiddingKG/dl/interface/Preprocessing.py
  2. 31 8
      BiddingKG/dl/table_head/pre_process.py
  3. 10 15
      BiddingKG/dl/table_head/predict.py

+ 1 - 8
BiddingKG/dl/interface/Preprocessing.py

@@ -430,15 +430,8 @@ def tableToText(soup):
 
         # 模型预测表头
         predict_list = predict(inner_table)
-        # with open(r"C:\Users\Administrator\Desktop\table_head_test.txt", "a") as f:
-        #     for i in range(len(predict_list)):
-        #         f.write(str(i) + " " + str(inner_table[i]) + "\n")
-        #         f.write(str(i) + " " + str(predict_list[i]) + "\n")
-        #     f.write("\n")
-
-        # print("table_list", inner_table)
-        # print("predict_list", predict_list)
 
+        # 组合结果
         for i in range(len(inner_table)):
             for j in range(len(inner_table[i])):
                 inner_table[i][j] = [inner_table[i][j], int(predict_list[i][j])]

+ 31 - 8
BiddingKG/dl/table_head/pre_process.py

@@ -441,9 +441,7 @@ def my_data_loader(data_list, data_label_list, batch_size, is_train=True):
 
     else:
         new_data_list = []
-        for j in range(batch_size):
-            if i >= data_num:
-                i = 0
+        for j in range(len(data_list)):
             # 中文字符映射为Embedding
             data = data_list[i]
             data = embedding_word(data, output_shape)
@@ -451,11 +449,36 @@ def my_data_loader(data_list, data_label_list, batch_size, is_train=True):
                 new_data_list.append(data)
             i += 1
 
-        new_data_list = np.array(new_data_list)
-        X = new_data_list
-        X = np.transpose(X, (1, 0, 2, 3))
-        yield {'input_1': X[0], 'input_2': X[1], 'input_3': X[2],
-               'input_4': X[3], 'input_5': X[4], 'input_6': X[5], }
+        for j in range(0, len(data_list), batch_size):
+            sub_data_list = np.array(new_data_list[j: j+batch_size])
+            X = sub_data_list
+            X = np.transpose(X, (1, 0, 2, 3))
+            # print(X)
+            # return X
+            yield {'input_1': X[0], 'input_2': X[1], 'input_3': X[2],
+                   'input_4': X[3], 'input_5': X[4], 'input_6': X[5], }
+
+
+def my_data_loader_predict(data_list, data_label_list, batch_size):
+    data_num = len(data_list)
+
+    # 定义Embedding输出
+    output_shape = (6, 20, 60)
+
+    i = 0
+    new_data_list = []
+    for j in range(len(data_list)):
+        # 中文字符映射为Embedding
+        data = data_list[i]
+        data = embedding_word(data, output_shape)
+        if data.shape == output_shape:
+            new_data_list.append(data)
+        i += 1
+
+    sub_data_list = np.array(new_data_list)
+    X = sub_data_list
+    X = np.transpose(X, (1, 0, 2, 3))
+    return X
 
 
 def my_data_loader_2(table_list, table_label_list, batch_size, is_train=True):

+ 10 - 15
BiddingKG/dl/table_head/predict.py

@@ -1,6 +1,7 @@
 #coding:utf-8
 import copy
 import json
+import math
 import os
 import sys
 import time
@@ -11,7 +12,7 @@ from flask import Flask
 sys.path.append(os.path.abspath(os.path.dirname(__file__)))
 from models.model import get_model
 from post_process import table_post_process, table_post_process_2
-from pre_process import my_data_loader, table_pre_process, table_pre_process_2, my_data_loader_2
+from pre_process import my_data_loader, table_pre_process, table_pre_process_2, my_data_loader_2, my_data_loader_predict
 
 # from BiddingKG.dl.interface.Preprocessing import tableToText, segment
 
@@ -48,8 +49,6 @@ sess = tf.Session(graph=tf.Graph())
 
 
 def predict(table_text_list, model_id=1):
-    start_time = time.time()
-
     if globals().get("model") is None:
         print("="*15, "init table_head model", "="*15)
         with sess.as_default():
@@ -58,7 +57,6 @@ def predict(table_text_list, model_id=1):
                 # load weights
                 model.load_weights(keras_model_path)
         globals()["model"] = model
-        # print("="*15, "finish init", "="*15)
     else:
         model = globals().get("model")
 
@@ -69,32 +67,29 @@ def predict(table_text_list, model_id=1):
         data_list = table_pre_process(table_text_list_copy, [], 0, is_train=False)
     else:
         data_list = table_pre_process_2(table_text_list_copy, [], 0, is_train=False, padding=False)
-    batch_size = len(data_list)
-    # print("batch_size", batch_size)
-    # print("data_list", data_list)
 
     # 数据预处理
+    batch_size = len(data_list)
     if model_id == 1:
-        predict_x = my_data_loader(data_list, [], batch_size, is_train=False)
+        predict_x = my_data_loader_predict(data_list, [], batch_size)
     else:
         predict_x = my_data_loader_2(data_list, [], 1, is_train=False)
-    # print(time.time()-start_time)
 
     # 预测
-    # with graph.as_default():
     with sess.as_default():
         with sess.graph.as_default():
-            predict_result = model.predict_generator(predict_x, steps=1)
-    # print("predict_result", predict_result)
-    # print(time.time()-start_time)
-    # print("predict_result", predict_result.shape)
+            # predict_result = model.predict_generator(predict_x, steps=steps)
+            # 设置batch size为1最快,默认为32很慢
+            predict_result = model.predict([predict_x[0], predict_x[1], predict_x[2],
+                                            predict_x[3], predict_x[4], predict_x[5]],
+                                           batch_size=1)
 
     # 数据后处理
     if model_id == 1:
         table_label_list = table_post_process(table_text_list_copy, predict_result)
     else:
         table_label_list = table_post_process_2(table_text_list_copy, predict_result)
-    # print(time.time()-start_time)
+
     # 打印保存结构
     # save_print_result(table_text_list, table_label_list)
     return table_label_list