浏览代码

update otr model

fangjiasheng 3 年之前
父节点
当前提交
17423f117c

+ 14 - 6
format_convert/convert.py

@@ -358,7 +358,7 @@ def cut_str(text_list, only_text_list, max_bytes_length=2000000):
 
 
 @memory_decorator
-def convert(data, ocr_model, otr_model):
+def convert_maxcompute(data, ocr_model, otr_model):
     """
     接口返回值:
     {[str], 1}: 处理成功
@@ -481,14 +481,22 @@ def _convert():
     _md5 = _global.get("md5")
     _type = None
     try:
-        if not request.form:
+        _time = time.time()
+        data = request.form
+        if not data:
             log("convert no data!")
             raise ConnectionError
 
-        data = request.form
-        stream = base64.b64decode(data.get("file"))
+        file_path = data.get("file_path")
+        if file_path is None:
+            stream = base64.b64decode(data.get("file"))
+            log("get bytes from file " + str(time.time()-_time))
+        # 有路径则直接取路径打开文件
+        else:
+            with open(file_path, "rb") as f:
+                stream = f.read()
+            log("get bytes from file_path " + str(time.time()-_time))
         _type = data.get("type")
-
         _md5 = get_md5_from_bytes(stream)
         _md5 = _md5[0]
         _global.update({"md5": _md5})
@@ -650,7 +658,7 @@ def test_one(p, from_remote=False):
     else:
         ocr_model = ocr_interface.OcrModels().get_model()
         otr_model = otr_interface.OtrModels().get_model()
-        result = convert(data, ocr_model, otr_model)
+        result = convert_maxcompute(data, ocr_model, otr_model)
 
     print("result_text", result.get("result_text")[0][:20])
     print("is_success", result.get("is_success"))

+ 14 - 2
format_convert/convert_image.py

@@ -138,6 +138,18 @@ def image_process(image_np, image_path, is_from_pdf=False, is_from_docx=False, u
         # # pred = np.load(decompressed_array, allow_pickle=True)['arr_0']
         # # log("inputs.shape" + str(pred.shape))
         #
+
+        # 调用gpu共享内存处理
+        # _dict = {"inputs": inputs, "md5": _global.get("md5")}
+        # result = from_gpu_share_memory(_dict, model_type="otr", predictor_type="")
+        # if judge_error_code(result):
+        #     logging.error("from_gpu_interface failed! " + str(result))
+        #     raise requests.exceptions.RequestException
+        #
+        # pred = result.get("preds")
+        # gpu_time = result.get("gpu_time")
+        # log("otr model predict time " + str(gpu_time))
+        #
         # # 后处理
         # list_line = table_postprocess(img_new, pred, prob)
         # log("len(list_line) " + str(len(list_line)))
@@ -159,8 +171,8 @@ def image_process(image_np, image_path, is_from_pdf=False, is_from_docx=False, u
         if image_np.shape[0] >= threshold or image_np.shape[1] >= threshold:
             best_h, best_w = get_best_predict_size2(image_np, threshold)
             image_resize = cv2.resize(image_np, (best_w, best_h), interpolation=cv2.INTER_AREA)
-            # image_resize_path = image_path.split(".")[0] + "_resize_ocr." + image_path.split(".")[-1]
-            # cv2.imwrite(image_resize_path, image_resize)
+            image_resize_path = image_path.split(".")[0] + "_resize_ocr." + image_path.split(".")[-1]
+            cv2.imwrite(image_resize_path, image_resize)
         log("ocr resize before " + str(time.time()-start_time))
 
         # 调用ocr模型接口

+ 42 - 0
format_convert/convert_mp.py

@@ -0,0 +1,42 @@
+import os
+import sys
+from flask import Flask
+sys.path.append(os.path.dirname(os.path.abspath(__file__)) + "/../")
+from multiprocessing import Process,Queue
+from format_convert.convert import _convert
+from format_convert.utils import get_ip_port
+
+
+ip_port_dict = get_ip_port()
+ip = "http://127.0.0.1"
+convert_port_list = ip_port_dict.get(ip).get("convert")
+ocr_port_list = ip_port_dict.get(ip).get("ocr")
+otr_port_list = ip_port_dict.get(ip).get("otr")
+
+
+attachment_queue = Queue()
+
+class A:
+    def __init__(self):
+        print(A)
+a = A()
+
+
+def convert_start_mp():
+    pass
+
+
+app = Flask(__name__)
+
+
+@app.route('/test', methods=['POST'])
+def test():
+    print(str(attachment_queue))
+    print(str(a))
+    attachment_queue.put(1)
+    print(attachment_queue.qsize())
+
+
+if __name__ == '__main__':
+    # app.run(host='0.0.0.0', port=15011, processes=3, threaded=False, debug=False)
+    app.run()

+ 173 - 8
format_convert/convert_need_interface.py

@@ -3,6 +3,7 @@ import base64
 import inspect
 import json
 import logging
+import multiprocessing
 import os
 import pickle
 import random
@@ -10,6 +11,8 @@ import sys
 import time
 import uuid
 import zlib
+from queue import Queue
+
 import redis
 from werkzeug.exceptions import NotFound
 sys.path.append(os.path.dirname(os.path.abspath(__file__)) + "/../")
@@ -17,10 +20,11 @@ import traceback
 import requests
 from format_convert import _global
 from format_convert.utils import get_platform, get_sequential_data, judge_error_code, request_post, get_ip_port, \
-    get_intranet_ip, get_logger, log, memory_decorator
+    get_intranet_ip, get_logger, log, get_args_from_config
 from ocr.ocr_interface import ocr, OcrModels
 from otr.otr_interface import otr, OtrModels
 from format_convert.libreoffice_interface import office_convert
+import numpy as np
 
 
 if get_platform() == "Windows":
@@ -43,6 +47,12 @@ if get_platform() == "Windows":
 else:
     FROM_REMOTE = True
 
+# ip_port_dict = get_ip_port()
+# ip = 'http://127.0.0.1'
+# ocr_port_list = ip_port_dict.get(ip).get("ocr")
+# otr_port_list = ip_port_dict.get(ip).get("otr")
+
+lock = multiprocessing.RLock()
 
 # 连接redis数据库
 # redis_db = redis.StrictRedis(host='192.168.2.103', port='6379',
@@ -50,7 +60,7 @@ else:
 redis_db = None
 
 
-def _interface(_dict, time_out=60, retry_times=3, use_zlib=False):
+def _interface(_dict, time_out=60, retry_times=3):
     try:
         # 重试
         model_type = _dict.get("model_type")
@@ -59,10 +69,9 @@ def _interface(_dict, time_out=60, retry_times=3, use_zlib=False):
             if judge_error_code(ip_port):
                 return ip_port
             _url = ip_port + "/" + model_type
-            base64_stream = base64.b64encode(pickle.dumps(_dict))
-            # if use_zlib:
-            #     base64_stream = zlib.compress(base64_stream)
-            r = pickle.loads(base64.b64decode(request_post(_url, {"data": base64_stream, "model_type": model_type}, time_out=time_out, use_zlib=use_zlib)))
+            # base64_stream = base64.b64encode(pickle.dumps(_dict))
+            r = json.loads(request_post(_url, {"data": json.dumps(_dict),
+                                                 "model_type": model_type}, time_out=time_out))
             log("get _interface return")
             if type(r) == list:
                 # 接口连不上换个端口重试
@@ -301,6 +310,71 @@ def from_gpu_interface_redis(_dict, model_type, predictor_type):
         return [-2]
 
 
+# def from_gpu_flask_sm(_dict, model_type, predictor_type):
+#     log("into from_gpu_share_memory")
+#     start_time = time.time()
+#     shm = None
+#     try:
+#         # 放入共享内存
+#         _time = time.time()
+#         np_data = _dict.get("inputs")
+#         shm = to_share_memory(np_data)
+#         log("data into share memory " + str(shm.name) + " " + str(time.time()-_time))
+#
+#         # 调用接口
+#         _time = time.time()
+#         _dict.pop("inputs")
+#         _dict.update({"predictor_type": predictor_type, "model_type": model_type,
+#                       "sm_name": shm.name, "sm_shape": np_data.shape,
+#                       "sm_dtype": str(np_data.dtype)})
+#         result = _interface(_dict, time_out=30, retry_times=2)
+#         log("_interface cost " + str(time.time()-_time))
+#
+#         # 读取共享内存
+#         _time = time.time()
+#         sm_name = result.get("sm_name")
+#         sm_shape = result.get("sm_shape")
+#         sm_dtype = result.get("sm_dtype")
+#         sm_dtype = get_np_type(sm_dtype)
+#         if sm_name:
+#             outputs = from_share_memory(sm_name, sm_shape, sm_dtype)
+#         else:
+#             log("from_share_memory failed!")
+#             raise Exception
+#         log("data from share memory " + sm_name + " " + str(time.time()-_time))
+#
+#         log("from_gpu_interface finish - size " + str(sys.getsizeof(_dict)) + " - time " + str(time.time()-start_time))
+#         return {"preds": outputs, "gpu_time": result.get("gpu_time")}
+#     except Exception as e:
+#         log("from_gpu_interface failed " + str(time.time()-start_time))
+#         traceback.print_exc()
+#         return [-2]
+#     finally:
+#         # del b  # Unnecessary; merely emphasizing the array is no longer used
+#         if shm:
+#             try:
+#                 shm.close()
+#                 shm.unlink()
+#             except FileNotFoundError:
+#                 log("share memory " + shm.name + " not exists!")
+#             except Exception:
+#                 traceback.print_exc()
+#
+#
+# def from_gpu_share_memory(_dict, model_type, predictor_type):
+#     log("into from_gpu_share_memory")
+#     start_time = time.time()
+#     try:
+#         _dict.update({"model_type": model_type, "predictor_type": predictor_type})
+#         outputs, gpu_time = share_memory_pool(_dict)
+#         log("from_gpu_share_memory finish - size " + str(sys.getsizeof(_dict)) + " - time " + str(time.time()-start_time))
+#         return {"preds": outputs, "gpu_time": float(gpu_time)}
+#     except Exception as e:
+#         log("from_gpu_interface failed " + str(time.time()-start_time))
+#         traceback.print_exc()
+#         return [-2]
+
+
 def from_otr_interface2(image_stream):
     log("into from_otr_interface")
     try:
@@ -485,7 +559,7 @@ def interface_pool_gunicorn(interface_type):
         if interface_type == "office":
             _ip = "http://127.0.0.1"
             # 选取端口
-            port_list = ip_port_dict.get(_ip).get(interface_type)
+            port_list = ip_port_dict.get(_ip).get("MASTER").get(interface_type)
             ip_type_cnt = ip_port_flag_dict.get(_ip).get(interface_type)
             if ip_type_cnt == 0:
                 _port = port_list[random.randint(0, len(port_list)-1)]
@@ -506,8 +580,20 @@ def interface_pool_gunicorn(interface_type):
             else:
                 ip_index = 0
             _ip = ip_flag_list[ip_index][0]
+            if "master" in _ip:
+                port_index = 1
+            else:
+                port_index = 0
+            _ip = _ip.split("_")[0]
             # 选取端口, 使用gunicorn则直接选第一个
-            _port = ip_port_dict.get(_ip).get(interface_type)[0]
+            # _port = ip_port_dict.get(_ip).get("MASTER").get(interface_type)[0]
+            log("_ip " + _ip)
+            log("interface_type " + interface_type)
+            port_list = get_args_from_config(ip_port_dict, _ip, interface_type)
+            log("port_list" + str(port_list))
+            if port_index >= len(port_list):
+                port_index = 0
+            _port = get_args_from_config(ip_port_dict, _ip, interface_type)[port_index][0]
             # 更新flag
             if ip_flag_list[ip_index][1] >= 10000:
                 ip_flag_list[ip_index][1] = 0
@@ -526,6 +612,85 @@ def interface_pool_gunicorn(interface_type):
         return [-1]
 
 
+# def share_memory_pool(args_dict):
+#     np_data = args_dict.get("inputs")
+#     _type = args_dict.get("model_type")
+#     args_dict.update({"sm_shape": np_data.shape, "sm_dtype": str(np_data.dtype)})
+#
+#     if _type == 'ocr':
+#         port_list = ocr_port_list
+#     elif _type == 'otr':
+#         port_list = otr_port_list
+#     else:
+#         log("type error! only support ocr otr")
+#         raise Exception
+#
+#     # 循环判断是否有空的share memory
+#     empty_sm_list = None
+#     sm_list_name = ""
+#     while empty_sm_list is None:
+#         for p in port_list:
+#             sm_list_name = "sml_"+_type+"_"+str(p)
+#             sm_list = get_share_memory_list(sm_list_name)
+#             if sm_list[0] == "0":
+#                 lock.acquire(timeout=0.1)
+#                 if sm_list[0] == "0":
+#                     sm_list[0] = "1"
+#                     sm_list[-1] = "0"
+#                     empty_sm_list = sm_list
+#                     break
+#                 else:
+#                     continue
+#                 lock.release()
+#
+#     log(str(os.getppid()) + " empty_sm_list " + sm_list_name)
+#
+#     # numpy放入共享内存
+#     _time = time.time()
+#     release_share_memory(get_share_memory("psm_" + str(os.getpid())))
+#     shm = to_share_memory(np_data)
+#     log("data into share memory " + str(shm.name) + " " + str(time.time()-_time))
+#
+#     # 参数放入共享内存列表
+#     empty_sm_list[1] = args_dict.get("md5")
+#     empty_sm_list[2] = args_dict.get("model_type")
+#     empty_sm_list[3] = args_dict.get("predictor_type")
+#     empty_sm_list[4] = args_dict.get("args")
+#     empty_sm_list[5] = str(shm.name)
+#     empty_sm_list[6] = str(args_dict.get("sm_shape"))
+#     empty_sm_list[7] = args_dict.get("sm_dtype")
+#     empty_sm_list[-1] = "1"
+#     # log("empty_sm_list[7] " + empty_sm_list[7])
+#     close_share_memory_list(empty_sm_list)
+#
+#     # 循环判断是否完成
+#     finish_sm_list = get_share_memory_list(sm_list_name)
+#     while True:
+#         if finish_sm_list[-1] == "0":
+#             break
+#
+#     # 读取共享内存
+#     _time = time.time()
+#     sm_name = finish_sm_list[5]
+#     sm_shape = finish_sm_list[6]
+#     sm_shape = eval(sm_shape)
+#     sm_dtype = finish_sm_list[7]
+#     gpu_time = finish_sm_list[8]
+#     sm_dtype = get_np_type(sm_dtype)
+#     outputs = from_share_memory(sm_name, sm_shape, sm_dtype)
+#     log(args_dict.get("model_type") + " " + args_dict.get("predictor_type") + " outputs " + str(outputs.shape))
+#     log("data from share memory " + sm_name + " " + str(time.time()-_time))
+#
+#     # 释放
+#     release_share_memory(get_share_memory(sm_name))
+#
+#     # 重置share memory list
+#     finish_sm_list[-1] = "0"
+#     finish_sm_list[0] = "0"
+#
+#     close_share_memory_list(finish_sm_list)
+#     return outputs, gpu_time
+
 
 # def interface_pool(interface_type):
 #     try:

+ 8 - 9
format_convert/convert_test.py

@@ -4,8 +4,7 @@ import os
 import random
 import sys
 import time
-from multiprocessing.context import Process
-
+from multiprocessing import Process
 sys.path.append(os.path.dirname(os.path.abspath(__file__)) + "/../")
 from format_convert.utils import get_platform, request_post, get_md5_from_bytes
 
@@ -20,9 +19,10 @@ def test_one(p, from_remote=False):
 
     data = {"file": file_base64, "type": p.split(".")[-1], "filemd5": 100}
     if from_remote:
-        _url = 'http://121.46.18.113:15010/convert'
-        # _url = 'http://192.168.2.103:15010/convert'
+        # _url = 'http://121.46.18.113:15010/convert'
+        # _url = 'http://192.168.2.102:15010/convert'
         # _url = 'http://172.16.160.65:15010/convert'
+        _url = 'http://127.0.0.1:15010/convert'
         result = json.loads(request_post(_url, data, time_out=10000))
     else:
         print("only support remote!")
@@ -36,7 +36,7 @@ def test_one(p, from_remote=False):
 def test_duplicate(path_list, process_no=None):
     start_time = time.time()
     # random.shuffle(path_list)
-    for i in range(1):
+    for i in range(10):
         if i % 10 == 0:
             if process_no is not None:
                 print("Process", process_no, i*len(path_list), time.time()-start_time)
@@ -48,10 +48,10 @@ def test_duplicate(path_list, process_no=None):
 
 if __name__ == '__main__':
     if get_platform() == "Windows":
-        # file_path = "C:/Users/Administrator/Desktop/error2.swf"
+        file_path = "C:/Users/Administrator/Desktop/test_image/error1.png"
         # file_path = "D:/BIDI_DOC/比地_文档/2022/Test_Interface/20210609202634853485.xlsx"
         # file_path = "D:/BIDI_DOC/比地_文档/2022/Test_ODPS/1624325845476.pdf"
-        file_path = "C:/Users/Administrator/Downloads/1652672734044.jpg"
+        # file_path = "C:/Users/Administrator/Downloads/1653547877897.pdf"
     else:
         file_path = "test1.doc"
     test_one(file_path, from_remote=True)
@@ -63,8 +63,7 @@ if __name__ == '__main__':
     #     # file_path_list = ["D:/BIDI_DOC/比地_文档/2022/Test_Interface/1623328459080.doc",
     #     #                   "D:/BIDI_DOC/比地_文档/2022/Test_Interface/94961e1987d1090e.xls"]
     #     # file_path_list = ["D:/BIDI_DOC/比地_文档/2022/Test_Interface/1623423836610.pdf"]
-    #     file_path_list = ["C:/Users/Administrator/Downloads/广东中检达元检测技术有限公司.pdf",
-    #                       "C:/Users/Administrator/Desktop/error16.jpg"]
+    #     file_path_list = ["C:/Users/Administrator/Desktop/error16.jpg"]
     # else:
     #     file_path_list = ["1623423836610.pdf"]
     # start_time = time.time()

+ 76 - 5
format_convert/convert_xlsx.py

@@ -7,6 +7,7 @@ import logging
 import traceback
 import pandas
 import numpy as np
+import xlrd
 from format_convert.utils import get_logger, log, memory_decorator
 from format_convert.wrapt_timeout_decorator import timeout
 
@@ -51,13 +52,16 @@ class XlsxConvert:
 
     @timeout(30, timeout_exception=TimeoutError, use_signals=False)
     def read(self):
+        # pandas
         df = pandas.read_excel(self.path, header=None, keep_default_na=False, sheet_name=None)
-        return df
+        # xlrd 为了读取合并单元格
+        workbook = xlrd.open_workbook(self.path)
+        return df, workbook
 
     def init_package(self):
         # 各个包初始化
         try:
-            self.df = self.read()
+            self.df, self.workbook = self.read()
             self.sheet_list = [sheet for sheet in self.df.values()]
 
             # 防止读太多空列空行
@@ -88,7 +92,7 @@ class XlsxConvert:
             if self.re_read > 0:
                 self.sheet_list = [sheet for sheet in self.df.values()]
 
-            print(self.sheet_list[0].shape)
+            # print(self.sheet_list[0].shape)
         except:
             log("cannot open xlsx!")
             traceback.print_exc()
@@ -102,14 +106,14 @@ class XlsxConvert:
         sheet_no = 0
         for sheet in self.sheet_list:
             self._page = _Page(None, sheet_no)
-            self.convert_page(sheet)
+            self.convert_page(sheet, sheet_no)
 
             if self._doc.error_code is None and self._page.error_code is not None:
                 self._doc.error_code = self._page.error_code
             self._doc.add_child(self._page)
             sheet_no += 1
 
-    def convert_page(self, sheet):
+    def convert_page2(self, sheet):
         text = '<table border="1">' + "\n"
 
         # 剔除多余空列
@@ -148,6 +152,73 @@ class XlsxConvert:
         _table = _Table(text, (0, 0, 0, 0), is_html=True)
         self._page.add_child(_table)
 
+    def convert_page(self, sheet, sheet_no):
+        # 剔除多余空列
+        max_row_len = 0
+        max_col_len = 0
+        if self.re_read:
+            for index, row in sheet.iterrows():
+                col_len = 0
+                row_empty_flag = 1
+                for i in range(len(row)):
+                    if row[i] not in [None, "", np.nan]:
+                        row_empty_flag = 0
+                        col_len = i
+
+                if self.re_read == 3 or self.re_read == 1:
+                    if col_len > max_col_len:
+                        max_col_len = col_len
+
+                if self.re_read == 3 or self.re_read == 2:
+                    if row_empty_flag == 0:
+                        max_row_len = index
+
+        row_list = []
+        for index, row in sheet.iterrows():
+            if self.re_read == 3 or self.re_read == 2:
+                if index > max_row_len:
+                    break
+            if self.re_read == 3 or self.re_read == 1:
+                row = row[:max_col_len+1]
+            col_list = []
+            for r in row:
+                col_list.append(str(r))
+            row_list.append(col_list)
+
+        # xlrd 获取合并单元格位置
+        sheet_xlrd = self.workbook.sheet_by_index(sheet_no)
+        merged_cell_list = sheet_xlrd.merged_cells
+        merged_cell_list.sort(key=lambda x: (x[0], x[1], x[2], x[3]))
+
+        # 复制填充合并单元格
+        for row_start, row_end, col_start, col_end in merged_cell_list:
+            if row_start >= len(row_list) or row_end > len(row_list):
+                continue
+            if col_start >= len(row_list) or col_end > len(row_list):
+                continue
+            copy_cell = row_list[row_start][col_start]
+            for i in range(row_start, row_end):
+                row = row_list[i]
+                # 第一行补少一个,其他行需补多一个
+                if i == row_start:
+                    col_start_real = col_start+1
+                else:
+                    col_start_real = col_start
+                for j in range(col_start_real, col_end):
+                    if row[j] == "":
+                        row[j] = copy_cell
+
+        # 拼接html表格
+        text = '<table border="1">' + "\n"
+        for row in row_list:
+            for col in row:
+                text = text + "<td>" + str(col) + "</td>" + "\n"
+            text = text + "</tr>" + "\n"
+        text = text + "</table>" + "\n"
+
+        _table = _Table(text, (0, 0, 0, 0), is_html=True)
+        self._page.add_child(_table)
+
     def get_html(self):
         try:
             self.convert()

+ 20 - 26
format_convert/interface.yml

@@ -5,46 +5,40 @@ MASTER:
 #  local-102: 'http://192.168.2.102'
 #  local-103: 'http://192.168.2.103'
 #  local 'http://127.0.0.1'
-  ip:
+  ip: ['http://127.0.0.1']
 
   PATH:
-#  65: /root/miniconda3/bin/python
-#  102: /home/python/anaconda3/envs/convert/bin/python
-#  103: /home/yons/anaconda3/envs/tf1.5/bin/python
-#  251: /data/anaconda3/envs/convert/bin/python
-    python: '/data/anaconda3/envs/convert3/bin/gunicorn'
-    #  65: /data/format_conversion_maxcompute/
-#  102: /data/fangjiasheng/format_conversion_maxcompute/
-#  103: /data/python/fangjiasheng/format_conversion_maxcompute/
-#  251: /data/fangjiasheng/format_conversion_maxcompute/
-    project: '/data/fangjiasheng/format_conversion_maxcompute/'
+    python: ['/data/anaconda3/envs/convert3/bin/python']
+    gunicorn: ['/data/anaconda3/envs/convert3/bin/gunicorn']
+    project: ['/data/fangjiasheng/format_conversion_maxcompute/']
 
   CONVERT:
-    port: 15010
-    processes: 20
+    port: [15010]
+    processes: [30]
 
   SCHEDULE:
     port:
 
   OCR:
-    port_start: 17000
-    port_no: 3
+    port_start: [17000]
+    port_no: [4]
 
   OTR:
-    port_start: 18000
-    port_no: 3
+    port_start: [18000]
+    port_no: [2]
 
   OFFICE:
-    port_start: 16000
-    port_no: 25
+    port_start: [16000]
+    port_no: [25]
 
 
 SLAVE:
-  ip: ['http://127.0.0.1']
+  ip: ['http://127.0.0.1', 'http://192.168.0.114']
 
   PATH:
-    python: '/data/anaconda3/envs/tf2/bin/gunicorn'
-    project: '/data/fangjiasheng/format_conversion_maxcompute/'
+    python: ['/data/anaconda3/envs/convert/bin/python', '/data/anaconda3/envs/tf2/bin/python']
+    gunicorn: ['/data/anaconda3/envs/convert/bin/gunicorn', '/data/anaconda3/envs/tf2/bin/gunicorn']
+    project: ['/data/fangjiasheng/format_conversion_maxcompute/', '/data/fangjiasheng/format_conversion_maxcompute/']
 
   CONVERT:
     port:
@@ -54,12 +48,12 @@ SLAVE:
     port:
 
   OCR:
-    port_start: 17000
-    port_no: 2
+    port_start: [17005, 17000]
+    port_no: [0, 2]
 
   OTR:
-    port_start: 18000
-    port_no: 4
+    port_start: [18005, 18000]
+    port_no: [3, 4]
 
   OFFICE:
     port_start:

+ 3 - 7
format_convert/kill_all.py

@@ -7,10 +7,10 @@ from format_convert.utils import get_ip_port, get_intranet_ip
 
 
 ip_port_dict = get_ip_port()
-# ip = "http://" + get_intranet_ip()
 ip = "http://127.0.0.1"
-python_path = ip_port_dict.get(ip).get("python_path")
-project_path = ip_port_dict.get(ip).get("project_path")
+python_path = ip_port_dict.get(ip).get("MASTER").get("python_path")
+gunicorn_path = ip_port_dict.get(ip).get("MASTER").get("gunicorn_path")
+project_path = ip_port_dict.get(ip).get("MASTER").get("project_path")
 
 
 def kill():
@@ -30,10 +30,6 @@ def kill():
                 comm = "kill -9 " + str(pid)
                 print(comm, process_cmd)
                 os.system(comm)
-            if re.search(project_path, process_cmd):
-                comm = "kill -9 " + str(pid)
-                print(comm, process_cmd)
-                os.system(comm)
     else:
         print("cannot kill! checkout config...")
         print(ip_port_dict)

+ 77 - 44
format_convert/monitor_process_config.py

@@ -6,49 +6,80 @@ import sys
 import time
 import psutil
 sys.path.append(os.path.dirname(os.path.abspath(__file__)) + "/../")
-from format_convert.utils import get_ip_port, get_intranet_ip
+from format_convert.utils import get_ip_port, get_intranet_ip, get_args_from_config
+
+# ip_port_dict = get_ip_port()
+# ip = "http://" + get_intranet_ip()
+# ip = "http://127.0.0.1"
+# convert_port_list = ip_port_dict.get(ip).get("convert")
+# ocr_port_list = ip_port_dict.get(ip).get("ocr")
+# otr_port_list = ip_port_dict.get(ip).get("otr")
+# soffice_port_list = ip_port_dict.get(ip).get("office")
+# if not convert_port_list:
+#     convert_port_list = []
+# if not ocr_port_list:
+#     ocr_port_list = []
+# if not otr_port_list:
+#     otr_port_list = []
+# if not soffice_port_list:
+#     soffice_port_list = []
+#
+# # schedule_port_list = ip_port_dict.get(ip).get("schedule")
+# # python_path = ip_port_dict.get(ip).get("python_path")
+# # project_path = ip_port_dict.get(ip).get("project_path")
+#
+# interface_path = project_path[:-1]
+# std_out = " >>/convert.out 2>&1 &"
+# std_out_gpu = " >>/gpu.out 2>&1 &"
+# std_out_schedule = " >>/schedule.out 2>&1 &"
+# # convert_comm = "nohup " + python_path + " " + interface_path + "/format_convert/convert.py #" + std_out
+# # ocr_comm = "nohup " + python_path + " " + interface_path + "/ocr/ocr_gpu_interface.py # 0" + std_out_gpu
+# # otr_comm = "nohup " + python_path + " " + interface_path + "/otr/otr_gpu_interface.py # 0" + std_out_gpu
+# schedule_comm = "nohup " + python_path + " " + interface_path + "/format_convert/schedule_interface.py #" + std_out_schedule
+# soffice_comm = "docker run --init -itd --log-opt max-size=10m --log-opt max-file=3 -p #:16000 soffice:v2 bash"
+#
+#
+# gunicorn_path = python_path
+# # print("convert_port_list", len(convert_port_list))
+# convert_comm = "nohup " + gunicorn_path + " -w " + str(len(convert_port_list)) + " -t 300 -b 0.0.0.0:# --chdir " \
+#                + interface_path + "/format_convert convert:app" + std_out
+# ocr_comm = "nohup " + gunicorn_path + " -w " + str(len(ocr_port_list)) + " -t 300 --keep-alive 600 -b 0.0.0.0:# --chdir " \
+#            + interface_path + "/ocr ocr_gpu_interface:app" + std_out_gpu
+# otr_comm = "nohup " + gunicorn_path + " -w " + str(len(otr_port_list)) + " -t 300 --keep-alive 600 -b 0.0.0.0:# --chdir " \
+#            + interface_path + "/otr otr_gpu_interface:app" + std_out_gpu
 
 
 ip_port_dict = get_ip_port()
-# ip = "http://" + get_intranet_ip()
 ip = "http://127.0.0.1"
-convert_port_list = ip_port_dict.get(ip).get("convert")
-ocr_port_list = ip_port_dict.get(ip).get("ocr")
-otr_port_list = ip_port_dict.get(ip).get("otr")
-soffice_port_list = ip_port_dict.get(ip).get("office")
-if not convert_port_list:
-    convert_port_list = []
-if not ocr_port_list:
-    ocr_port_list = []
-if not otr_port_list:
-    otr_port_list = []
-if not soffice_port_list:
-    soffice_port_list = []
-
-# schedule_port_list = ip_port_dict.get(ip).get("schedule")
-python_path = ip_port_dict.get(ip).get("python_path")
-project_path = ip_port_dict.get(ip).get("project_path")
-
-
-interface_path = project_path[:-1]
+convert_port_list = get_args_from_config(ip_port_dict, ip, "convert", "MASTER")[0]
+ocr_port_list = get_args_from_config(ip_port_dict, ip, "ocr")
+otr_port_list = get_args_from_config(ip_port_dict, ip, "otr")
+soffice_port_list = get_args_from_config(ip_port_dict, ip, "office", "MASTER")[0]
+python_path_list = get_args_from_config(ip_port_dict, ip, "python_path")
+project_path_list = get_args_from_config(ip_port_dict, ip, "project_path")
+gunicorn_path_list = get_args_from_config(ip_port_dict, ip, "gunicorn_path")
 std_out = " >>/convert.out 2>&1 &"
 std_out_gpu = " >>/gpu.out 2>&1 &"
 std_out_schedule = " >>/schedule.out 2>&1 &"
-# convert_comm = "nohup " + python_path + " " + interface_path + "/format_convert/convert.py #" + std_out
-# ocr_comm = "nohup " + python_path + " " + interface_path + "/ocr/ocr_gpu_interface.py # 0" + std_out_gpu
-# otr_comm = "nohup " + python_path + " " + interface_path + "/otr/otr_gpu_interface.py # 0" + std_out_gpu
-schedule_comm = "nohup " + python_path + " " + interface_path + "/format_convert/schedule_interface.py #" + std_out_schedule
-soffice_comm = "docker run --init -itd --log-opt max-size=10m --log-opt max-file=3 -p #:16000 soffice:v2 bash"
 
-
-gunicorn_path = python_path
-print("convert_port_list", len(convert_port_list))
-convert_comm = "nohup " + gunicorn_path + " -w " + str(len(convert_port_list)) + " -t 300 -b 0.0.0.0:# --chdir " \
-               + interface_path + "/format_convert convert:app" + std_out
-ocr_comm = "nohup " + gunicorn_path + " -w " + str(len(ocr_port_list)) + " -t 300 --keep-alive 600 -b 0.0.0.0:# --chdir " \
-           + interface_path + "/ocr ocr_interface:app" + std_out_gpu
-otr_comm = "nohup " + gunicorn_path + " -w " + str(len(otr_port_list)) + " -t 300 --keep-alive 600 -b 0.0.0.0:# --chdir " \
-           + interface_path + "/otr otr_interface:app" + std_out_gpu
+print("ocr_port_list", ocr_port_list)
+print("otr_port_list", otr_port_list)
+print("soffice_port_list", soffice_port_list)
+
+ocr_comm_list = []
+otr_comm_list = []
+for i in range(len(ocr_port_list)):
+    ocr_comm_list.append("nohup " + gunicorn_path_list[i] + " -w " + str(len(ocr_port_list[i]))
+                         + " -t 300 --keep-alive 600 -b 0.0.0.0:# --chdir "
+                         + project_path_list[i] + "/ocr ocr_interface:app" + std_out_gpu)
+for i in range(len(otr_port_list)):
+    otr_comm_list.append("nohup " + gunicorn_path_list[i] + " -w " + str(len(otr_port_list[i]))
+                         + " -t 300 --keep-alive 600 -b 0.0.0.0:# --chdir "
+                         + project_path_list[i] + "/otr otr_interface:app" + std_out_gpu)
+convert_comm = "nohup " + gunicorn_path_list[0] + " -w " + str(len(convert_port_list)) + " -t 300 -b 0.0.0.0:# --chdir " \
+               + project_path_list[0] + "/format_convert convert:app" + std_out
+schedule_comm = "nohup " + python_path_list[0] + " " + project_path_list[0] + "/format_convert/schedule_interface.py #" + std_out_schedule
+soffice_comm = "docker run --init -itd --log-opt max-size=10m --log-opt max-file=3 -p #:16000 soffice:v2 bash"
 
 
 def get_port():
@@ -61,13 +92,13 @@ def get_port():
     return current_port_list
 
 
-def restart(process_type, port):
+def restart(process_type, port, index=0):
     if process_type == "convert":
         _comm = re.sub("#", port, convert_comm)
     elif process_type == "ocr":
-        _comm = re.sub("#", port, ocr_comm)
+        _comm = re.sub("#", port, ocr_comm_list[index])
     elif process_type == "otr":
-        _comm = re.sub("#", port, otr_comm)
+        _comm = re.sub("#", port, otr_comm_list[index])
     elif process_type == "soffice":
         _comm = re.sub("#", port, soffice_comm)
     elif process_type == "schedule":
@@ -145,14 +176,16 @@ def monitor():
                 restart("convert", p)
 
     if ocr_port_list:
-        for p in ocr_port_list[:1]:
-            if p not in current_port_list:
-                restart("ocr", p)
+        for j in range(len(ocr_port_list)):
+            for p in ocr_port_list[j][:1]:
+                if p not in current_port_list:
+                    restart("ocr", p, index=j)
 
     if otr_port_list:
-        for p in otr_port_list[:1]:
-            if p not in current_port_list:
-                restart("otr", p)
+        for j in range(len(otr_port_list)):
+            for p in otr_port_list[j][:1]:
+                if p not in current_port_list:
+                    restart("otr", p, index=j)
 
     if soffice_port_list:
         for p in soffice_port_list:

+ 278 - 85
format_convert/utils.py

@@ -1,4 +1,6 @@
 # -*- coding:utf-8 -*-
+import argparse
+import copy
 import hashlib
 import inspect
 import json
@@ -24,6 +26,7 @@ from format_convert import _global
 from functools import wraps
 import psutil
 import time
+import numpy as np
 from format_convert.judge_platform import get_platform
 if get_platform() == "Linux":
     import resource
@@ -58,7 +61,7 @@ def add_div(text):
         return text
 
     text = "<div>" + text + "\n"
-    text = re.sub("\n", "</div>\n<div>", text)
+    text = re.sub("\n", "</div><div>", text)
     # text += "</div>"
     if text[-5:] == "<div>":
         # print("add_div has cut", text[-30:])
@@ -522,7 +525,6 @@ class LineTable:
         self.list_line = list_line
         self.list_crosspoints = self.recognize_crosspoints(list_line)
 
-  
         # 聚类
         cluster_crosspoints = []
         for _point in self.list_crosspoints:
@@ -1076,8 +1078,10 @@ class LineTable:
         # for _textbox in list_textbox:
         #     print(_textbox.get_text())
         # print("textbox:======>>>>>>>>>>>>>")
-
-
+        # for c in clusters_rects:
+        #     print("+"*30)
+        #     for cc in c:
+        #         print("rect", cc.)
         # cul spans
         for _line in clusters_rects:
             for _rect in _line:
@@ -1142,19 +1146,69 @@ class LineTable:
         list_textbox.sort(key=lambda x:x.bbox[0])
         list_textbox.sort(key=lambda x:x.bbox[3],reverse=sourceP_LB)
         # print("list_textbox", list_textbox)
-        for textbox in list_textbox:
-            (x0,y0,x1,y1) = textbox.bbox
-            _text = textbox.get_text()
-            _find = False
-            for table_line in _table:
-                for _cell in table_line:
+        # for textbox in list_textbox:
+        #     (x0,y0,x1,y1) = textbox.bbox
+        #     _text = textbox.get_text()
+        #     _find = False
+        #     for table_line in _table:
+        #         for _cell in table_line:
+        #             if self.inbox(textbox.bbox, _cell["bbox"], textbox.get_text()):
+        #                 _cell["text"] += _text
+        #                 in_objs.add(textbox)
+        #                 _find = True
+        #                 break
+        #         if _find:
+        #             break
+        has_matched_box_list = []
+        for table_line in _table:
+            for _cell in table_line:
+                inbox_textbox_list = []
+                for textbox in list_textbox:
+                    if textbox in in_objs:
+                        continue
                     if self.inbox(textbox.bbox, _cell["bbox"], textbox.get_text()):
-                        _cell["text"] += _text
+                        inbox_textbox_list.append(textbox)
                         in_objs.add(textbox)
-                        _find = True
-                        break
-                if _find:
-                    break
+
+                # 分行,根据y重合
+                all_match_box_list = []
+                for i in range(len(inbox_textbox_list)):
+                    match_box_list = []
+                    box1 = inbox_textbox_list[i]
+                    if box1 in has_matched_box_list:
+                        continue
+
+                    min_y1 = box1.bbox[1] + 1/3 * abs(box1.bbox[3]-box1.bbox[1])
+                    max_y1 = box1.bbox[3] - 1/3 * abs(box1.bbox[3]-box1.bbox[1])
+                    match_box_list.append([box1.get_text(), box1.bbox[0], box1.bbox[1], box1.bbox[2], box1.bbox[3]])
+                    has_matched_box_list.append(box1)
+                    for j in range(i+1, len(inbox_textbox_list)):
+                        box2 = inbox_textbox_list[j]
+                        if box2 in has_matched_box_list:
+                            continue
+
+                        # print(min_y1, box2.bbox[1], box2.bbox[3], max_y1)
+                        # print(min_y2, box1.bbox[3], max_y2)
+                        if min_y1 <= box2.bbox[1] <= max_y1 or \
+                                min_y1 <= box2.bbox[3] <= max_y1 or \
+                                box2.bbox[1] <= min_y1 <= max_y1 <= box2.bbox[3]:
+                            match_box_list.append([box2.get_text(), box2.bbox[0], box2.bbox[1], box2.bbox[2], box2.bbox[3]])
+                            has_matched_box_list.append(box2)
+                    match_box_list.sort(key=lambda x: x[1])
+                    all_match_box_list.append(match_box_list)
+
+                # print("match_box_list", all_match_box_list)
+                for box_list in all_match_box_list:
+                    for box in box_list:
+                        _cell["text"] += box[0]
+
+        # print("------------")
+        # for _line in _table:
+        #     for _cell in _line:
+        #         print(_cell["text"])
+        #         print("\n")
+        # print("------------")
+
         if fixspan:
             for _line in _table:
                 for c_i in range(len(_line)):
@@ -1181,17 +1235,6 @@ class LineTable:
                                 # print(len(_table),l_i+i)
                                 _table[l_i+i].insert(c_i,n_cell)
 
-
-        # print("=======")
-        # for _line in _table:
-        #     for _cell in _line:
-        #         _text = _cell["text"][:2]+"_"+str(_cell["columnspan"])+"_"+str(_cell["rowspan"])
-        #         if _text=="":
-        #             _text = "=="
-        #         print(_text,end="\t")
-        #     print("\n")
-        # print("===========")
-
         if fixRect:
             for _line in _table:
                 extend_line = []
@@ -1230,23 +1273,22 @@ class LineTable:
                                      "columnspan": self.getspan(list_x, _bbox[0], _bbox[2], margin),
                                      "text": ""}
                             extend_line.append({"index":c_i+1,"cell":_cell})
-                extend_line.sort(key=lambda x:x["index"],reverse=True)
+                extend_line.sort(key=lambda x: x["index"],reverse=True)
 
                 for _tmp in extend_line:
                     _line.insert(_tmp["index"],_tmp["cell"])
 
-
-                list_textbox.sort(key=lambda x:x.bbox[0])
-                list_textbox.sort(key=lambda x:x.bbox[3],reverse=sourceP_LB)
+                list_textbox.sort(key=lambda x: x.bbox[0])
+                list_textbox.sort(key=lambda x: x.bbox[3], reverse=sourceP_LB)
                 for textbox in list_textbox:
                     if textbox in in_objs:
                         continue
-                    (x0,y0,x1,y1) = textbox.bbox
+                    x0, y0, x1, y1 = textbox.bbox
                     _text = textbox.get_text()
                     _find = False
                     for table_line in _table:
                         for _cell in table_line:
-                            if self.inbox(textbox.bbox,_cell["bbox"], textbox.get_text()):
+                            if self.inbox(textbox.bbox, _cell["bbox"], textbox.get_text()):
                                 _cell["text"] += _text
                                 in_objs.add(textbox)
                                 _find = True
@@ -1254,26 +1296,18 @@ class LineTable:
                         if _find:
                             break
 
-
-
-
-
+        table_bbox = (_table[0][0].get("bbox")[0],
+                      _table[0][0].get("bbox")[1],
+                      _table[-1][-1].get("bbox")[2],
+                      _table[-1][-1].get("bbox")[3])
 
         # print("=======")
         # for _line in _table:
         #     for _cell in _line:
-        #         _text = _cell["text"][:2]
-        #         if _text=="":
-        #             _text = "=="
-        #         print(_text,end="\t")
-        #     print("\n")
+        #         print(_cell["text"])
+        #         print("\n")
         # print("===========")
 
-        table_bbox = (_table[0][0].get("bbox")[0],
-                      _table[0][0].get("bbox")[1],
-                      _table[-1][-1].get("bbox")[2],
-                      _table[-1][-1].get("bbox")[3])
-
         ta = {"bbox": table_bbox, "table": _table}
         return ta
 
@@ -1337,6 +1371,8 @@ def get_table_html(table):
 
 def sort_object(obj_list, is_reverse=False):
     from format_convert.convert_tree import _Table, _Image, _Sentence, _Page
+    obj_list = combine_object(obj_list)
+
     if len(obj_list) == 0:
         return obj_list
     if isinstance(obj_list[0], (_Table, _Sentence, _Image)):
@@ -1349,6 +1385,38 @@ def sort_object(obj_list, is_reverse=False):
         return obj_list
 
 
+def combine_object(obj_list, threshold=5):
+    from format_convert.convert_tree import _Sentence
+    sentence_list = []
+    for obj in obj_list:
+        if isinstance(obj, _Sentence):
+            obj.content = re.sub("\s", "", obj.content)
+            sentence_list.append(obj)
+    sentence_list.sort(key=lambda x: (x.y, x.x))
+    for sen in sentence_list:
+        obj_list.remove(sen)
+
+    delete_list = []
+    for i in range(1, len(sentence_list)):
+        sen1 = sentence_list[i-1]
+        sen2 = sentence_list[i]
+        if abs(sen2.y - sen1.y) <= threshold:
+            if sen2.x > sen1.x:
+                sen2.x = sen1.x
+                sen2.content = sen1.content + sen2.content
+            else:
+                sen2.content = sen2.content + sen1.content
+            if sen2.y > sen1.y:
+                sen2.y = sen1.y
+            delete_list.append(sen1)
+
+    for sen in delete_list:
+        sentence_list.remove(sen)
+    for sen in sentence_list:
+        obj_list.append(sen)
+    return obj_list
+
+
 session_ocr = requests.Session()
 session_otr = requests.Session()
 session_all = requests.Session()
@@ -1444,48 +1512,60 @@ def get_ip_port(node_type=None, interface_type=None):
 
     ip_port_dict = {}
     params = parse_yaml()
+    # 循环 master slave
     for type1 in node_type_list:
         node_type = type1.upper()
         ip_list = params.get(node_type).get("ip")
-        for type2 in interface_type_list:
-            interface_type = type2.upper()
-            processes = 0
-            python_path = None
-            project_path = None
-            if interface_type in ["convert".upper()]:
-                _port = params.get(node_type).get(interface_type).get("port")
-                if _port is None:
-                    port_list = []
-                else:
-                    if interface_type == "convert".upper():
-                        processes = params.get(node_type).get(interface_type).get("processes")
-                    port_list = [str(_port)]*int(processes)
-                    # port_list = [str(_port)]
-            elif interface_type == "path".upper():
-                python_path = params.get(node_type).get(interface_type).get("python")
-                project_path = params.get(node_type).get(interface_type).get("project")
+
+        # 循环多个IP
+        for j in range(len(ip_list)):
+            _ip = ip_list[j]
+            if ip_port_dict.get(_ip):
+                ip_port_dict.get(_ip).update({node_type: {}})
             else:
-                port_start = params.get(node_type).get(interface_type).get("port_start")
-                port_no = params.get(node_type).get(interface_type).get("port_no")
-                if port_start is None or port_no is None:
-                    port_list = []
+                ip_port_dict.update({_ip: {node_type: {}}})
+
+            # 有IP时,循环多个参数
+            for type2 in interface_type_list:
+                python_path = None
+                project_path = None
+                gunicorn_path = None
+                processes = 0
+                port_list = []
+                interface_type = type2.upper()
+                if interface_type in ["convert".upper()]:
+                    _port = params.get(node_type).get(interface_type).get("port")
+                    if _port is None:
+                        port_list = []
+                    else:
+                        if interface_type == "convert".upper():
+                            processes = params.get(node_type).get(interface_type).get("processes")[j]
+                        port_list = [str(_port[j])]*int(processes)
+                        # port_list = [str(_port)]
+                elif interface_type == "path".upper():
+                    python_path = params.get(node_type).get(interface_type).get("python")[j]
+                    project_path = params.get(node_type).get(interface_type).get("project")[j]
+                    gunicorn_path = params.get(node_type).get(interface_type).get("gunicorn")[j]
                 else:
-                    port_list = [str(x) for x in range(port_start, port_start+port_no, 1)]
-            if ip_list:
-                for _ip in ip_list:
-                    if _ip is None:
-                        continue
-                    if _ip in ip_port_dict.keys():
-                        if port_list:
-                            ip_port_dict.get(_ip).update({interface_type.lower(): port_list})
+                    port_start = params.get(node_type).get(interface_type).get("port_start")
+                    port_no = params.get(node_type).get(interface_type).get("port_no")
+                    if port_start is None or port_no is None:
+                        port_list = []
                     else:
-                        if port_list:
-                            ip_port_dict[_ip] = {interface_type.lower(): port_list}
-                    if processes:
-                        ip_port_dict.get(_ip).update({interface_type.lower()+"_processes": processes})
-                    if project_path and python_path:
-                        ip_port_dict.get(_ip).update({"project_path": project_path,
-                                                      "python_path": python_path})
+                        port_list = [str(x) for x in range(port_start[j], port_start[j]+port_no[j], 1)]
+                # if ip_list:
+                #     for i in range(len(ip_list)):
+
+                # 参数放入dict
+                if port_list:
+                    ip_port_dict.get(_ip).get(node_type).update({interface_type.lower(): port_list})
+                if processes:
+                    ip_port_dict.get(_ip).get(node_type).update({interface_type.lower()+"_processes": processes})
+                if project_path and python_path and gunicorn_path:
+                    ip_port_dict.get(_ip).get(node_type).update({"project_path": project_path,
+                                                                 "python_path": python_path,
+                                                                 "gunicorn_path": gunicorn_path})
+                # print("ip_port_dict", ip_port_dict)
     return ip_port_dict
 
 
@@ -1582,7 +1662,10 @@ def set_flask_global():
                                   "convert": 0,
                                   "office": 0
                                   }})
-        ip_flag.append([_k, 0])
+        if ip_port_dict.get(_k).get("MASTER"):
+            ip_flag.append([_k+"_master", 0])
+        if ip_port_dict.get(_k).get("SLAVE"):
+            ip_flag.append([_k+"_slave", 0])
     _global.update({"ip_port_flag": ip_port_flag})
     _global.update({"ip_port": ip_port_dict})
     _global.update({"ip_flag": ip_flag})
@@ -1611,6 +1694,116 @@ def get_md5_from_bytes(_bytes):
         return None, _length
 
 
+# def to_share_memory(np_data, name=None):
+#     # from multiprocessing.resource_tracker import unregister
+#     from multiprocessing import shared_memory
+#     if name is None:
+#         sm_name = "psm_" + str(os.getpid())
+#     else:
+#         sm_name = name
+#     logging.info("into from_share_memory sm_name " + sm_name)
+#     shm = shared_memory.SharedMemory(name=sm_name, create=True, size=np_data.nbytes)
+#     # unregister(sm_name, 'shared_memory')
+#     sm_data = np.ndarray(np_data.shape, dtype=np_data.dtype, buffer=shm.buf)
+#     sm_data[:] = np_data[:]  # Copy the original data into shared memory
+#
+#     shm.close()
+#     del sm_data
+#     return shm
+
+
+# def from_share_memory(sm_name, _shape, _dtype, if_close=True):
+#     from multiprocessing import shared_memory
+#     logging.info("into from_share_memory sm_name " + sm_name)
+#     shm = shared_memory.SharedMemory(name=sm_name, create=False)
+#     b = np.ndarray(_shape, dtype=_dtype, buffer=shm.buf)
+#     sm_data = copy.deepcopy(b)
+#     b[::] = 0
+#
+#     if if_close:
+#         try:
+#             shm.close()
+#             shm.unlink()
+#         except Exception:
+#             log("file not found! " + sm_name)
+#     return sm_data
+
+
+# def get_share_memory(sm_name):
+#     try:
+#         from multiprocessing import shared_memory
+#         shm = shared_memory.SharedMemory(name=sm_name, create=False)
+#         return shm
+#     except:
+#         return None
+
+
+# def release_share_memory(shm):
+#     try:
+#         if shm is None:
+#             return
+#         shm.close()
+#         shm.unlink()
+#         log(str(shm.name) + " release successfully!")
+#     except FileNotFoundError:
+#         log(str(shm.name) + " has released!")
+#     except Exception as e:
+#         traceback.print_exc()
+
+
+# def get_share_memory_list(sm_list_name, list_size=None):
+#     # from multiprocessing.resource_tracker import unregister
+#     from multiprocessing import shared_memory
+#     if list_size is None:
+#         sm_list = shared_memory.ShareableList(name=sm_list_name)
+#     else:
+#         sm_list = shared_memory.ShareableList(name=sm_list_name, sequence=["0"]+[' '*2048]*(list_size-2)+["0"])
+#         # unregister(sm_list_name, 'shared_memory')
+#     return sm_list
+
+
+# def close_share_memory_list(sm_list):
+#     try:
+#         sm_list.shm.close()
+#     except Exception:
+#         traceback.print_exc()
+
+
+def get_np_type(_str):
+    _dtype = None
+    if _str == 'uint8':
+        _dtype = np.uint8
+    elif _str == 'float16':
+        _dtype = np.float16
+    elif _str == 'float32':
+        _dtype = np.float32
+    logging.info("get_np_type " + _str + " " + str(_dtype))
+    return _dtype
+
+
+def namespace_to_dict(agrs_or_dict, reverse=False):
+    if reverse:
+        agrs_or_dict = argparse.Namespace(**agrs_or_dict)
+    else:
+        agrs_or_dict = vars(agrs_or_dict)
+    return agrs_or_dict
+
+
+def get_args_from_config(ip_port_dict, ip, arg_type, node_type=None):
+    if node_type is None:
+        node_type = ["MASTER", "SLAVE"]
+    else:
+        node_type = [node_type]
+
+    arg_list = []
+    for _type in node_type:
+        if ip_port_dict.get(ip).get(_type):
+            if ip_port_dict.get(ip).get(_type).get(arg_type):
+                arg_list.append(ip_port_dict.get(ip).get(_type).get(arg_type))
+
+    return arg_list
+
+
 if __name__ == "__main__":
     # strs = r"D:\Project\temp\04384fcc9e8911ecbd2844f971944973\043876ca9e8911eca5e144f971944973_rar\1624114035529.jpeg"
     # print(slash_replace(strs))
@@ -1639,6 +1832,6 @@ if __name__ == "__main__":
 
     # print(parse_yaml())
 
-    print(get_ip_port(node_type='slave'))
-
+    print(get_ip_port())
+    print(get_args_from_config(get_ip_port(), "http://127.0.0.1", "gunicorn_path"))
     # print(get_intranet_ip())

+ 2 - 1
format_convert/wrapt_timeout_decorator/wrapt_timeout_decorator.py

@@ -15,6 +15,7 @@ import wrapt  # type: ignore
 
 # OWN
 sys.path.append(os.path.dirname(os.path.abspath(__file__)) + "/../")
+from format_convert.utils import get_platform
 from wrapt_timeout_decorator.wrap_helper import WrapHelper, detect_unpickable_objects_and_reraise
 from wrapt_timeout_decorator.wrap_function_multiprocess import Timeout
 
@@ -127,7 +128,7 @@ def timeout(
         wrap_helper = WrapHelper(
             dec_timeout, use_signals, timeout_exception, exception_message, dec_allow_eval, dec_hard_timeout, wrapped, instance, args, kwargs
         )
-        if not wrap_helper.dec_timeout_float:
+        if (not wrap_helper.dec_timeout_float) or (get_platform() == "Windows"):
             return wrapped(*wrap_helper.args, **wrap_helper.kwargs)
         else:
             return wrapped_with_timeout(wrap_helper)

+ 311 - 6
ocr/ocr_gpu_interface.py

@@ -1,4 +1,5 @@
 # encoding=utf8
+import argparse
 import base64
 import io
 import json
@@ -6,7 +7,7 @@ import pickle
 import sys
 import os
 import threading
-import zlib
+
 import numpy as np
 import redis
 
@@ -14,7 +15,9 @@ sys.path.append(os.path.dirname(os.path.abspath(__file__)) + "/../")
 import time
 import traceback
 os.environ['FLAGS_eager_delete_tensor_gb'] = '0'
-from format_convert.utils import request_post, test_gpu, get_intranet_ip, log, get_md5_from_bytes
+from format_convert.utils import request_post, test_gpu, get_intranet_ip, log, get_md5_from_bytes, \
+    to_share_memory, from_share_memory, get_np_type, namespace_to_dict, get_share_memory_list, get_ip_port, \
+    release_share_memory, get_share_memory, close_share_memory_list
 from flask import Flask, request
 from format_convert import _global
 from ocr.tools.infer import utility
@@ -25,12 +28,19 @@ logger = get_logger()
 # 接口配置
 app = Flask(__name__)
 
+ocr_model_dir = os.path.dirname(os.path.abspath(__file__)) + "/model/2.0/"
 
 lock = threading.RLock()
 
+ip_port_dict = get_ip_port()
+ip = 'http://127.0.0.1'
+ocr_port_list = ip_port_dict.get(ip).get("ocr")
+otr_port_list = ip_port_dict.get(ip).get("otr")
+
 
-redis_db = redis.StrictRedis(host='127.0.0.1', port='6379',
-                             db=1, password='bidi123456', health_check_interval=300)
+# redis_db = redis.StrictRedis(host='127.0.0.1', port='6379',
+#                              db=1, password='bidi123456', health_check_interval=300)
+redis_db = None
 
 
 # @app.route('/ocr', methods=['POST'])
@@ -207,6 +217,299 @@ def _ocr_gpu_redis():
             traceback.print_exc()
 
 
+# @app.route('/ocr', methods=['POST'])
+def _ocr_gpu_flask_sm():
+    start_time = time.time()
+    log("into _ocr")
+    _global._init()
+    _global.update({"port": globals().get("port")})
+    log("into _ocr -> _global " + str(time.time()-start_time))
+    start_time = time.time()
+
+    try:
+        if not request.form:
+            log("ocr no data!")
+            return json.dumps({"text": str([-9]), "bbox": str([-9])})
+        log("judge request.form " + str(time.time()-start_time))
+
+        _time = time.time()
+        result = json.loads(request.form.get("data"))
+        predictor_type = result.get("predictor_type")
+        model_type = result.get("model_type")
+        args = result.get("args")
+        args = namespace_to_dict(args, reverse=True)
+        _md5 = result.get("md5")
+        sm_name = result.get("sm_name")
+        sm_shape = result.get("sm_shape")
+        sm_dtype = result.get("sm_dtype")
+        sm_dtype = get_np_type(sm_dtype)
+        _global.update({"md5": _md5})
+        log("read data " + str(time.time()-_time))
+
+        # 读取共享内存
+        _time = time.time()
+        if sm_name:
+            inputs = from_share_memory(sm_name, sm_shape, sm_dtype)
+        else:
+            log("from_share_memory failed!")
+            raise Exception
+        log("data from share memory " + sm_name + " " + str(time.time()-_time))
+
+        # 获取对应predictor
+        if globals().get(predictor_type) is None:
+            start_time1 = time.time()
+            log("=== init " + model_type + " " + predictor_type + " model ===")
+            predictor, input_tensor, output_tensors = \
+                utility.create_predictor(args, predictor_type, logger)
+            globals().update({predictor_type: {"predictor": predictor,
+                                               "input_tensor": input_tensor,
+                                               "output_tensors": output_tensors}})
+            log("=== init " + model_type + " " + predictor_type + " model " + str(round(time.time()-start_time1, 2)) + " ===")
+        else:
+            predictor = globals().get(predictor_type).get("predictor")
+            input_tensor = globals().get(predictor_type).get("input_tensor")
+            output_tensors = globals().get(predictor_type).get("output_tensors")
+
+        _time = time.time()
+        with lock:
+            # 设置模型输入
+            input_tensor.copy_from_cpu(inputs)
+            # 运行
+            predictor.run()
+            # 获取模型输出
+            outputs = []
+            for output_tensor in output_tensors:
+                output = output_tensor.copy_to_cpu()
+                outputs.append(output)
+            preds = outputs[0]
+        gpu_time = round(float(time.time()-_time), 2)
+        log("gpu_time " + str(gpu_time))
+
+        _shape = preds.shape
+        _dtype = str(preds.dtype)
+
+        # 判断前一个读取完
+        _time = time.time()
+        while True:
+            shm = globals().get("shm")
+            if shm is None:
+                break
+            last_shape = globals().get("last_shape")
+            sm_data = np.ndarray(last_shape, dtype=sm_dtype, buffer=shm.buf)
+            if (sm_data == np.zeros(last_shape)).all():
+                try:
+                    _time1 = time.time()
+                    shm.close()
+                    shm.unlink()
+                    log("release share memory " + str(time.time()-_time1))
+                except FileNotFoundError:
+                    log("share memory " + shm.name + " not exists!")
+                break
+        log("wait for share memory being read " + str(time.time()-_time))
+
+        # 数据放入共享内存
+        _time = time.time()
+        shm = to_share_memory(preds)
+        globals().update({"shm": shm})
+        globals().update({"last_shape": _shape})
+        log("data into share memory " + str(shm.name) + " " + str(time.time()-_time))
+
+        # 释放内存
+        _time = time.time()
+        predictor.clear_intermediate_tensor()
+        predictor.try_shrink_memory()
+        log("ocr shrink memory " + str(time.time()-_time))
+
+        finish_time = round(float(time.time()-start_time), 2)
+        log("ocr model predict time - " + str(predictor_type) + " - " + str(gpu_time) + " " + str(finish_time))
+        return json.dumps({"gpu_time": gpu_time, "elapse": finish_time,
+                           "sm_name": shm.name, "sm_shape": _shape, "sm_dtype": _dtype})
+    except Exception as e:
+        finish_time = round(float(time.time()-start_time), 2)
+        traceback.print_exc()
+        return json.dumps({"gpu_time": gpu_time, "elapse": finish_time,
+                           "sm_name": None, "sm_shape": None, "sm_dtype": None})
+
+
+def _ocr():
+    start_time = time.time()
+    log("into _ocr")
+    _global._init()
+    _global.update({"port": globals().get("port")})
+    log("into _ocr -> _global " + str(time.time()-start_time))
+    start_time = time.time()
+    sm_list_name = "sml_ocr_"+str(port)
+    try:
+        # 初始化模型
+        for predictor_type in ["det", "cls", "rec"]:
+
+            args = init_ocr_args()
+            predictor, input_tensor, output_tensors = \
+                utility.create_predictor(args, predictor_type, logger)
+            globals().update({predictor_type: {"predictor": predictor,
+                                               "input_tensor": input_tensor,
+                                               "output_tensors": output_tensors}})
+            if predictor == "det":
+                inputs = np.zeros((1, 3, 1024, 1024), dtype=np.float32)
+            else:
+                inputs = np.zeros((30, 3, 32, 64), dtype=np.float32)
+
+            # init model by running once
+            input_tensor.copy_from_cpu(inputs)
+            predictor.run()
+            outputs = []
+            for output_tensor in output_tensors:
+                output = output_tensor.copy_to_cpu()
+                outputs.append(output)
+            log("finish init predictor " + predictor_type)
+
+        # 循环判断是否有新数据需处理
+        # full_sm_list = globals().get("sm_list")
+        while True:
+            try:
+                full_sm_list = get_share_memory_list(sm_list_name=sm_list_name)
+            except FileNotFoundError:
+                full_sm_list = get_share_memory_list(sm_list_name=sm_list_name, list_size=10)
+
+            try:
+                if full_sm_list[0] == "1" and full_sm_list[-1] == "1":
+                    log("empty_sm_list[0] " + full_sm_list[0])
+                    log("empty_sm_list[-1] " + full_sm_list[-1])
+                    log("empty_sm_list[1] " + full_sm_list[1])
+                    log("wait for " + str(time.time()-start_time))
+                    break
+            except ValueError:
+                continue
+
+        start_time = time.time()
+        _time = time.time()
+        _md5 = full_sm_list[1]
+        model_type = full_sm_list[2]
+        predictor_type = full_sm_list[3]
+        args = full_sm_list[4]
+        args = namespace_to_dict(eval(args), reverse=True)
+        sm_name = full_sm_list[5]
+        sm_shape = full_sm_list[6]
+        sm_shape = eval(sm_shape)
+        sm_dtype = full_sm_list[7]
+        sm_dtype = get_np_type(sm_dtype)
+        _global.update({"md5": _md5})
+        log("read data " + str(time.time()-_time))
+
+        # 读取共享内存
+        _time = time.time()
+        if sm_name:
+            inputs = from_share_memory(sm_name, sm_shape, sm_dtype)
+            log(predictor_type + " inputs shape " + str(inputs.shape))
+        else:
+            log("from_share_memory failed!")
+            raise Exception
+        log("data from share memory " + sm_name + " " + str(time.time()-_time))
+
+        # 获取对应predictor
+        if globals().get(predictor_type) is None:
+            start_time1 = time.time()
+            log("=== init " + model_type + " " + predictor_type + " model ===")
+            predictor, input_tensor, output_tensors = \
+                utility.create_predictor(args, predictor_type, logger)
+            globals().update({predictor_type: {"predictor": predictor,
+                                               "input_tensor": input_tensor,
+                                               "output_tensors": output_tensors}})
+            log("=== init " + model_type + " " + predictor_type + " model " + str(round(time.time()-start_time1, 2)) + " ===")
+        else:
+            predictor = globals().get(predictor_type).get("predictor")
+            input_tensor = globals().get(predictor_type).get("input_tensor")
+            output_tensors = globals().get(predictor_type).get("output_tensors")
+
+        _time = time.time()
+        with lock:
+            # 设置模型输入
+            input_tensor.copy_from_cpu(inputs)
+            # 运行
+            predictor.run()
+            # 获取模型输出
+            outputs = []
+            for output_tensor in output_tensors:
+                output = output_tensor.copy_to_cpu()
+                outputs.append(output)
+            preds = outputs[0]
+        gpu_time = round(float(time.time()-_time), 2)
+        log("gpu_time " + str(gpu_time))
+
+        # 数据放入共享内存
+        _time = time.time()
+        # 先释放之前的同名share memory
+        release_share_memory(get_share_memory(sm_name))
+        # 写入共享内存
+        shm = to_share_memory(preds, sm_name)
+        full_sm_list[5] = shm.name
+        full_sm_list[6] = str(preds.shape)
+        full_sm_list[7] = str(preds.dtype)
+        full_sm_list[8] = str(gpu_time)
+        full_sm_list[-1] = "0"
+        log("data into share memory " + str(shm.name) + " " + str(time.time()-_time))
+
+        # 释放共享内存
+        close_share_memory_list(full_sm_list)
+
+        # 释放内存
+        _time = time.time()
+        predictor.clear_intermediate_tensor()
+        predictor.try_shrink_memory()
+        log("ocr shrink memory " + str(time.time()-_time))
+
+        finish_time = round(float(time.time()-start_time), 2)
+        log("ocr model predict time - " + str(predictor_type) + " - " + str(gpu_time) + " " + str(finish_time))
+    except Exception as e:
+        finish_time = round(float(time.time()-start_time), 2)
+        traceback.print_exc()
+        raise
+
+
+def init_ocr_args():
+    return argparse.Namespace(
+        use_gpu=True,
+        ir_optim=True,
+        use_tensorrt=False,
+        gpu_mem=8000,
+        image_dir='',
+        det_algorithm='DB',
+        det_model_dir=ocr_model_dir+"det",
+        det_limit_side_len=1280,
+        det_limit_type='max',
+        det_db_thresh=0.1,
+        # det_db_box_thresh 漏行 调小
+        det_db_box_thresh=0.1,
+        # det_db_unclip_ratio 检测框的贴近程度
+        det_db_unclip_ratio=2.5,
+        # 对文字膨胀操作
+        use_dilation=False,
+        det_east_score_thresh=0.8,
+        det_east_cover_thresh=0.1,
+        det_east_nms_thresh=0.2,
+        rec_algorithm='CRNN',
+        rec_model_dir=ocr_model_dir+"rec/ch",
+        rec_image_shape="3, 32, 1000",
+        rec_char_type='ch',
+        rec_batch_num=30,
+        max_text_length=128,
+        rec_char_dict_path='./ppocr/utils/ppocr_keys_v1.txt',
+        use_space_char=True,
+        drop_score=0.5,
+        cls_model_dir=ocr_model_dir+"cls",
+        cls_image_shape="3, 32, 1000",
+        label_list=['0', '180'],
+        cls_batch_num=30,
+        cls_thresh=0.9,
+        enable_mkldnn=False,
+        use_zero_copy_run=True,
+        use_pdserving=False,
+        lang='ch',
+        det=True,
+        rec=True,
+        use_angle_cls=False)
+
+
 if __name__ == '__main__':
     if len(sys.argv) == 2:
         port = int(sys.argv[1])
@@ -226,6 +529,8 @@ if __name__ == '__main__':
     os.environ['CUDA_VISIBLE_DEVICES'] = str(using_gpu_index)
 
     # app.run(host='0.0.0.0', port=port, processes=1, threaded=False, debug=False)
-    app.run()
+    # app.run()
     # log("OCR running "+str(port))
-    # _ocr()
+
+    while True:
+        _ocr()

+ 3 - 2
ocr/ocr_interface.py

@@ -26,6 +26,7 @@ app = Flask(__name__)
 def _ocr():
     _global._init()
     _global.update({"port": globals().get("port")})
+    start_time = time.time()
 
     log("into ocr_interface _ocr")
     try:
@@ -37,13 +38,11 @@ def _ocr():
         # _md5 = get_md5_from_bytes(img_data)[0]
         _md5 = request.form.get("md5")
         _global.update({"md5": _md5})
-        log("====")
         ocr_model = globals().get("global_ocr_model")
         if ocr_model is None:
             log("----------- init ocr_model ------------")
             ocr_model = OcrModels().get_model()
             globals().update({"global_ocr_model": ocr_model})
-        log("----")
         text = picture2text(img_data, ocr_model)
         return json.dumps(text)
     except TimeoutError:
@@ -51,6 +50,8 @@ def _ocr():
     except:
         traceback.print_exc()
         return json.dumps({"text": str([-1]), "bbox": str([-1])})
+    finally:
+        log("ocr interface finish time " + str(time.time()-start_time))
 
 
 def ocr(data, ocr_model):

+ 7 - 3
ocr/tools/infer/predict_cls.py

@@ -23,7 +23,7 @@ sys.path.append(os.path.dirname(os.path.abspath(__file__)) + "/../../../")
 import requests
 from format_convert import _global
 
-from format_convert.utils import judge_error_code, log
+from format_convert.utils import judge_error_code, log, namespace_to_dict
 
 # sys.path.append(__dir__)
 # sys.path.append(os.path.abspath(os.path.join(__dir__, '../..')))
@@ -93,6 +93,7 @@ class TextClassifier(object):
         cls_res = [['', 0.0]] * img_num
         batch_num = self.cls_batch_num
         elapse = 0
+        all_gpu_time = 0
         for beg_img_no in range(0, img_num, batch_num):
             end_img_no = min(img_num, beg_img_no + batch_num)
             norm_img_batch = []
@@ -109,10 +110,12 @@ class TextClassifier(object):
             norm_img_batch = norm_img_batch.copy()
             starttime = time.time()
 
+            _time = time.time()
             self.input_tensor.copy_from_cpu(norm_img_batch)
             self.predictor.run()
-
             prob_out = self.output_tensors[0].copy_to_cpu()
+            gpu_time = time.time()-_time
+
             cls_result = self.postprocess_op(prob_out)
             elapse += time.time() - starttime
             for rno in range(len(cls_result)):
@@ -128,6 +131,7 @@ class TextClassifier(object):
             self.predictor.clear_intermediate_tensor()
             self.predictor.try_shrink_memory()
             # TextClassifier.shrink_memory_count += 1
+            log("ocr model predict time - cls - time " + str(all_gpu_time) + " - num " + str(img_num))
         return img_list, cls_res, elapse
 
 
@@ -206,7 +210,7 @@ class TextClassifier2(object):
             # norm_img_batch = compressed_array.read()
 
             # 调用GPU接口
-            _dict = {"inputs": norm_img_batch, "args": self.args, "md5": _global.get("md5")}
+            _dict = {"inputs": norm_img_batch, "args": str(namespace_to_dict(self.args)), "md5": _global.get("md5")}
             result = from_gpu_interface_redis(_dict, model_type="ocr", predictor_type="cls")
             if judge_error_code(result):
                 logging.error("from_gpu_interface failed! " + str(result))

+ 3 - 2
ocr/tools/infer/predict_det.py

@@ -23,7 +23,8 @@ sys.path.append(os.path.dirname(os.path.abspath(__file__)) + "/../../../")
 import requests
 from format_convert import _global
 
-from format_convert.utils import judge_error_code, log
+from format_convert.utils import judge_error_code, log, namespace_to_dict
+
 # sys.path.append(__dir__)
 # sys.path.append(os.path.abspath(os.path.join(__dir__, '../..')))
 os.environ["FLAGS_allocator_strategy"] = 'auto_growth'
@@ -320,7 +321,7 @@ class TextDetector2(object):
         # img = compressed_array.read()
 
         # 调用GPU接口
-        _dict = {"inputs": img, "args": self.args, "md5": _global.get("md5")}
+        _dict = {"inputs": img, "args": str(namespace_to_dict(self.args)), "md5": _global.get("md5")}
         result = from_gpu_interface_redis(_dict, model_type="ocr", predictor_type="det")
         if judge_error_code(result):
             logging.error("from_gpu_interface failed! " + str(result))

+ 3 - 2
ocr/tools/infer/predict_rec.py

@@ -22,6 +22,7 @@ import zlib
 import requests
 # sys.path.append(__dir__)
 # sys.path.append(os.path.abspath(os.path.join(__dir__, '../..')))
+
 sys.path.append(os.path.dirname(os.path.abspath(__file__)) + "/../../../")
 os.environ["FLAGS_allocator_strategy"] = 'auto_growth'
 import cv2
@@ -37,7 +38,7 @@ from ocr.ppocr.postprocess import build_post_process
 from ocr.ppocr.utils.logging import get_logger
 from ocr.ppocr.utils.utility import get_image_file_list, check_and_read_gif
 
-from format_convert.utils import judge_error_code, log
+from format_convert.utils import judge_error_code, log, namespace_to_dict
 from format_convert import _global
 
 
@@ -346,7 +347,7 @@ class TextRecognizer2(object):
             # norm_img_batch = compressed_array.read()
 
             # 调用GPU接口
-            _dict = {"inputs": norm_img_batch, "args": self.args, "md5": _global.get("md5")}
+            _dict = {"inputs": norm_img_batch, "args": str(namespace_to_dict(self.args)), "md5": _global.get("md5")}
             result = from_gpu_interface_redis(_dict, model_type="ocr", predictor_type="rec")
             if judge_error_code(result):
                 logging.error("from_gpu_interface failed! " + str(result))

二进制
otr/models/table-line-large.h5


二进制
otr/models/table-line.h5


+ 188 - 6
otr/otr_gpu_interface.py

@@ -27,7 +27,9 @@ import logging
 import cv2
 import numpy as np
 from flask import Flask, request
-from format_convert.utils import request_post, judge_error_code, get_intranet_ip, log, get_md5_from_bytes, get_platform
+from format_convert.utils import request_post, judge_error_code, get_intranet_ip, log, get_md5_from_bytes, get_platform, \
+    to_share_memory, from_share_memory, get_np_type, get_share_memory_list, release_share_memory, get_share_memory, \
+    close_share_memory_list
 from otr.table_line import table_net, table_line, table_preprocess, table_postprocess
 from format_convert import _global
 
@@ -39,8 +41,9 @@ app = Flask(__name__)
 lock = threading.RLock()
 
 
-redis_db = redis.StrictRedis(host='127.0.0.1', port='6379',
-                             db=1, password='bidi123456', health_check_interval=300)
+# redis_db = redis.StrictRedis(host='127.0.0.1', port='6379',
+#                              db=1, password='bidi123456', health_check_interval=300)
+redis_db = None
 
 
 # @app.route('/otr', methods=['POST'])
@@ -109,7 +112,7 @@ def _otr_flask():
         return base64.b64encode(pickle.dumps({"preds": None, "gpu_time": 0., "elapse": finish_time}))
 
 
-def _otr():
+def _otr_redis():
     start_time = time.time()
     log("into _otr")
     _global._init()
@@ -181,6 +184,183 @@ def _otr():
             traceback.print_exc()
 
 
+@app.route('/otr', methods=['POST'])
+def _otr_flask_sm():
+    start_time = time.time()
+    log("into _otr")
+    _global._init()
+    _global.update({"port": globals().get("port")})
+    log("into _otr -> _global " + str(time.time()-start_time))
+    start_time = time.time()
+
+    try:
+        if not request.form:
+            log("otr no data!")
+            return json.dumps({"list_line": str([-9])})
+        log("judge request.form " + str(time.time()-start_time))
+
+        _time = time.time()
+        result = json.loads(request.form.get("data"))
+        model_type = result.get("model_type")
+        args = result.get("args")
+        _md5 = result.get("md5")
+        sm_name = result.get("sm_name")
+        sm_shape = result.get("sm_shape")
+        sm_dtype = result.get("sm_dtype")
+        sm_dtype = get_np_type(sm_dtype)
+        _global.update({"md5": _md5})
+        log("read data " + str(time.time()-_time))
+
+        # 读取共享内存
+        _time = time.time()
+        if sm_name:
+            inputs = from_share_memory(sm_name, sm_shape, sm_dtype)
+        else:
+            log("from_share_memory failed!")
+            raise Exception
+        log("data from share memory " + sm_name + " " + str(time.time()-_time))
+
+        # 获取模型
+        model = globals().get(model_type)
+        if model is None:
+            start_time1 = time.time()
+            log("=== init " + model_type + " model ===")
+            model = OtrModels().get_model()
+            globals().update({model_type: model})
+            log("=== init " + model_type + " model " + str(round(time.time()-start_time1, 2)) + " ===")
+
+        # 运行
+        _time = time.time()
+        with lock:
+            pred = model.predict(inputs)
+            pred = pred[0]
+            _shape = pred.shape
+            _dtype = str(pred.dtype)
+            log("pred.shape " + str(pred.shape))
+        gpu_time = round(float(time.time()-_time), 2)
+
+        # 判断前一个读取完
+        _time = time.time()
+        while True:
+            shm = globals().get("shm")
+            if shm is None:
+                break
+            last_shape = globals().get("last_shape")
+            sm_data = np.ndarray(last_shape, dtype=sm_dtype, buffer=shm.buf)
+            if (sm_data == np.zeros(last_shape)).all():
+                try:
+                    _time1 = time.time()
+                    shm.close()
+                    shm.unlink()
+                    log("release share memory " + str(time.time()-_time1))
+                except FileNotFoundError:
+                    log("share memory " + shm.name + " not exists!")
+                break
+        log("wait for share memory being read " + str(time.time()-_time))
+
+        # 数据放入共享内存
+        _time = time.time()
+        shm = to_share_memory(pred)
+        globals().update({"shm": shm})
+        globals().update({"last_shape": _shape})
+        log("data into share memory " + str(shm.name) + " " + str(time.time()-_time))
+
+        finish_time = round(float(time.time()-start_time), 2)
+        log("otr model predict time " + str(gpu_time) + " " + str(finish_time))
+        return json.dumps({"gpu_time": gpu_time, "elapse": finish_time,
+                           "sm_name": shm.name, "sm_shape": _shape, "sm_dtype": _dtype})
+    except Exception as e:
+        finish_time = round(float(time.time()-start_time), 2)
+        traceback.print_exc()
+        return json.dumps({"gpu_time": 0., "elapse": finish_time,
+                           "sm_name": None, "sm_shape": None, "sm_dtype": None})
+
+
+def _otr():
+    start_time = time.time()
+    log("into _ocr")
+    _global._init()
+    _global.update({"port": globals().get("port")})
+    log("into _ocr -> _global " + str(time.time()-start_time))
+    start_time = time.time()
+
+    try:
+        # 循环判断是否有新数据需处理
+        while True:
+            try:
+                full_sm_list = get_share_memory_list(sm_list_name="sml_otr_"+str(globals().get("port")))
+            except FileNotFoundError:
+                full_sm_list = get_share_memory_list(sm_list_name="sml_otr_"+str(globals().get("port")), list_size=10)
+            try:
+                if full_sm_list[0] == "1" and full_sm_list[-1] == "1":
+                    log("empty_sm_list[0] " + full_sm_list[0])
+                    log("empty_sm_list[-1] " + full_sm_list[-1])
+                    log("empty_sm_list[1] " + full_sm_list[1])
+                    log("wait for " + str(time.time()-start_time))
+                    break
+            except ValueError:
+                continue
+
+        start_time = time.time()
+        _time = time.time()
+        _md5 = full_sm_list[1]
+        model_type = full_sm_list[2]
+        sm_name = full_sm_list[5]
+        sm_shape = full_sm_list[6]
+        sm_shape = eval(sm_shape)
+        sm_dtype = full_sm_list[7]
+        sm_dtype = get_np_type(sm_dtype)
+        _global.update({"md5": _md5})
+        log("read data " + str(time.time()-_time))
+
+        # 读取共享内存
+        _time = time.time()
+        if sm_name:
+            inputs = from_share_memory(sm_name, sm_shape, sm_dtype)
+        else:
+            log("from_share_memory failed!")
+            raise Exception
+        log("data from share memory " + sm_name + " " + str(time.time()-_time))
+
+        # 获取模型
+        model = globals().get(model_type)
+        if model is None:
+            start_time1 = time.time()
+            log("=== init " + model_type + " model ===")
+            model = OtrModels().get_model()
+            globals().update({model_type: model})
+            log("=== init " + model_type + " model " + str(round(time.time()-start_time1, 2)) + " ===")
+
+        # 运行
+        _time = time.time()
+        with lock:
+            pred = model.predict(inputs)
+            preds = pred[0]
+            log("preds.shape " + str(preds.shape))
+        gpu_time = round(float(time.time()-_time), 2)
+
+        # 数据放入共享内存
+        _time = time.time()
+        # 先释放之前的同名share memory
+        release_share_memory(get_share_memory(sm_name))
+        # 写入共享内存
+        shm = to_share_memory(preds, sm_name)
+        full_sm_list[5] = shm.name
+        full_sm_list[6] = str(preds.shape)
+        full_sm_list[7] = str(preds.dtype)
+        full_sm_list[8] = str(gpu_time)
+        full_sm_list[-1] = "0"
+        log("data into share memory " + str(shm.name) + " " + str(time.time()-_time))
+        close_share_memory_list(full_sm_list)
+
+        finish_time = round(float(time.time()-start_time), 2)
+        log("otr model predict time " + str(gpu_time) + " " + str(finish_time))
+    except Exception as e:
+        finish_time = round(float(time.time()-start_time), 2)
+        traceback.print_exc()
+        raise
+
+
 class OtrModels:
     def __init__(self):
         # python文件所在目录
@@ -206,9 +386,11 @@ if __name__ == '__main__':
 
     # _global._init()
     # _global.update({"port": str(port)})
-    # globals().update({"port": str(port)})
+    globals().update({"port": str(port)})
 
     # app.run(host='0.0.0.0', port=port, processes=1, threaded=False, debug=False)
     # app.run()
     # log("OTR running "+str(port))
-    _otr()
+
+    while True:
+        _otr()

+ 12 - 8
otr/otr_interface.py

@@ -14,14 +14,15 @@ try:
             [tf.config.experimental.VirtualDeviceConfiguration(memory_limit=2048)])
 except:
     traceback.print_exc()
-    pass
-# for gpu in gpus:  # 如果使用多块GPU时
-#     tf.config.experimental.set_memory_growth(gpu, True)
-
-# os.environ['CUDA_CACHE_MAXSIZE'] = str(2147483648)
-# os.environ['CUDA_CACHE_DISABLE'] = str(0)
-# gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.6)
-# sess = tf.Session(config=tf.ConfigProto(gpu_options=gpu_options))
+    # pass
+    # gpus = tf.config.list_physical_devices('GPU')
+    # for gpu in gpus:  # 如果使用多块GPU时
+    #     tf.config.experimental.set_memory_growth(gpu, True)
+
+    os.environ['CUDA_CACHE_MAXSIZE'] = str(2147483648)
+    os.environ['CUDA_CACHE_DISABLE'] = str(0)
+    gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.6)
+    sess = tf.Session(config=tf.ConfigProto(gpu_options=gpu_options))
 import sys
 sys.path.append(os.path.dirname(os.path.abspath(__file__)) + "/../")
 import time
@@ -46,6 +47,7 @@ app = Flask(__name__)
 def _otr():
     _global._init()
     _global.update({"port": globals().get("port")})
+    start_time = time.time()
 
     log("into otr_interface _otr")
     try:
@@ -74,6 +76,8 @@ def _otr():
     except:
         traceback.print_exc()
         return json.dumps({"list_line": str([-1])})
+    finally:
+        log("otr interface finish time " + str(time.time()-start_time))
 
 
 def otr(data, otr_model, is_from_pdf):

+ 85 - 1
otr/table_line.py

@@ -50,7 +50,7 @@ def focal_loss(gamma=3., alpha=.5):
     return focal_loss_fixed
 
 
-def table_net(input_shape=(1152, 896, 3), num_classes=1):
+def table_net_large(input_shape=(1152, 896, 3), num_classes=1):
     inputs = Input(shape=input_shape)
     # 512
     use_bias = False
@@ -199,6 +199,90 @@ def table_net(input_shape=(1152, 896, 3), num_classes=1):
     return model
 
 
+def table_net(input_shape=(1152, 896, 3), num_classes=1):
+    inputs = Input(shape=input_shape)
+    # 512
+    use_bias = False
+    down0a = Conv2D(16, (3, 3), padding='same', use_bias=use_bias)(inputs)
+    down0a = BatchNormalization()(down0a)
+    down0a = LeakyReLU(alpha=0.1)(down0a)
+    down0a = Conv2D(16, (3, 3), padding='same', use_bias=use_bias)(down0a)
+    down0a = BatchNormalization()(down0a)
+    down0a = LeakyReLU(alpha=0.1)(down0a)
+    down0a_pool = MaxPooling2D((2, 2), strides=(2, 2))(down0a)
+
+    # 256
+    down0 = Conv2D(32, (3, 3), padding='same', use_bias=use_bias)(down0a_pool)
+    down0 = BatchNormalization()(down0)
+
+    down0 = LeakyReLU(alpha=0.1)(down0)
+    down0 = Conv2D(32, (3, 3), padding='same', use_bias=use_bias)(down0)
+    down0 = BatchNormalization()(down0)
+    down0 = LeakyReLU(alpha=0.1)(down0)
+    down0_pool = MaxPooling2D((2, 2), strides=(2, 2))(down0)
+    # 128
+
+    down1 = Conv2D(64, (3, 3), padding='same', use_bias=use_bias)(down0_pool)
+    down1 = BatchNormalization()(down1)
+    down1 = LeakyReLU(alpha=0.1)(down1)
+    down1 = Conv2D(64, (3, 3), padding='same', use_bias=use_bias)(down1)
+    down1 = BatchNormalization()(down1)
+    down1 = LeakyReLU(alpha=0.1)(down1)
+    down1_pool = MaxPooling2D((2, 2), strides=(2, 2))(down1)
+    # 64
+
+    down2 = Conv2D(128, (3, 3), padding='same', use_bias=use_bias)(down1_pool)
+    down2 = BatchNormalization()(down2)
+    down2 = LeakyReLU(alpha=0.1)(down2)
+    down2 = Conv2D(128, (3, 3), padding='same', use_bias=use_bias)(down2)
+    down2 = BatchNormalization()(down2)
+    down2 = LeakyReLU(alpha=0.1)(down2)
+    down2_pool = MaxPooling2D((2, 2), strides=(2, 2))(down2)
+    # 32
+
+    up1 = UpSampling2D((2, 2))(down2)
+    up1 = concatenate([down1, up1], axis=3)
+    up1 = Conv2D(64, (3, 3), padding='same', use_bias=use_bias)(up1)
+    up1 = BatchNormalization()(up1)
+    up1 = LeakyReLU(alpha=0.1)(up1)
+    up1 = Conv2D(64, (3, 3), padding='same', use_bias=use_bias)(up1)
+    up1 = BatchNormalization()(up1)
+    up1 = LeakyReLU(alpha=0.1)(up1)
+    up1 = Conv2D(64, (3, 3), padding='same', use_bias=use_bias)(up1)
+    up1 = BatchNormalization()(up1)
+    up1 = LeakyReLU(alpha=0.1)(up1)
+    # 128
+
+    up0 = UpSampling2D((2, 2))(up1)
+    up0 = concatenate([down0, up0], axis=3)
+    up0 = Conv2D(32, (3, 3), padding='same', use_bias=use_bias)(up0)
+    up0 = BatchNormalization()(up0)
+    up0 = LeakyReLU(alpha=0.1)(up0)
+    up0 = Conv2D(32, (3, 3), padding='same', use_bias=use_bias)(up0)
+    up0 = BatchNormalization()(up0)
+    up0 = LeakyReLU(alpha=0.1)(up0)
+    up0 = Conv2D(32, (3, 3), padding='same', use_bias=use_bias)(up0)
+    up0 = BatchNormalization()(up0)
+    up0 = LeakyReLU(alpha=0.1)(up0)
+    # 256
+
+    up0a = UpSampling2D((2, 2))(up0)
+    up0a = concatenate([down0a, up0a], axis=3)
+    up0a = Conv2D(16, (3, 3), padding='same', use_bias=use_bias)(up0a)
+    up0a = BatchNormalization()(up0a)
+    up0a = LeakyReLU(alpha=0.1)(up0a)
+    up0a = Conv2D(16, (3, 3), padding='same', use_bias=use_bias)(up0a)
+    up0a = BatchNormalization()(up0a)
+    up0a = LeakyReLU(alpha=0.1)(up0a)
+    up0a = Conv2D(16, (3, 3), padding='same', use_bias=use_bias)(up0a)
+    up0a = BatchNormalization()(up0a)
+    up0a = LeakyReLU(alpha=0.1)(up0a)
+    # 512
+    classify = Conv2D(num_classes, (1, 1), activation='sigmoid')(up0a)
+    model = Model(inputs=inputs, outputs=classify)
+    return model
+
+
 model = table_net((None, None, 3), 2)
 
 

+ 73 - 1
result.html

@@ -1 +1,73 @@
-<!DOCTYPE HTML><head><meta charset="UTF-8"></head><body></body>
+<!DOCTYPE HTML><head><meta charset="UTF-8"></head><body><div>首政采法规购买服务监督检查信息公告GPA专栏PPP频道</div><div>当前位置:首页》政采公告》地方公告》竞争性碳商公告</div><div>海南师范大学食堂厨房设备采购及餐桌购置项目(二次采购)竞争性碳商</div><div>2021年12月21日16:53来源:中国政府采购网【打印】</div><div>项目概况</div><div>海南师范大学食堂厨房设备采购及餐桌购置项目(二次采购)采购项目的潜在供应商应在海南省海口市滨海大</div><div>道67号黄金海景大酒店1102房获取采购文件,并于2022年01月04日09点00分(北京时间)前提交响应文</div><div>件。</div><div>、项目基本情况</div><div>项目编号:HFCC20212063H-A</div><div>项目名称:海南师范大学食堂厨房设备采购及餐桌购置项目(二次采购)</div><div>采购方式:竞争性碳商</div><div>预算金额:48.1711000万元(人民币)</div><div>最高限价(如有):48.1711000万元(人民币)</div><div>采购需求:</div><table border="1">
+<tr>
+<td colspan=1 rowspan=1>序号</td>
+<td colspan=1 rowspan=1>采购品目名称</td>
+<td colspan=1 rowspan=1>数量</td>
+<td colspan=1 rowspan=1>单位</td>
+</tr>
+<tr>
+<td colspan=3 rowspan=1>标包号:A</td>
+<td colspan=1 rowspan=1></td>
+</tr>
+<tr>
+<td colspan=1 rowspan=1>1</td>
+<td colspan=1 rowspan=1>留样冰柜180L</td>
+<td colspan=1 rowspan=1>2</td>
+<td colspan=1 rowspan=1>台</td>
+</tr>
+<tr>
+<td colspan=1 rowspan=1>2</td>
+<td colspan=1 rowspan=1>四门冰箱860L</td>
+<td colspan=1 rowspan=1>3</td>
+<td colspan=1 rowspan=1>台</td>
+</tr>
+<tr>
+<td colspan=1 rowspan=1>3</td>
+<td colspan=1 rowspan=1>双通打荷台</td>
+<td colspan=1 rowspan=1>4</td>
+<td colspan=1 rowspan=1>台</td>
+</tr>
+<tr>
+<td colspan=1 rowspan=1>4</td>
+<td colspan=1 rowspan=1>四层货架</td>
+<td colspan=1 rowspan=1>22</td>
+<td colspan=1 rowspan=1>台</td>
+</tr>
+<tr>
+<td colspan=1 rowspan=1>5</td>
+<td colspan=1 rowspan=1>双门蒸柜24盘</td>
+<td colspan=1 rowspan=1>2</td>
+<td colspan=1 rowspan=1>台</td>
+</tr>
+<tr>
+<td colspan=1 rowspan=1>6</td>
+<td colspan=1 rowspan=1>不锈钢双层工作台</td>
+<td colspan=1 rowspan=1>8</td>
+<td colspan=1 rowspan=1>台</td>
+</tr>
+<tr>
+<td colspan=1 rowspan=1>7</td>
+<td colspan=1 rowspan=1>双星盆</td>
+<td colspan=1 rowspan=1>1</td>
+<td colspan=1 rowspan=1>台灣</td>
+</tr>
+<tr>
+<td colspan=1 rowspan=1>8</td>
+<td colspan=1 rowspan=1>靠背式单星洗米盆</td>
+<td colspan=1 rowspan=1>2</td>
+<td colspan=1 rowspan=1>台</td>
+</tr>
+<tr>
+<td colspan=1 rowspan=1>9</td>
+<td colspan=1 rowspan=1>六门冰柜</td>
+<td colspan=1 rowspan=1>2</td>
+<td colspan=1 rowspan=1>台</td>
+</tr>
+<tr>
+<td colspan=1 rowspan=1>10</td>
+<td colspan=1 rowspan=1>靠背式单星解冻盆</td>
+<td colspan=1 rowspan=1>1</td>
+<td colspan=1 rowspan=1>台</td>
+</tr>
+</table>
+</body>