fangjiasheng 3 years ago
parent
commit
7178b6138e

+ 4 - 0
format_convert/_global.py

@@ -7,6 +7,10 @@ def _init():
     global_dict = {}
 
 
+def get_dict():
+    return global_dict
+
+
 def update(_dict):
     # 定义一个全局变量
     try:

+ 11 - 12
format_convert/convert.py

@@ -3,9 +3,6 @@ import gc
 import json
 import sys
 import os
-import tracemalloc
-from io import BytesIO
-import objgraph
 sys.path.append(os.path.dirname(os.path.abspath(__file__)) + "/../")
 # 强制tf使用cpu
 os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
@@ -523,7 +520,8 @@ def _convert():
                 is_success = 0
             log("md5: " + str(_md5)
                          + " finished result: " + str(text)
-                         + " is_success: " + str(is_success)
+                         + " is_success: " + str(is_success) + " "
+                         + str(_type) + " "
                          + " " + str(time.time() - start_time))
             return json.dumps({"result_html": [str(text[0])], "result_text": [str(text[0])],
                                "is_success": is_success, "swf_images": str(swf_images)})
@@ -568,8 +566,8 @@ def _convert():
                            "is_success": 1, "swf_images": str(swf_images)})
 
     except ConnectionError:
-        log("convert post has no data!" + " failed result: [-2] is_success: 0 " +
-            str(time.time() - start_time))
+        log("convert post has no data!" + " failed result: [-2] is_success: 0 "
+            + str(time.time() - start_time))
         return json.dumps({"result_html": ["-2"], "result_text": ["-2"],
                            "is_success": 0, "swf_images": str([])})
     except Exception as e:
@@ -689,14 +687,15 @@ if __name__ == '__main__':
 
     globals().update({"md5": "1"+"0"*15})
     globals().update({"port": str(port)})
-    _global._init()
-    _global.update({"md5": "1"+"0"*15})
-    _global.update({"port": str(port)})
+    # _global._init()
+    # _global.update({"md5": "1"+"0"*15})
+    # _global.update({"port": str(port)})
 
-    ip = get_intranet_ip()
-    log("my ip"+str(ip))
+    # ip = get_intranet_ip()
+    # log("my ip"+str(ip))
+    # ip = "http://" + ip
     ip_port_dict = get_ip_port()
-    ip = "http://" + ip
+    ip = "http://127.0.0.1"
     processes = ip_port_dict.get(ip).get("convert_processes")
 
     set_flask_global()

+ 65 - 4
format_convert/convert_image.py

@@ -1,15 +1,23 @@
+# encoding=utf8
 import inspect
+import io
 import logging
 import os
 import sys
+import time
+import requests
+import numpy as np
+from ocr.paddleocr import PaddleOCR
+
 sys.path.append(os.path.dirname(__file__) + "/../")
 from pdfminer.layout import LTLine
 import traceback
 import cv2
-from format_convert import get_memory_info
+from format_convert import get_memory_info, _global
 from format_convert.utils import judge_error_code, add_div, LineTable, get_table_html, get_logger, log, memory_decorator
 from format_convert.table_correct import get_rotated_image
-from format_convert.convert_need_interface import from_otr_interface, from_ocr_interface
+from format_convert.convert_need_interface import from_otr_interface, from_ocr_interface, from_gpu_interface_redis
+from otr.table_line import table_preprocess, table_postprocess
 
 
 def image_process(image_np, image_path, is_from_pdf=False, is_from_docx=False, use_ocr=True):
@@ -100,20 +108,60 @@ def image_process(image_np, image_path, is_from_pdf=False, is_from_docx=False, u
         if judge_error_code(list_line):
             return list_line
 
+        # # 预处理
+        # if is_from_pdf:
+        #     prob = 0.2
+        # else:
+        #     prob = 0.5
+        # with open(image_resize_path, "rb") as f:
+        #     image_bytes = f.read()
+        # img_new, inputs = table_preprocess(image_bytes, prob)
+        # if type(img_new) is list and judge_error_code(img_new):
+        #     return img_new
+        # log("img_new.shape " + str(img_new.shape))
+        #
+        # # 调用模型运行接口
+        # _dict = {"inputs": inputs, "md5": _global.get("md5")}
+        # result = from_gpu_interface(_dict, model_type="otr", predictor_type="")
+        # if judge_error_code(result):
+        #     logging.error("from_gpu_interface failed! " + str(result))
+        #     raise requests.exceptions.RequestException
+        #
+        # pred = result.get("preds")
+        # gpu_time = result.get("gpu_time")
+        # log("otr model predict time " + str(gpu_time))
+        #
+        # # # 解压numpy
+        # # decompressed_array = io.BytesIO()
+        # # decompressed_array.write(pred)
+        # # decompressed_array.seek(0)
+        # # pred = np.load(decompressed_array, allow_pickle=True)['arr_0']
+        # # log("inputs.shape" + str(pred.shape))
+        #
+        # # 后处理
+        # list_line = table_postprocess(img_new, pred, prob)
+        # log("len(list_line) " + str(len(list_line)))
+        # if judge_error_code(list_line):
+        #     return list_line
+
         # otr resize后得到的bbox根据比例还原
+        start_time = time.time()
         ratio = (image_np.shape[0]/best_h, image_np.shape[1]/best_w)
         for i in range(len(list_line)):
             point = list_line[i]
             list_line[i] = [int(point[0]*ratio[1]), int(point[1]*ratio[0]),
                             int(point[2]*ratio[1]), int(point[3]*ratio[0])]
+        log("otr resize bbox recover " + str(time.time()-start_time))
 
         # ocr图片过大内存溢出,需resize
+        start_time = time.time()
         threshold = 3000
         if image_np.shape[0] >= threshold or image_np.shape[1] >= threshold:
             best_h, best_w = get_best_predict_size2(image_np, threshold)
             image_resize = cv2.resize(image_np, (best_w, best_h), interpolation=cv2.INTER_AREA)
-            image_resize_path = image_path.split(".")[0] + "_resize_ocr." + image_path.split(".")[-1]
-            cv2.imwrite(image_resize_path, image_resize)
+            # image_resize_path = image_path.split(".")[0] + "_resize_ocr." + image_path.split(".")[-1]
+            # cv2.imwrite(image_resize_path, image_resize)
+        log("ocr resize before " + str(time.time()-start_time))
 
         # 调用ocr模型接口
         with open(image_resize_path, "rb") as f:
@@ -122,6 +170,19 @@ def image_process(image_np, image_path, is_from_pdf=False, is_from_docx=False, u
         if judge_error_code(text_list):
             return text_list
 
+        # # PaddleOCR内部包括预处理,调用模型运行接口,后处理
+        # paddle_ocr = PaddleOCR(use_angle_cls=True, lang="ch")
+        # results = paddle_ocr.ocr(image_resize, det=True, rec=True, cls=True)
+        # # 循环每张图片识别结果
+        # text_list = []
+        # bbox_list = []
+        # for line in results:
+        #     # print("ocr_interface line", line)
+        #     text_list.append(line[-1][0])
+        #     bbox_list.append(line[0])
+        # if len(text_list) == 0:
+        #     return []
+
         # ocr resize后的bbox还原
         ratio = (image_np.shape[0]/best_h, image_np.shape[1]/best_w)
         for i in range(len(bbox_list)):

+ 191 - 37
format_convert/convert_need_interface.py

@@ -1,12 +1,16 @@
+# encoding=utf8
 import base64
 import inspect
 import json
 import logging
 import os
+import pickle
 import random
 import sys
 import time
-
+import uuid
+import zlib
+import redis
 from werkzeug.exceptions import NotFound
 sys.path.append(os.path.dirname(os.path.abspath(__file__)) + "/../")
 import traceback
@@ -19,38 +23,64 @@ from otr.otr_interface import otr, OtrModels
 from format_convert.libreoffice_interface import office_convert
 
 
-# 远程GPU接口
-# # interface_ip_list = ['http://192.168.2.102', 'http://192.168.2.103']
-# # interface_ip_list = ['http://172.16.160.65', 'http://172.16.160.64', 'http://172.16.160.66', 'http://172.16.160.67']
-# interface_ip_list = ['http://172.16.160.65', 'http://172.16.160.65']
-# # ocr_port_list = ["15011", "15013", "15015"]
-# # ocr_port_list = ["15011", "15013", "15015", "15017", "15019"]
-# # otr_port_list = ["15012", "15014", "15016", "15018", "15020"]
-# ocr_port_list = ["15011", "15013", "15015"]
-# otr_port_list = ["15012", "15014", "15016"]
-# # ocr_port_list = ["15011", "15013", "15015", "15017", "15019", "15021"]
-# # otr_port_list = ["15012", "15014", "15016", "15018", "15020", "15022"]
-# soffice_port_list = ["16000", "16001", "16002", "16003", "16004", "16005",
-#                      "16006", "16007", "16008", "16009"]
-# # ocr_port_list = ["15011", "15013"]
-# # otr_port_list = ["15012"]
-
 if get_platform() == "Windows":
     FROM_REMOTE = False
+    only_test_ocr = False
+
+    if only_test_ocr:
+        ip_port_flag = {}
+        ip_port_dict = get_ip_port()
+        for _k in ip_port_dict.keys():
+            ip_port_flag.update({_k: {"ocr": 0,
+                                      "otr": 0,
+                                      "convert": 0,
+                                      "office": 0
+                                      }})
+        _global.update({"ip_port_flag": ip_port_flag})
+        ip_port_dict["http://127.0.0.1"]["ocr"] = ["17000"]
+        ip_port_dict["http://127.0.0.1"]["otr"] = ["18000"]
+        _global.update({"ip_port": ip_port_dict})
 else:
     FROM_REMOTE = True
 
-# _global = {}
-# ip_port_flag = {}
-# ip_port_dict = get_ip_port()
-# for _k in ip_port_dict.keys():
-#     ip_port_flag.update({_k: {"ocr": 0,
-#                               "otr": 0,
-#                               "convert": 0,
-#                               "office": 0
-#                               }})
-# _global.update({"ip_port_flag": ip_port_flag})
-# _global.update({"ip_port": ip_port_dict})
+
+# 连接redis数据库
+# redis_db = redis.StrictRedis(host='192.168.2.103', port='6379',
+#                              db=1, password='bidi123456', health_check_interval=300)
+redis_db = None
+
+
+def _interface(_dict, time_out=60, retry_times=3, use_zlib=False):
+    try:
+        # 重试
+        model_type = _dict.get("model_type")
+        while retry_times:
+            ip_port = interface_pool(model_type)
+            if judge_error_code(ip_port):
+                return ip_port
+            _url = ip_port + "/" + model_type
+            base64_stream = base64.b64encode(pickle.dumps(_dict))
+            # if use_zlib:
+            #     base64_stream = zlib.compress(base64_stream)
+            r = pickle.loads(base64.b64decode(request_post(_url, {"data": base64_stream, "model_type": model_type}, time_out=time_out, use_zlib=use_zlib)))
+            log("get _interface return")
+            if type(r) == list:
+                # 接口连不上换个端口重试
+                if retry_times <= 1:
+                    return r
+                else:
+                    retry_times -= 1
+                    log("retry post _interface... left times " + str(retry_times) + " " + model_type)
+                    continue
+            if judge_error_code(r):
+                return r
+            return r
+            break
+
+    except TimeoutError:
+        return [-5]
+    except requests.exceptions.ConnectionError as e:
+        return [-2]
 
 
 def from_office_interface(src_path, dest_path, target_format, retry_times=1, from_remote=FROM_REMOTE):
@@ -78,7 +108,7 @@ def from_office_interface(src_path, dest_path, target_format, retry_times=1, fro
                 # _port = "16002"
                 # _ip, _port = interface_pool("soffice")
                 # ip_port = from_schedule_interface("office")
-                ip_port = interface_pool("office")
+                ip_port = interface_pool_gunicorn("office")
                 if judge_error_code(ip_port):
                     return ip_port
                 _url = ip_port + "/soffice"
@@ -92,6 +122,7 @@ def from_office_interface(src_path, dest_path, target_format, retry_times=1, fro
                                                    "file": base64_stream,
                                                    "target_format": target_format,
                                                    "retry_times": retry_times}, time_out=25))
+                log("get interface return")
                 log("office use time " + str(time.time()-start_time))
                 if type(r) == list:
                     # 接口连不上换个端口重试
@@ -149,11 +180,14 @@ def from_ocr_interface(image_stream, is_table=False, from_remote=FROM_REMOTE):
                     # _ip, _port = interface_pool("ocr")
                     # ip_port = _ip + ":" + _port
                     # ip_port = from_schedule_interface("ocr")
-                    ip_port = interface_pool("ocr")
+                    ip_port = interface_pool_gunicorn("ocr")
                     if judge_error_code(ip_port):
                         return ip_port
                     _url = ip_port + "/ocr"
-                    r = json.loads(request_post(_url, {"data": base64_stream}, time_out=60))
+                    r = json.loads(request_post(_url, {"data": base64_stream,
+                                                       "md5": _global.get("md5")},
+                                                time_out=60))
+                    log("get interface return")
                     if type(r) == list:
                         # 接口连不上换个端口重试
                         if retry_times_1 <= 1:
@@ -170,8 +204,8 @@ def from_ocr_interface(image_stream, is_table=False, from_remote=FROM_REMOTE):
                     break
             else:
                 if globals().get("global_ocr_model") is None:
-                    globals().update({"global_ocr_model": OcrModels().get_model()})
                     print("=========== init ocr model ===========")
+                    globals().update({"global_ocr_model": OcrModels().get_model()})
                 r = ocr(data=base64_stream, ocr_model=globals().get("global_ocr_model"))
         except TimeoutError:
             if is_table:
@@ -211,6 +245,62 @@ def from_ocr_interface(image_stream, is_table=False, from_remote=FROM_REMOTE):
             return [-1]
 
 
+def from_gpu_interface_flask(_dict, model_type, predictor_type):
+    log("into from_gpu_interface")
+    start_time = time.time()
+    try:
+        # 调用接口
+        _dict.update({"predictor_type": predictor_type, "model_type": model_type})
+        if model_type == "ocr":
+            use_zlib = True
+        else:
+            use_zlib = False
+        result = _interface(_dict, time_out=30, retry_times=2, use_zlib=use_zlib)
+        log("from_gpu_interface finish size " + str(sys.getsizeof(_dict)) + " time " + str(time.time()-start_time))
+        return result
+    except Exception as e:
+        log("from_gpu_interface error!")
+        log("from_gpu_interface failed " + str(time.time()-start_time))
+        traceback.print_exc()
+        return [-2]
+
+
+def from_gpu_interface_redis(_dict, model_type, predictor_type):
+    log("into from_gpu_interface")
+    start_time = time.time()
+    try:
+        # 调用接口
+        _uuid = uuid.uuid1().hex
+        _dict.update({"predictor_type": predictor_type, "model_type": model_type,
+                      "uuid": _uuid})
+        _time = time.time()
+        log("pickle.dumps(_dict)" + str(_dict))
+        redis_db.rpush("producer_"+model_type, pickle.dumps(_dict))
+        log("producer_" + model_type + " len " + str(redis_db.llen("producer_" + model_type)))
+        log("to producer_" + model_type + " time " + str(time.time()-_time))
+
+        _time = time.time()
+        time_out = 300
+        while True:
+            time.sleep(0.2)
+            if time.time() - _time > time_out:
+                raise Exception
+            if redis_db.hexists("consumer_"+model_type, _uuid):
+                time1 = time.time()
+                result = redis_db.hget("consumer_"+model_type, _uuid)
+                log("from consumer_"+model_type + " time " + str(time.time()-time1))
+                break
+
+        result = pickle.loads(result)
+        log("from_gpu_interface finish - size " + str(sys.getsizeof(_dict)) + " - time " + str(time.time()-start_time))
+        return result
+    except Exception as e:
+        log("from_gpu_interface error!")
+        log("from_gpu_interface failed " + str(time.time()-start_time))
+        traceback.print_exc()
+        return [-2]
+
+
 def from_otr_interface2(image_stream):
     log("into from_otr_interface")
     try:
@@ -269,12 +359,15 @@ def from_otr_interface(image_stream, is_from_pdf=False, from_remote=FROM_REMOTE)
                     # _port = port_pool("otr", _random=True)
                     # if _ip == interface_ip_list[1]:
                     #     _port = otr_port_list[0]
-                    ip_port = interface_pool("otr")
+                    ip_port = interface_pool_gunicorn("otr")
                     # ip_port = from_schedule_interface("otr")
                     if judge_error_code(ip_port):
                         return ip_port
                     _url = ip_port + "/otr"
-                    r = json.loads(request_post(_url, {"data": base64_stream, "is_from_pdf": is_from_pdf}, time_out=60))
+                    r = json.loads(request_post(_url, {"data": base64_stream,
+                                                       "is_from_pdf": is_from_pdf,
+                                                       "md5": _global.get("md5")}, time_out=60))
+                    log("get interface return")
                     if type(r) == list:
                         # 接口连不上换个端口重试
                         if retry_times_1 <= 1:
@@ -288,9 +381,10 @@ def from_otr_interface(image_stream, is_from_pdf=False, from_remote=FROM_REMOTE)
                     break
             else:
                 if globals().get("global_otr_model") is None:
-                    globals().update({"global_otr_model": OtrModels().get_model()})
                     print("=========== init otr model ===========")
+                    globals().update({"global_otr_model": OtrModels().get_model()})
                 r = otr(data=base64_stream, otr_model=globals().get("global_otr_model"), is_from_pdf=is_from_pdf)
+                # r = otr(data=base64_stream, otr_model=None, is_from_pdf=is_from_pdf)
         except TimeoutError:
             return [-5]
         except requests.exceptions.ConnectionError as e:
@@ -329,10 +423,17 @@ def from_otr_interface(image_stream, is_from_pdf=False, from_remote=FROM_REMOTE)
 #         return [-1]
 
 
-def interface_pool(interface_type):
+def interface_pool(interface_type, use_gunicorn=True):
     ip_port_flag = _global.get("ip_port_flag")
     ip_port_dict = _global.get("ip_port")
     try:
+        if use_gunicorn:
+            _ip = "http://127.0.0.1"
+            _port = ip_port_dict.get(_ip).get(interface_type)[0]
+            ip_port = _ip + ":" + str(_port)
+            log(ip_port)
+            return ip_port
+
         # 负载均衡, 选取ip
         interface_load_list = []
         for _ip in ip_port_flag.keys():
@@ -360,7 +461,7 @@ def interface_pool(interface_type):
         else:
             ip_port_flag[_ip][interface_type] = current_flag + 1
         _global.update({"ip_port_flag": ip_port_flag})
-        log(str(_global.get("ip_port_flag")))
+        # log(str(_global.get("ip_port_flag")))
 
         ip_port = _ip + ":" + str(_port)
         log(ip_port)
@@ -373,6 +474,59 @@ def interface_pool(interface_type):
         return [-1]
 
 
+def interface_pool_gunicorn(interface_type):
+    ip_flag_list = _global.get("ip_flag")
+    ip_port_flag_dict = _global.get("ip_port_flag")
+    ip_port_dict = _global.get("ip_port")
+    try:
+        if ip_flag_list is None or ip_port_dict is None or ip_port_flag_dict is None:
+            raise NotFound
+
+        if interface_type == "office":
+            _ip = "http://127.0.0.1"
+            # 选取端口
+            port_list = ip_port_dict.get(_ip).get(interface_type)
+            ip_type_cnt = ip_port_flag_dict.get(_ip).get(interface_type)
+            if ip_type_cnt == 0:
+                _port = port_list[random.randint(0, len(port_list)-1)]
+            else:
+                _port = port_list[ip_type_cnt % len(port_list)]
+            # 更新flag
+            if ip_port_flag_dict.get(_ip).get(interface_type) >= 10000:
+                ip_port_flag_dict[_ip][interface_type] = 0
+            else:
+                ip_port_flag_dict[_ip][interface_type] += 1
+            _global.update({"ip_port_flag": ip_port_flag_dict})
+
+        else:
+            # 负载均衡, 选取ip
+            ip_flag_list.sort(key=lambda x: x[1])
+            if ip_flag_list[-1][1] == 0:
+                ip_index = random.randint(0, len(ip_flag_list)-1)
+            else:
+                ip_index = 0
+            _ip = ip_flag_list[ip_index][0]
+            # 选取端口, 使用gunicorn则直接选第一个
+            _port = ip_port_dict.get(_ip).get(interface_type)[0]
+            # 更新flag
+            if ip_flag_list[ip_index][1] >= 10000:
+                ip_flag_list[ip_index][1] = 0
+            else:
+                ip_flag_list[ip_index][1] += + 1
+            _global.update({"ip_flag": ip_flag_list})
+
+        ip_port = _ip + ":" + str(_port)
+        log(ip_port)
+        return ip_port
+    except NotFound:
+        log("ip_flag or ip_port_dict is None! checkout config")
+        return [-2]
+    except:
+        traceback.print_exc()
+        return [-1]
+
+
+
 # def interface_pool(interface_type):
 #     try:
 #         ip_port_dict = _global.get("ip_port")

+ 14 - 10
format_convert/convert_test.py

@@ -1,6 +1,7 @@
 import base64
 import json
 import os
+import random
 import sys
 import time
 from multiprocessing.context import Process
@@ -19,8 +20,8 @@ def test_one(p, from_remote=False):
 
     data = {"file": file_base64, "type": p.split(".")[-1], "filemd5": 100}
     if from_remote:
-        # _url = 'http://121.46.18.113:15010/convert'
-        _url = 'http://192.168.2.102:15010/convert'
+        _url = 'http://121.46.18.113:15010/convert'
+        # _url = 'http://192.168.2.103:15010/convert'
         # _url = 'http://172.16.160.65:15010/convert'
         result = json.loads(request_post(_url, data, time_out=10000))
     else:
@@ -34,7 +35,8 @@ def test_one(p, from_remote=False):
 
 def test_duplicate(path_list, process_no=None):
     start_time = time.time()
-    for i in range(500):
+    # random.shuffle(path_list)
+    for i in range(1):
         if i % 10 == 0:
             if process_no is not None:
                 print("Process", process_no, i*len(path_list), time.time()-start_time)
@@ -46,7 +48,7 @@ def test_duplicate(path_list, process_no=None):
 
 if __name__ == '__main__':
     if get_platform() == "Windows":
-        # file_path = "C:/Users/Administrator/Desktop/error7.jpg"
+        # file_path = "C:/Users/Administrator/Desktop/error2.swf"
         # file_path = "D:/BIDI_DOC/比地_文档/2022/Test_Interface/20210609202634853485.xlsx"
         # file_path = "D:/BIDI_DOC/比地_文档/2022/Test_ODPS/1624325845476.pdf"
         file_path = "C:/Users/Administrator/Downloads/1652672734044.jpg"
@@ -62,13 +64,15 @@ if __name__ == '__main__':
     #     #                   "D:/BIDI_DOC/比地_文档/2022/Test_Interface/94961e1987d1090e.xls"]
     #     # file_path_list = ["D:/BIDI_DOC/比地_文档/2022/Test_Interface/1623423836610.pdf"]
     #     file_path_list = ["C:/Users/Administrator/Downloads/广东中检达元检测技术有限公司.pdf",
-    #                       "C:/Users/Administrator/Desktop/error11.pdf",
-    #                       "C:/Users/Administrator/Desktop/error9.pdf",
-    #                       "C:/Users/Administrator/Desktop/error16.jpg",
-    #                       "C:/Users/Administrator/Desktop/error9.jpg",]
+    #                       "C:/Users/Administrator/Desktop/error16.jpg"]
     # else:
     #     file_path_list = ["1623423836610.pdf"]
-    # for j in range(10):
+    # start_time = time.time()
+    # p_list = []
+    # for j in range(3):
     #     p = Process(target=test_duplicate, args=(file_path_list, j, ))
     #     p.start()
-    # p.join()
+    #     p_list.append(p)
+    # for p in p_list:
+    #     p.join()
+    # print("finish", time.time() - start_time)

+ 12 - 12
format_convert/interface.yml

@@ -5,15 +5,15 @@ MASTER:
 #  local-102: 'http://192.168.2.102'
 #  local-103: 'http://192.168.2.103'
 #  local 'http://127.0.0.1'
-  ip: ['http://172.20.1.251']
+  ip:
 
   PATH:
 #  65: /root/miniconda3/bin/python
 #  102: /home/python/anaconda3/envs/convert/bin/python
 #  103: /home/yons/anaconda3/envs/tf1.5/bin/python
 #  251: /data/anaconda3/envs/convert/bin/python
-    python: '/data/anaconda3/envs/convert/bin/python'
-#  65: /data/format_conversion_maxcompute/
+    python: '/data/anaconda3/envs/convert3/bin/gunicorn'
+    #  65: /data/format_conversion_maxcompute/
 #  102: /data/fangjiasheng/format_conversion_maxcompute/
 #  103: /data/python/fangjiasheng/format_conversion_maxcompute/
 #  251: /data/fangjiasheng/format_conversion_maxcompute/
@@ -21,30 +21,30 @@ MASTER:
 
   CONVERT:
     port: 15010
-    processes: 25
+    processes: 20
 
   SCHEDULE:
     port:
 
   OCR:
     port_start: 17000
-    port_no: 6
+    port_no: 3
 
   OTR:
     port_start: 18000
-    port_no: 6
+    port_no: 3
 
   OFFICE:
     port_start: 16000
-    port_no: 24
+    port_no: 25
 
 
 SLAVE:
-  ip:
+  ip: ['http://127.0.0.1']
 
   PATH:
-    python:
-    project:
+    python: '/data/anaconda3/envs/tf2/bin/gunicorn'
+    project: '/data/fangjiasheng/format_conversion_maxcompute/'
 
   CONVERT:
     port:
@@ -55,11 +55,11 @@ SLAVE:
 
   OCR:
     port_start: 17000
-    port_no: 1
+    port_no: 2
 
   OTR:
     port_start: 18000
-    port_no: 1
+    port_no: 4
 
   OFFICE:
     port_start:

+ 3 - 2
format_convert/kill_all.py

@@ -7,7 +7,8 @@ from format_convert.utils import get_ip_port, get_intranet_ip
 
 
 ip_port_dict = get_ip_port()
-ip = "http://" + get_intranet_ip()
+# ip = "http://" + get_intranet_ip()
+ip = "http://127.0.0.1"
 python_path = ip_port_dict.get(ip).get("python_path")
 project_path = ip_port_dict.get(ip).get("project_path")
 
@@ -29,7 +30,7 @@ def kill():
                 comm = "kill -9 " + str(pid)
                 print(comm, process_cmd)
                 os.system(comm)
-            if re.search("gunicorn", process_cmd):
+            if re.search(project_path, process_cmd):
                 comm = "kill -9 " + str(pid)
                 print(comm, process_cmd)
                 os.system(comm)

+ 12 - 3
format_convert/kill_office.py

@@ -6,19 +6,28 @@ import psutil
 logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
 
 
-def kill_soffice(limit_sec=12):
+def kill_soffice(limit_sec=30):
     pid_list = psutil.pids()
     for pid in pid_list:
         process = psutil.Process(pid)
+
+        process_cmd = ''
+        for c in process.cmdline():
+            process_cmd += c + " "
+        if process_cmd.strip() == "":
+            continue
+
+        if process.status() == "zombie":
+            print("zombie cmd", process_cmd)
+
         if re.search("soffice", process.exe()):
             start_time = process.create_time()
             now_time = time.time()
             run_time = now_time-start_time
-            # logging.info("pid " + str(run_time))
             if run_time >= limit_sec:
                 comm = "kill -9 " + str(pid)
-                print("kill process ", str(pid), str(process.exe()), str(run_time), ">", limit_sec)
                 os.system("echo $(date +%F%n%T)")
+                print("kill process ", str(pid), str(process.exe()), str(run_time), ">", limit_sec)
                 os.system(comm)
 
 

+ 37 - 21
format_convert/monitor_process_config.py

@@ -1,3 +1,4 @@
+import datetime
 import logging
 import os
 import re
@@ -9,13 +10,22 @@ from format_convert.utils import get_ip_port, get_intranet_ip
 
 
 ip_port_dict = get_ip_port()
-ip = "http://" + get_intranet_ip()
-# ip = "http://127.0.0.1"
+# ip = "http://" + get_intranet_ip()
+ip = "http://127.0.0.1"
 convert_port_list = ip_port_dict.get(ip).get("convert")
 ocr_port_list = ip_port_dict.get(ip).get("ocr")
 otr_port_list = ip_port_dict.get(ip).get("otr")
 soffice_port_list = ip_port_dict.get(ip).get("office")
-schedule_port_list = ip_port_dict.get(ip).get("schedule")
+if not convert_port_list:
+    convert_port_list = []
+if not ocr_port_list:
+    ocr_port_list = []
+if not otr_port_list:
+    otr_port_list = []
+if not soffice_port_list:
+    soffice_port_list = []
+
+# schedule_port_list = ip_port_dict.get(ip).get("schedule")
 python_path = ip_port_dict.get(ip).get("python_path")
 project_path = ip_port_dict.get(ip).get("project_path")
 
@@ -24,13 +34,23 @@ interface_path = project_path[:-1]
 std_out = " >>/convert.out 2>&1 &"
 std_out_gpu = " >>/gpu.out 2>&1 &"
 std_out_schedule = " >>/schedule.out 2>&1 &"
-convert_comm = "nohup " + python_path + " " + interface_path + "/format_convert/convert.py #" + std_out
-ocr_comm = "nohup " + python_path + " " + interface_path + "/ocr/ocr_interface.py # 0" + std_out_gpu
-otr_comm = "nohup " + python_path + " " + interface_path + "/otr/otr_interface.py # 0" + std_out_gpu
+# convert_comm = "nohup " + python_path + " " + interface_path + "/format_convert/convert.py #" + std_out
+# ocr_comm = "nohup " + python_path + " " + interface_path + "/ocr/ocr_gpu_interface.py # 0" + std_out_gpu
+# otr_comm = "nohup " + python_path + " " + interface_path + "/otr/otr_gpu_interface.py # 0" + std_out_gpu
 schedule_comm = "nohup " + python_path + " " + interface_path + "/format_convert/schedule_interface.py #" + std_out_schedule
 soffice_comm = "docker run --init -itd --log-opt max-size=10m --log-opt max-file=3 -p #:16000 soffice:v2 bash"
 
 
+gunicorn_path = python_path
+print("convert_port_list", len(convert_port_list))
+convert_comm = "nohup " + gunicorn_path + " -w " + str(len(convert_port_list)) + " -t 300 -b 0.0.0.0:# --chdir " \
+               + interface_path + "/format_convert convert:app" + std_out
+ocr_comm = "nohup " + gunicorn_path + " -w " + str(len(ocr_port_list)) + " -t 300 --keep-alive 600 -b 0.0.0.0:# --chdir " \
+           + interface_path + "/ocr ocr_interface:app" + std_out_gpu
+otr_comm = "nohup " + gunicorn_path + " -w " + str(len(otr_port_list)) + " -t 300 --keep-alive 600 -b 0.0.0.0:# --chdir " \
+           + interface_path + "/otr otr_interface:app" + std_out_gpu
+
+
 def get_port():
     net_conn = psutil.net_connections()
     current_port_list = []
@@ -38,7 +58,6 @@ def get_port():
         current_port_list.append(str(conn.laddr.port))
     current_port_list = list(set(current_port_list))
     current_port_list.sort(key=lambda x: x)
-    # print(current_port_list)
     return current_port_list
 
 
@@ -56,9 +75,8 @@ def restart(process_type, port):
     else:
         _comm = "netstat -nltp"
         print("no process_type", process_type)
-    os.system("echo $(date +%F%n%T)")
-    print("restart comm", _comm)
-    # os.system("netstat -nltp")
+    # os.system("echo $(date +%F%n%T)")
+    print(datetime.datetime.now(), "restart comm", _comm)
     os.system(_comm)
 
 
@@ -82,8 +100,7 @@ def kill_soffice(limit_sec=30):
             run_time = now_time-start_time
             if run_time >= limit_sec:
                 comm = "kill -9 " + str(pid)
-                os.system("echo $(date +%F%n%T)")
-                print("kill process ", str(pid), str(process.exe()), str(run_time), ">", limit_sec)
+                print(datetime.datetime.now(), "kill process ", str(pid), str(process.exe()), str(run_time), ">", limit_sec)
                 os.system(comm)
 
 
@@ -99,7 +116,7 @@ def kill_nested_timeout_process():
         if process_cmd.strip() == "":
             continue
 
-        if re.search("convert\.py|gunicorn", process_cmd):
+        if re.search("convert:app", process_cmd):
             ppid = process.ppid()
             start_time = process.create_time()
             now_time = time.time()
@@ -108,7 +125,6 @@ def kill_nested_timeout_process():
                 suspect_pid_list.append([str(pid), float(run_time)])
 
     # 时间最久的父进程为1的不能杀,是接口主进程
-    # print("suspect_pid_list", str(suspect_pid_list))
     if len(suspect_pid_list) <= 1:
         return
     else:
@@ -116,25 +132,25 @@ def kill_nested_timeout_process():
         for pid, run_time in suspect_pid_list[1:]:
             # print("pid", pid, run_time)
             comm = "kill -9 " + str(pid)
-            print("kill process ", str(pid), "father is 1", process_cmd)
+            print(datetime.datetime.now(), "kill process ", str(pid), "father is 1", process_cmd)
             os.system(comm)
 
 
 def monitor():
     current_port_list = get_port()
 
-    # if convert_port_list:
-    #     for p in convert_port_list:
-    #         if p not in current_port_list:
-    #             restart("convert", p)
+    if convert_port_list:
+        for p in convert_port_list[:1]:
+            if p not in current_port_list:
+                restart("convert", p)
 
     if ocr_port_list:
-        for p in ocr_port_list:
+        for p in ocr_port_list[:1]:
             if p not in current_port_list:
                 restart("ocr", p)
 
     if otr_port_list:
-        for p in otr_port_list:
+        for p in otr_port_list[:1]:
             if p not in current_port_list:
                 restart("otr", p)
 

+ 22 - 5
format_convert/utils.py

@@ -1,3 +1,4 @@
+# -*- coding:utf-8 -*-
 import hashlib
 import inspect
 import json
@@ -823,7 +824,7 @@ class LineTable:
                 if dict_line_points[_line]["direct"]=="column":
                     _line = lines[1]
                 next_point = None
-                for p1 in  dict_line_points[_line]["points"]:
+                for p1 in dict_line_points[_line]["points"]:
                     if p1["buttom"]>=margin and p1["point"][0]>_point["point"][0]:
                         next_point = p1
                         break
@@ -1348,7 +1349,12 @@ def sort_object(obj_list, is_reverse=False):
         return obj_list
 
 
-def request_post(url, param, time_out=1000):
+session_ocr = requests.Session()
+session_otr = requests.Session()
+session_all = requests.Session()
+
+
+def request_post(url, param, time_out=1000, use_zlib=False):
     fails = 0
     text = json.dumps([-2])
     while True:
@@ -1357,7 +1363,14 @@ def request_post(url, param, time_out=1000):
                 break
 
             headers = {'content-type': 'application/json'}
-            result = requests.post(url, data=param, timeout=time_out)
+            # result = requests.post(url, data=param, timeout=time_out)
+            if param.get("model_type") == "ocr":
+                result = session_ocr.post(url, data=param, timeout=time_out)
+            elif param.get("model_type") == "otr":
+                result = session_otr.post(url, data=param, timeout=time_out)
+            else:
+                result = session_all.post(url, data=param, timeout=time_out)
+
             # print('result.status_code', result.status_code)
             # print('result.text', result.text)
 
@@ -1444,9 +1457,10 @@ def get_ip_port(node_type=None, interface_type=None):
                 if _port is None:
                     port_list = []
                 else:
-                    port_list = [str(_port)]
                     if interface_type == "convert".upper():
                         processes = params.get(node_type).get(interface_type).get("processes")
+                    port_list = [str(_port)]*int(processes)
+                    # port_list = [str(_port)]
             elif interface_type == "path".upper():
                 python_path = params.get(node_type).get(interface_type).get("python")
                 project_path = params.get(node_type).get(interface_type).get("project")
@@ -1560,6 +1574,7 @@ def get_logger(_name, _dict):
 def set_flask_global():
     # 接口轮询所需锁、参数
     ip_port_flag = {}
+    ip_flag = []
     ip_port_dict = get_ip_port()
     for _k in ip_port_dict.keys():
         ip_port_flag.update({_k: {"ocr": 0,
@@ -1567,8 +1582,10 @@ def set_flask_global():
                                   "convert": 0,
                                   "office": 0
                                   }})
+        ip_flag.append([_k, 0])
     _global.update({"ip_port_flag": ip_port_flag})
     _global.update({"ip_port": ip_port_dict})
+    _global.update({"ip_flag": ip_flag})
     # print(globals().get("ip_port"))
 
 
@@ -1622,6 +1639,6 @@ if __name__ == "__main__":
 
     # print(parse_yaml())
 
-    print(get_ip_port())
+    print(get_ip_port(node_type='slave'))
 
     # print(get_intranet_ip())

+ 231 - 0
ocr/ocr_gpu_interface.py

@@ -0,0 +1,231 @@
+# encoding=utf8
+import base64
+import io
+import json
+import pickle
+import sys
+import os
+import threading
+import zlib
+import numpy as np
+import redis
+
+sys.path.append(os.path.dirname(os.path.abspath(__file__)) + "/../")
+import time
+import traceback
+os.environ['FLAGS_eager_delete_tensor_gb'] = '0'
+from format_convert.utils import request_post, test_gpu, get_intranet_ip, log, get_md5_from_bytes
+from flask import Flask, request
+from format_convert import _global
+from ocr.tools.infer import utility
+from ocr.ppocr.utils.logging import get_logger
+
+logger = get_logger()
+
+# 接口配置
+app = Flask(__name__)
+
+
+lock = threading.RLock()
+
+
+redis_db = redis.StrictRedis(host='127.0.0.1', port='6379',
+                             db=1, password='bidi123456', health_check_interval=300)
+
+
+# @app.route('/ocr', methods=['POST'])
+def _ocr_gpu_flask():
+    start_time = time.time()
+    log("into _ocr")
+    _global._init()
+    _global.update({"port": globals().get("port")})
+    log("into _ocr -> _global " + str(time.time()-start_time))
+    start_time = time.time()
+
+    try:
+        if not request.form:
+            log("ocr no data!")
+            return json.dumps({"text": str([-9]), "bbox": str([-9])})
+        log("judge request.form " + str(time.time()-start_time))
+
+        start_time1 = time.time()
+        result = pickle.loads(base64.b64decode(request.form.get("data")))
+
+        # 解压
+        inputs = result.get("inputs")
+        # 解压numpy
+        decompressed_array = io.BytesIO()
+        decompressed_array.write(inputs)
+        decompressed_array.seek(0)
+        inputs = np.load(decompressed_array, allow_pickle=True)['arr_0']
+        log("inputs.shape" + str(inputs.shape))
+
+        args = result.get("args")
+        predictor_type = result.get("predictor_type")
+        model_type = result.get("model_type")
+        _md5 = result.get("md5")
+        _global.update({"md5": _md5})
+        log("read data " + str(time.time()-start_time1))
+
+        # 获取对应predictor
+        if globals().get(predictor_type) is None:
+            start_time1 = time.time()
+            log("=== init " + model_type + " " + predictor_type + " model ===")
+            predictor, input_tensor, output_tensors = \
+                utility.create_predictor(args, predictor_type, logger)
+            globals().update({predictor_type: {"predictor": predictor,
+                                               "input_tensor": input_tensor,
+                                               "output_tensors": output_tensors}})
+            log("=== init " + model_type + " " + predictor_type + " model " + str(round(time.time()-start_time1, 2)) + " ===")
+        else:
+            predictor = globals().get(predictor_type).get("predictor")
+            input_tensor = globals().get(predictor_type).get("input_tensor")
+            output_tensors = globals().get(predictor_type).get("output_tensors")
+
+        # 设置模型输入,运行
+        input_tensor.copy_from_cpu(inputs)
+        with lock:
+            start_time1 = time.time()
+            predictor.run()
+            gpu_time = round(float(time.time()-start_time1), 2)
+
+        # 获取模型输出
+        outputs = []
+        for output_tensor in output_tensors:
+            output = output_tensor.copy_to_cpu()
+            outputs.append(output)
+        preds = outputs[0]
+
+        # 压缩numpy
+        compressed_array = io.BytesIO()
+        np.savez_compressed(compressed_array, preds)
+        compressed_array.seek(0)
+        preds = compressed_array.read()
+
+        # 释放内存
+        predictor.clear_intermediate_tensor()
+        predictor.try_shrink_memory()
+
+        finish_time = round(float(time.time()-start_time), 2)
+        log("ocr model predict time - " + str(predictor_type) + " - " + str(gpu_time) + " " + str(finish_time))
+        return base64.b64encode(pickle.dumps({"preds": preds, "gpu_time": gpu_time, "elapse": finish_time}))
+    except Exception as e:
+        finish_time = round(float(time.time()-start_time), 2)
+        traceback.print_exc()
+        return base64.b64encode(pickle.dumps({"preds": None, "gpu_time": 0., "elapse": finish_time}))
+
+
+def _ocr_gpu_redis():
+    start_time = time.time()
+    log("into _ocr")
+    _global._init()
+    _global.update({"port": globals().get("port")})
+    log("into _ocr -> _global " + str(time.time()-start_time))
+
+    while True:
+        start_time = time.time()
+        try:
+            if redis_db.llen("producer_ocr") == 0:
+                continue
+            log("judge llen " + str(time.time()-start_time))
+
+            _time = time.time()
+            result = redis_db.lpop("producer_ocr")
+            if result is None:
+                continue
+            result = pickle.loads(result)
+            log("from producer_ocr time " + str(time.time() - _time))
+
+            _time = time.time()
+            inputs = result.get("inputs")
+            # # 解压numpy
+            # decompressed_array = io.BytesIO()
+            # decompressed_array.write(inputs)
+            # decompressed_array.seek(0)
+            # inputs = np.load(decompressed_array, allow_pickle=True)['arr_0']
+            # log("inputs.shape " + str(inputs.shape))
+            # log("numpy decompress " + str(time.time()-_time))
+
+            args = result.get("args")
+            _uuid = result.get("uuid")
+            predictor_type = result.get("predictor_type")
+            model_type = result.get("model_type")
+            _md5 = result.get("md5")
+            _global.update({"md5": _md5})
+            log("read data " + str(time.time()-_time))
+
+            # 获取对应predictor
+            if globals().get(predictor_type) is None:
+                start_time1 = time.time()
+                log("=== init " + model_type + " " + predictor_type + " model ===")
+                predictor, input_tensor, output_tensors = \
+                    utility.create_predictor(args, predictor_type, logger)
+                globals().update({predictor_type: {"predictor": predictor,
+                                                   "input_tensor": input_tensor,
+                                                   "output_tensors": output_tensors}})
+                log("=== init " + model_type + " " + predictor_type + " model " + str(round(time.time()-start_time1, 2)) + " ===")
+            else:
+                predictor = globals().get(predictor_type).get("predictor")
+                input_tensor = globals().get(predictor_type).get("input_tensor")
+                output_tensors = globals().get(predictor_type).get("output_tensors")
+
+            # 设置模型输入,运行
+            input_tensor.copy_from_cpu(inputs)
+            start_time1 = time.time()
+            predictor.run()
+            gpu_time = round(float(time.time()-start_time1), 2)
+
+            # 获取模型输出
+            _time = time.time()
+            outputs = []
+            for output_tensor in output_tensors:
+                output = output_tensor.copy_to_cpu()
+                outputs.append(output)
+            preds = outputs[0]
+            log("output_tensors " + str(time.time()-_time))
+
+            # # 压缩numpy
+            # _time = time.time()
+            # compressed_array = io.BytesIO()
+            # np.savez_compressed(compressed_array, preds)
+            # compressed_array.seek(0)
+            # preds = compressed_array.read()
+            # log("numpy compress " + str(time.time()-_time))
+
+
+            # 写入redis
+            finish_time = round(float(time.time()-start_time), 2)
+            _time = time.time()
+            redis_db.hset("consumer_ocr", _uuid, pickle.dumps({"preds": preds, "gpu_time": gpu_time, "elapse": finish_time}))
+            log("to consumer_ocr " + str(time.time()-_time))
+
+            # 释放内存
+            predictor.clear_intermediate_tensor()
+            predictor.try_shrink_memory()
+            log("ocr model predict time - " + str(predictor_type) + " - " + str(gpu_time) + " " + str(finish_time))
+        except Exception as e:
+            traceback.print_exc()
+
+
+if __name__ == '__main__':
+    if len(sys.argv) == 2:
+        port = int(sys.argv[1])
+        using_gpu_index = 0
+    elif len(sys.argv) == 3:
+        port = int(sys.argv[1])
+        using_gpu_index = int(sys.argv[2])
+    else:
+        port = 17000
+        using_gpu_index = 0
+    _global._init()
+    _global.update({"port": str(port)})
+    globals().update({"port": str(port)})
+
+    ip = get_intranet_ip()
+
+    os.environ['CUDA_VISIBLE_DEVICES'] = str(using_gpu_index)
+
+    # app.run(host='0.0.0.0', port=port, processes=1, threaded=False, debug=False)
+    app.run()
+    # log("OCR running "+str(port))
+    # _ocr()

+ 30 - 19
ocr/ocr_interface.py

@@ -13,7 +13,6 @@ import requests
 import logging
 import numpy as np
 os.environ['FLAGS_eager_delete_tensor_gb'] = '0'
-from ocr.paddleocr import PaddleOCR
 from format_convert.utils import request_post, test_gpu, get_intranet_ip, log, get_md5_from_bytes
 from flask import Flask, request
 from format_convert import _global
@@ -33,16 +32,18 @@ def _ocr():
         if not request.form:
             log("ocr no data!")
             return json.dumps({"text": str([-9]), "bbox": str([-9])})
-
+        data = request.form.get("data")
+        img_data = base64.b64decode(data)
+        # _md5 = get_md5_from_bytes(img_data)[0]
+        _md5 = request.form.get("md5")
+        _global.update({"md5": _md5})
+        log("====")
         ocr_model = globals().get("global_ocr_model")
         if ocr_model is None:
+            log("----------- init ocr_model ------------")
             ocr_model = OcrModels().get_model()
             globals().update({"global_ocr_model": ocr_model})
-
-        data = request.form.get("data")
-        img_data = base64.b64decode(data)
-        _md5 = get_md5_from_bytes(img_data)[0]
-        _global.update({"md5": _md5})
+        log("----")
         text = picture2text(img_data, ocr_model)
         return json.dumps(text)
     except TimeoutError:
@@ -108,7 +109,7 @@ def picture2text(img_data, ocr_model):
         # cv2.imshow("bbox", img)
         # cv2.waitKey(0)
 
-        log("ocr model use time: " + str(time.time()-start_time))
+        # log("ocr model use time: " + str(time.time()-start_time))
         return {"text": str(text_list), "bbox": str(bbox_list)}
 
     except TimeoutError:
@@ -141,6 +142,7 @@ def get_best_predict_size(image_np):
 
 class OcrModels:
     def __init__(self):
+        from ocr.paddleocr import PaddleOCR
         try:
             self.ocr_model = PaddleOCR(use_angle_cls=True, lang="ch")
         except:
@@ -151,16 +153,25 @@ class OcrModels:
         return self.ocr_model
 
 
-def test_ocr_model():
+def test_ocr_model(from_remote=True):
     file_path = "C:/Users/Administrator/Desktop/error2.png"
     with open(file_path, "rb") as f:
         file_bytes = f.read()
     file_base64 = base64.b64encode(file_bytes)
-    file_json = {"data": file_base64}
+    _md5 = get_md5_from_bytes(file_bytes)[0]
 
-    # _url = "http://192.168.2.102:17000/ocr"
-    _url = "http://127.0.0.1:17000/ocr"
-    print(json.loads(request_post(_url, file_json)))
+    _global._init()
+    _global.update({"port": 15010, "md5": _md5})
+
+    if from_remote:
+        file_json = {"data": file_base64, "md5": _md5}
+        # _url = "http://192.168.2.102:17000/ocr"
+        _url = "http://127.0.0.1:17000/ocr"
+        print(json.loads(request_post(_url, file_json)))
+    else:
+        ocr_model = OcrModels().get_model()
+        result = ocr(file_base64, ocr_model)
+        print(result)
 
 
 if __name__ == '__main__':
@@ -176,18 +187,18 @@ if __name__ == '__main__':
     _global.update({"port": str(port)})
     globals().update({"port": str(port)})
 
-    ip = get_intranet_ip()
-    logging.basicConfig(level=logging.INFO,
-                        format='%(asctime)s - %(name)s - %(levelname)s - '
-                               + ip + ' - ' + str(port) + ' - %(message)s')
+    # ip = get_intranet_ip()
+    # logging.basicConfig(level=logging.INFO,
+    #                     format='%(asctime)s - %(name)s - %(levelname)s - '
+    #                            + ip + ' - ' + str(port) + ' - %(message)s')
 
     os.environ['CUDA_VISIBLE_DEVICES'] = str(using_gpu_index)
 
     # app.run(host='0.0.0.0', port=port, processes=1, threaded=False, debug=False)
-    app.run(port=port)
+    app.run()
     log("OCR running "+str(port))
 
-    # test_ocr_model()
+    # test_ocr_model(False)
     #
     # log("OCR running")
     # file_path = "C:/Users/Administrator/Desktop/error9.jpg"

+ 2 - 2
ocr/paddleocr.py

@@ -312,8 +312,8 @@ class PaddleOCR(predict_system.TextSystem):
             img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
         if det and rec:
             dt_boxes, rec_res = self.__call__(img)
-            print("paddleocr.py dt_boxes", len(dt_boxes))
-            print("paddleocr.py rec_res", len(rec_res))
+            # print("paddleocr.py dt_boxes", len(dt_boxes))
+            # print("paddleocr.py rec_res", len(rec_res))
             return [[box.tolist(), res] for box, res in zip(dt_boxes, rec_res)]
         elif det and not rec:
             dt_boxes, elapse = self.text_detector(img)

+ 119 - 4
ocr/tools/infer/predict_cls.py

@@ -1,3 +1,4 @@
+# encoding=utf8
 # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
@@ -11,15 +12,22 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
+import io
+import logging
 import os
 import sys
+# __dir__ = os.path.dirname(os.path.abspath(__file__))
+import zlib
 
-__dir__ = os.path.dirname(os.path.abspath(__file__))
-sys.path.append(__dir__)
-sys.path.append(os.path.abspath(os.path.join(__dir__, '../..')))
+sys.path.append(os.path.dirname(os.path.abspath(__file__)) + "/../../../")
+import requests
+from format_convert import _global
 
-os.environ["FLAGS_allocator_strategy"] = 'auto_growth'
+from format_convert.utils import judge_error_code, log
 
+# sys.path.append(__dir__)
+# sys.path.append(os.path.abspath(os.path.join(__dir__, '../..')))
+os.environ["FLAGS_allocator_strategy"] = 'auto_growth'
 import cv2
 import copy
 import numpy as np
@@ -123,6 +131,113 @@ class TextClassifier(object):
         return img_list, cls_res, elapse
 
 
+class TextClassifier2(object):
+    shrink_memory_count = 0
+
+    def __init__(self, args):
+        self.cls_image_shape = [int(v) for v in args.cls_image_shape.split(",")]
+        self.cls_batch_num = args.cls_batch_num
+        self.cls_thresh = args.cls_thresh
+        postprocess_params = {
+            'name': 'ClsPostProcess',
+            "label_list": args.label_list,
+        }
+        self.postprocess_op = build_post_process(postprocess_params)
+        self.args = args
+
+    def resize_norm_img(self, img):
+        imgC, imgH, imgW = self.cls_image_shape
+        h = img.shape[0]
+        w = img.shape[1]
+        ratio = w / float(h)
+        if math.ceil(imgH * ratio) > imgW:
+            resized_w = imgW
+        else:
+            resized_w = int(math.ceil(imgH * ratio))
+        resized_image = cv2.resize(img, (resized_w, imgH))
+        resized_image = resized_image.astype('float32')
+        if self.cls_image_shape[0] == 1:
+            resized_image = resized_image / 255
+            resized_image = resized_image[np.newaxis, :]
+        else:
+            resized_image = resized_image.transpose((2, 0, 1)) / 255
+        resized_image -= 0.5
+        resized_image /= 0.5
+        padding_im = np.zeros((imgC, imgH, imgW), dtype=np.float32)
+        padding_im[:, :, 0:resized_w] = resized_image
+        return padding_im
+
+    def __call__(self, img_list):
+        from format_convert.convert_need_interface import from_gpu_interface_redis
+        img_list = copy.deepcopy(img_list)
+        img_num = len(img_list)
+        # Calculate the aspect ratio of all text bars
+        width_list = []
+        for img in img_list:
+            width_list.append(img.shape[1] / float(img.shape[0]))
+        # Sorting can speed up the cls process
+        indices = np.argsort(np.array(width_list))
+
+        cls_res = [['', 0.0]] * img_num
+        batch_num = self.cls_batch_num
+        elapse = 0
+        all_gpu_time = 0
+        for beg_img_no in range(0, img_num, batch_num):
+            # 预处理
+            end_img_no = min(img_num, beg_img_no + batch_num)
+            norm_img_batch = []
+            max_wh_ratio = 0
+            for ino in range(beg_img_no, end_img_no):
+                h, w = img_list[indices[ino]].shape[0:2]
+                wh_ratio = w * 1.0 / h
+                max_wh_ratio = max(max_wh_ratio, wh_ratio)
+            for ino in range(beg_img_no, end_img_no):
+                norm_img = self.resize_norm_img(img_list[indices[ino]])
+                norm_img = norm_img[np.newaxis, :]
+                norm_img_batch.append(norm_img)
+            norm_img_batch = np.concatenate(norm_img_batch)
+            norm_img_batch = norm_img_batch.copy()
+            starttime = time.time()
+
+            # # 压缩numpy
+            # compressed_array = io.BytesIO()
+            # np.savez_compressed(compressed_array, norm_img_batch)
+            # compressed_array.seek(0)
+            # norm_img_batch = compressed_array.read()
+
+            # 调用GPU接口
+            _dict = {"inputs": norm_img_batch, "args": self.args, "md5": _global.get("md5")}
+            result = from_gpu_interface_redis(_dict, model_type="ocr", predictor_type="cls")
+            if judge_error_code(result):
+                logging.error("from_gpu_interface failed! " + str(result))
+                raise requests.exceptions.RequestException
+
+            preds = result.get("preds")
+            gpu_time = result.get("gpu_time")
+            all_gpu_time += round(gpu_time, 2)
+
+            # # 解压numpy
+            # decompressed_array = io.BytesIO()
+            # decompressed_array.write(preds)
+            # decompressed_array.seek(0)
+            # preds = np.load(decompressed_array, allow_pickle=True)['arr_0']
+            # log("inputs.shape" + str(preds.shape))
+
+            # 后处理
+            prob_out = preds
+            cls_result = self.postprocess_op(prob_out)
+            elapse += time.time() - starttime
+            for rno in range(len(cls_result)):
+                label, score = cls_result[rno]
+                cls_res[indices[beg_img_no + rno]] = [label, score]
+                if '180' in label and score > self.cls_thresh:
+                    img_list[indices[beg_img_no + rno]] = cv2.rotate(
+                        img_list[indices[beg_img_no + rno]], 1)
+
+        log("ocr model predict time - cls - time " + str(all_gpu_time) + " - num " + str(img_num))
+        return img_list, cls_res, elapse
+
+
 def main(args):
     image_file_list = get_image_file_list(args.image_dir)
     text_classifier = TextClassifier(args)

+ 154 - 4
ocr/tools/infer/predict_det.py

@@ -1,3 +1,4 @@
+# encoding=utf8
 # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
@@ -11,15 +12,21 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
+import io
+import logging
 import os
 import sys
+# __dir__ = os.path.dirname(os.path.abspath(__file__))
+import zlib
 
-__dir__ = os.path.dirname(os.path.abspath(__file__))
-sys.path.append(__dir__)
-sys.path.append(os.path.abspath(os.path.join(__dir__, '../..')))
+sys.path.append(os.path.dirname(os.path.abspath(__file__)) + "/../../../")
+import requests
+from format_convert import _global
 
+from format_convert.utils import judge_error_code, log
+# sys.path.append(__dir__)
+# sys.path.append(os.path.abspath(os.path.join(__dir__, '../..')))
 os.environ["FLAGS_allocator_strategy"] = 'auto_growth'
-
 import cv2
 import numpy as np
 import time
@@ -202,6 +209,149 @@ class TextDetector(object):
         return dt_boxes, elapse
 
 
+class TextDetector2(object):
+    shrink_memory_count = 0
+
+    def __init__(self, args):
+        self.args = args
+        self.det_algorithm = args.det_algorithm
+        pre_process_list = [{
+            'DetResizeForTest': None
+        }, {
+            'NormalizeImage': {
+                'std': [0.229, 0.224, 0.225],
+                'mean': [0.485, 0.456, 0.406],
+                'scale': '1./255.',
+                'order': 'hwc'
+            }
+        }, {
+            'ToCHWImage': None
+        }, {
+            'KeepKeys': {
+                'keep_keys': ['image', 'shape']
+            }
+        }]
+        postprocess_params = {}
+        if self.det_algorithm == "DB":
+            postprocess_params['name'] = 'DBPostProcess'
+            postprocess_params["thresh"] = args.det_db_thresh
+            postprocess_params["box_thresh"] = args.det_db_box_thresh
+            postprocess_params["max_candidates"] = 1000
+            postprocess_params["unclip_ratio"] = args.det_db_unclip_ratio
+            postprocess_params["use_dilation"] = args.use_dilation
+        else:
+            logger.info("unknown det_algorithm:{}".format(self.det_algorithm))
+            sys.exit(0)
+
+        self.preprocess_op = create_operators(pre_process_list)
+        self.postprocess_op = build_post_process(postprocess_params)
+
+    def order_points_clockwise(self, pts):
+        """
+        reference from: https://github.com/jrosebr1/imutils/blob/master/imutils/perspective.py
+        # sort the points based on their x-coordinates
+        """
+        xSorted = pts[np.argsort(pts[:, 0]), :]
+
+        # grab the left-most and right-most points from the sorted
+        # x-roodinate points
+        leftMost = xSorted[:2, :]
+        rightMost = xSorted[2:, :]
+
+        # now, sort the left-most coordinates according to their
+        # y-coordinates so we can grab the top-left and bottom-left
+        # points, respectively
+        leftMost = leftMost[np.argsort(leftMost[:, 1]), :]
+        (tl, bl) = leftMost
+
+        rightMost = rightMost[np.argsort(rightMost[:, 1]), :]
+        (tr, br) = rightMost
+
+        rect = np.array([tl, tr, br, bl], dtype="float32")
+        return rect
+
+    def clip_det_res(self, points, img_height, img_width):
+        for pno in range(points.shape[0]):
+            points[pno, 0] = int(min(max(points[pno, 0], 0), img_width - 1))
+            points[pno, 1] = int(min(max(points[pno, 1], 0), img_height - 1))
+        return points
+
+    def filter_tag_det_res(self, dt_boxes, image_shape):
+        img_height, img_width = image_shape[0:2]
+        dt_boxes_new = []
+        for box in dt_boxes:
+            box = self.order_points_clockwise(box)
+            box = self.clip_det_res(box, img_height, img_width)
+            rect_width = int(np.linalg.norm(box[0] - box[1]))
+            rect_height = int(np.linalg.norm(box[0] - box[3]))
+            if rect_width <= 3 or rect_height <= 3:
+                continue
+            dt_boxes_new.append(box)
+        dt_boxes = np.array(dt_boxes_new)
+        return dt_boxes
+
+    def filter_tag_det_res_only_clip(self, dt_boxes, image_shape):
+        img_height, img_width = image_shape[0:2]
+        dt_boxes_new = []
+        for box in dt_boxes:
+            box = self.clip_det_res(box, img_height, img_width)
+            dt_boxes_new.append(box)
+        dt_boxes = np.array(dt_boxes_new)
+        return dt_boxes
+
+    def __call__(self, img):
+        from format_convert.convert_need_interface import from_gpu_interface_redis
+        # 预处理
+        ori_im = img.copy()
+        data = {'image': img}
+        data = transform(data, self.preprocess_op)
+        img, shape_list = data
+        if img is None:
+            return None, 0
+        img = np.expand_dims(img, axis=0)
+        shape_list = np.expand_dims(shape_list, axis=0)
+        img = img.copy()
+        starttime = time.time()
+
+        # # 压缩numpy
+        # compressed_array = io.BytesIO()
+        # np.savez_compressed(compressed_array, img)
+        # compressed_array.seek(0)
+        # img = compressed_array.read()
+
+        # 调用GPU接口
+        _dict = {"inputs": img, "args": self.args, "md5": _global.get("md5")}
+        result = from_gpu_interface_redis(_dict, model_type="ocr", predictor_type="det")
+        if judge_error_code(result):
+            logging.error("from_gpu_interface failed! " + str(result))
+            raise requests.exceptions.RequestException
+
+        _preds = result.get("preds")
+        gpu_time = result.get("gpu_time")
+
+        # # 解压numpy
+        # decompressed_array = io.BytesIO()
+        # decompressed_array.write(_preds)
+        # decompressed_array.seek(0)
+        # _preds = np.load(decompressed_array, allow_pickle=True)['arr_0']
+        # log("inputs.shape" + str(_preds.shape))
+
+        # 后处理
+        preds = {}
+        if self.det_algorithm == 'DB':
+            preds['maps'] = _preds
+        else:
+            raise NotImplementedError
+
+        post_result = self.postprocess_op(preds, shape_list)
+        dt_boxes = post_result[0]['points']
+        dt_boxes = self.filter_tag_det_res(dt_boxes, ori_im.shape)
+        elapse = time.time() - starttime
+
+        log("ocr model predict time - det - time " + str(gpu_time))
+        return dt_boxes, elapse
+
+
 if __name__ == "__main__":
     args = utility.parse_args()
     image_file_list = get_image_file_list(args.image_dir)

+ 122 - 5
ocr/tools/infer/predict_rec.py

@@ -1,3 +1,4 @@
+# encoding=utf8
 # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
@@ -11,15 +12,18 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
+import io
+import logging
 import os
 import sys
+# __dir__ = os.path.dirname(os.path.abspath(__file__))
+import zlib
 
-__dir__ = os.path.dirname(os.path.abspath(__file__))
-sys.path.append(__dir__)
-sys.path.append(os.path.abspath(os.path.join(__dir__, '../..')))
-
+import requests
+# sys.path.append(__dir__)
+# sys.path.append(os.path.abspath(os.path.join(__dir__, '../..')))
+sys.path.append(os.path.dirname(os.path.abspath(__file__)) + "/../../../")
 os.environ["FLAGS_allocator_strategy"] = 'auto_growth'
-
 import cv2
 import numpy as np
 import math
@@ -33,6 +37,10 @@ from ocr.ppocr.postprocess import build_post_process
 from ocr.ppocr.utils.logging import get_logger
 from ocr.ppocr.utils.utility import get_image_file_list, check_and_read_gif
 
+from format_convert.utils import judge_error_code, log
+from format_convert import _global
+
+
 logger = get_logger()
 
 
@@ -234,7 +242,9 @@ class TextRecognizer(object):
             else:
                 starttime = time.time()
                 self.input_tensor.copy_from_cpu(norm_img_batch)
+                start_time = time.time()
                 self.predictor.run()
+                logging.info("ocr model predict time - rec" + str(time.time()-start_time))
 
                 outputs = []
                 for output_tensor in self.output_tensors:
@@ -256,6 +266,113 @@ class TextRecognizer(object):
         return rec_res, elapse
 
 
+class TextRecognizer2(object):
+    shrink_memory_count = 0
+
+    def __init__(self, args):
+        self.rec_image_shape = [int(v) for v in args.rec_image_shape.split(",")]
+        self.character_type = args.rec_char_type
+        self.rec_batch_num = args.rec_batch_num
+        self.rec_algorithm = args.rec_algorithm
+        postprocess_params = {
+            'name': 'CTCLabelDecode',
+            "character_type": args.rec_char_type,
+            "character_dict_path": args.rec_char_dict_path,
+            "use_space_char": args.use_space_char
+        }
+        self.postprocess_op = build_post_process(postprocess_params)
+        self.args = args
+        # self.predictor, self.input_tensor, self.output_tensors = \
+        #     utility.create_predictor(args, 'rec', logger)
+
+    def resize_norm_img(self, img, max_wh_ratio):
+        imgC, imgH, imgW = self.rec_image_shape
+        assert imgC == img.shape[2]
+        if self.character_type == "ch":
+            imgW = int((32 * max_wh_ratio))
+        h, w = img.shape[:2]
+        ratio = w / float(h)
+        if math.ceil(imgH * ratio) > imgW:
+            resized_w = imgW
+        else:
+            resized_w = int(math.ceil(imgH * ratio))
+        # print("predict_rec.py resize_norm_img resize shape", (resized_w, imgH))
+        resized_image = cv2.resize(img, (resized_w, imgH))
+        resized_image = resized_image.astype('float32')
+        resized_image = resized_image.transpose((2, 0, 1)) / 255
+        resized_image -= 0.5
+        resized_image /= 0.5
+        padding_im = np.zeros((imgC, imgH, imgW), dtype=np.float32)
+        padding_im[:, :, 0:resized_w] = resized_image
+        return padding_im
+
+    def __call__(self, img_list):
+        from format_convert.convert_need_interface import from_gpu_interface_redis
+        img_num = len(img_list)
+        # Calculate the aspect ratio of all text bars
+        width_list = []
+        for img in img_list:
+            width_list.append(img.shape[1] / float(img.shape[0]))
+
+        # Sorting can speed up the recognition process
+        indices = np.argsort(np.array(width_list))
+
+        rec_res = [['', 0.0]] * img_num
+        batch_num = self.rec_batch_num
+        elapse = 0
+        all_gpu_time = 0
+        for beg_img_no in range(0, img_num, batch_num):
+            # 预处理
+            end_img_no = min(img_num, beg_img_no + batch_num)
+            norm_img_batch = []
+            max_wh_ratio = 0
+            for ino in range(beg_img_no, end_img_no):
+                h, w = img_list[indices[ino]].shape[0:2]
+                wh_ratio = w * 1.0 / h
+                max_wh_ratio = max(max_wh_ratio, wh_ratio)
+            for ino in range(beg_img_no, end_img_no):
+                norm_img = self.resize_norm_img(img_list[indices[ino]],
+                                                max_wh_ratio)
+                norm_img = norm_img[np.newaxis, :]
+                norm_img_batch.append(norm_img)
+            norm_img_batch = np.concatenate(norm_img_batch)
+            norm_img_batch = norm_img_batch.copy()
+            starttime = time.time()
+
+            # # 压缩numpy
+            # compressed_array = io.BytesIO()
+            # np.savez_compressed(compressed_array, norm_img_batch)
+            # compressed_array.seek(0)
+            # norm_img_batch = compressed_array.read()
+
+            # 调用GPU接口
+            _dict = {"inputs": norm_img_batch, "args": self.args, "md5": _global.get("md5")}
+            result = from_gpu_interface_redis(_dict, model_type="ocr", predictor_type="rec")
+            if judge_error_code(result):
+                logging.error("from_gpu_interface failed! " + str(result))
+                raise requests.exceptions.RequestException
+
+            preds = result.get("preds")
+            gpu_time = result.get("gpu_time")
+            all_gpu_time += round(gpu_time, 2)
+
+            # # 解压numpy
+            # decompressed_array = io.BytesIO()
+            # decompressed_array.write(preds)
+            # decompressed_array.seek(0)
+            # preds = np.load(decompressed_array, allow_pickle=True)['arr_0']
+            # log("inputs.shape" + str(preds.shape))
+
+            # 后处理
+            rec_result = self.postprocess_op(preds)
+            for rno in range(len(rec_result)):
+                rec_res[indices[beg_img_no + rno]] = rec_result[rno]
+            elapse += time.time() - starttime
+
+        log("ocr model predict time - rec - time " + str(all_gpu_time) + " - num " + str(img_num))
+        return rec_res, elapse
+
+
 def main(args):
     image_file_list = get_image_file_list(args.image_dir)
     text_recognizer = TextRecognizer(args)

+ 1 - 3
ocr/tools/infer/utility.py

@@ -17,7 +17,7 @@ import logging
 import os
 import sys
 import time
-
+sys.path.append(os.path.dirname(os.path.abspath(__file__)) + "/../../../")
 import cv2
 import numpy as np
 import json
@@ -150,7 +150,6 @@ def create_predictor(args, mode, logger):
     # config.switch_use_feed_fetch_ops(False)
 
     # create predictor
-    start_time = time.time()
     predictor = inference.create_predictor(config)
 
     input_names = predictor.get_input_names()
@@ -161,7 +160,6 @@ def create_predictor(args, mode, logger):
     for output_name in output_names:
         output_tensor = predictor.get_output_handle(output_name)
         output_tensors.append(output_tensor)
-    logging.info("ocr model predict time " + str(time.time()-start_time))
 
     return predictor, input_tensor, output_tensors
 

+ 214 - 0
otr/otr_gpu_interface.py

@@ -0,0 +1,214 @@
+# encoding=utf8
+import base64
+import io
+import json
+import os
+import pickle
+import threading
+import traceback
+# os.environ['TF_XLA_FLAGS'] = '--tf_xla_cpu_global_jit'
+# os.environ['CUDA_VISIBLE_DEVICES'] = "0"
+import redis
+import tensorflow as tf
+try:
+    gpus = tf.config.list_physical_devices('GPU')
+    if len(gpus) > 0:
+        tf.config.experimental.set_virtual_device_configuration(
+            gpus[0],
+            [tf.config.experimental.VirtualDeviceConfiguration(memory_limit=2000)])
+except:
+    traceback.print_exc()
+    pass
+import sys
+sys.path.append(os.path.dirname(os.path.abspath(__file__)) + "/../")
+import time
+import logging
+# from table_line import *
+import cv2
+import numpy as np
+from flask import Flask, request
+from format_convert.utils import request_post, judge_error_code, get_intranet_ip, log, get_md5_from_bytes, get_platform
+from otr.table_line import table_net, table_line, table_preprocess, table_postprocess
+from format_convert import _global
+
+
+# 接口配置
+app = Flask(__name__)
+
+
+lock = threading.RLock()
+
+
+redis_db = redis.StrictRedis(host='127.0.0.1', port='6379',
+                             db=1, password='bidi123456', health_check_interval=300)
+
+
+# @app.route('/otr', methods=['POST'])
+def _otr_flask():
+    start_time = time.time()
+    log("into _otr")
+    _global._init()
+    _global.update({"port": globals().get("port")})
+    log("into _otr -> _global " + str(time.time()-start_time))
+    start_time = time.time()
+
+    try:
+        if not request.form:
+            log("otr no data!")
+            return json.dumps({"list_line": str([-9])})
+        log("judge request.form " + str(time.time()-start_time))
+
+        start_time1 = time.time()
+        # 反序列化
+        result = pickle.loads(base64.b64decode(request.form.get("data")))
+
+        inputs = result.get("inputs")
+        # 解压numpy
+        decompressed_array = io.BytesIO()
+        decompressed_array.write(inputs)
+        decompressed_array.seek(0)
+        inputs = np.load(decompressed_array, allow_pickle=True)['arr_0']
+        log("inputs.shape" + str(inputs.shape))
+
+        predictor_type = result.get("predictor_type")
+        model_type = result.get("model_type")
+        _md5 = result.get("md5")
+        _global.update({"md5": _md5})
+        log("read data " + str(time.time()-start_time1))
+
+        # 获取模型
+        model = globals().get(model_type)
+        if model is None:
+            start_time1 = time.time()
+            log("=== init " + model_type + " model ===")
+            model = OtrModels().get_model()
+            globals().update({model_type: model})
+            log("=== init " + model_type + " model " + str(round(time.time()-start_time1, 2)) + " ===")
+
+        # 运行
+        with lock:
+            start_time1 = time.time()
+            pred = model.predict(inputs)
+            pred = pred[0]
+            log("pred.shape " + str(pred.shape))
+
+        # 压缩numpy
+        compressed_array = io.BytesIO()
+        np.savez_compressed(compressed_array, pred)
+        compressed_array.seek(0)
+        pred = compressed_array.read()
+
+        gpu_time = round(float(time.time()-start_time1), 2)
+
+        finish_time = round(float(time.time()-start_time), 2)
+        log("otr model predict time " + str(gpu_time) + " " + str(finish_time))
+        return base64.b64encode(pickle.dumps({"preds": pred, "gpu_time": gpu_time, "elapse": finish_time}))
+    except Exception as e:
+        finish_time = round(float(time.time()-start_time), 2)
+        traceback.print_exc()
+        return base64.b64encode(pickle.dumps({"preds": None, "gpu_time": 0., "elapse": finish_time}))
+
+
+def _otr():
+    start_time = time.time()
+    log("into _otr")
+    _global._init()
+    _global.update({"port": globals().get("port")})
+    log("into _otr -> _global " + str(time.time()-start_time))
+
+    while True:
+        start_time = time.time()
+        try:
+            if redis_db.llen("producer_otr") == 0:
+                continue
+            log("judge llen " + str(time.time()-start_time))
+
+            _time = time.time()
+            result = redis_db.lpop("producer_otr")
+            if result is None:
+                continue
+            result = pickle.loads(result)
+            log("from producer_otr time " + str(time.time() - _time))
+
+            _time = time.time()
+            inputs = result.get("inputs")
+            # # 解压numpy
+            # decompressed_array = io.BytesIO()
+            # decompressed_array.write(inputs)
+            # decompressed_array.seek(0)
+            # inputs = np.load(decompressed_array, allow_pickle=True)['arr_0']
+            # log("inputs.shape " + str(inputs.shape))
+            # log("numpy decompress " + str(time.time()-_time))
+
+            predictor_type = result.get("predictor_type")
+            _uuid = result.get("uuid")
+            model_type = result.get("model_type")
+            _md5 = result.get("md5")
+            _global.update({"md5": _md5})
+            log("read data " + str(time.time()-_time))
+
+            # 获取模型
+            model = globals().get(model_type)
+            if model is None:
+                start_time1 = time.time()
+                log("=== init " + model_type + " model ===")
+                model = OtrModels().get_model()
+                globals().update({model_type: model})
+                log("=== init " + model_type + " model " + str(round(time.time()-start_time1, 2)) + " ===")
+
+            # 运行
+            start_time1 = time.time()
+            pred = model.predict(inputs)
+            pred = pred[0]
+            log("pred.shape " + str(pred.shape))
+
+            # # 压缩numpy
+            # _time = time.time()
+            # compressed_array = io.BytesIO()
+            # np.savez_compressed(compressed_array, pred)
+            # compressed_array.seek(0)
+            # pred = compressed_array.read()
+            # log("numpy compress " + str(time.time()-_time))
+
+            # 写入redis
+            gpu_time = round(float(time.time()-start_time1), 2)
+            finish_time = round(float(time.time()-start_time), 2)
+            redis_db.hset("consumer_otr", _uuid, pickle.dumps({"preds": pred, "gpu_time": gpu_time, "elapse": finish_time}))
+            log("to consumer_otr " + str(time.time()-_time))
+
+            log("otr model predict time " + str(gpu_time) + " " + str(finish_time))
+        except Exception as e:
+            traceback.print_exc()
+
+
+class OtrModels:
+    def __init__(self):
+        # python文件所在目录
+        _dir = os.path.abspath(os.path.dirname(os.path.abspath(__file__)))
+        model_path = _dir + "/models/table-line.h5"
+        self.otr_model = table_net((None, None, 3), 2)
+        self.otr_model.load_weights(model_path)
+
+    def get_model(self):
+        return self.otr_model
+
+
+if __name__ == '__main__':
+    if len(sys.argv) == 2:
+        port = int(sys.argv[1])
+        using_gpu_index = 0
+    elif len(sys.argv) == 3:
+        port = int(sys.argv[1])
+        using_gpu_index = int(sys.argv[2])
+    else:
+        port = 18000
+        using_gpu_index = 0
+
+    # _global._init()
+    # _global.update({"port": str(port)})
+    # globals().update({"port": str(port)})
+
+    # app.run(host='0.0.0.0', port=port, processes=1, threaded=False, debug=False)
+    # app.run()
+    # log("OTR running "+str(port))
+    _otr()

+ 77 - 26
otr/otr_interface.py

@@ -2,25 +2,39 @@ import base64
 import json
 import multiprocessing as mp
 import os
+import traceback
 # os.environ['TF_XLA_FLAGS'] = '--tf_xla_cpu_global_jit'
+# os.environ['CUDA_VISIBLE_DEVICES'] = "0"
+import tensorflow as tf
+try:
+    gpus = tf.config.list_physical_devices('GPU')
+    if len(gpus) > 0:
+        tf.config.experimental.set_virtual_device_configuration(
+            gpus[0],
+            [tf.config.experimental.VirtualDeviceConfiguration(memory_limit=2048)])
+except:
+    traceback.print_exc()
+    pass
+# for gpu in gpus:  # 如果使用多块GPU时
+#     tf.config.experimental.set_memory_growth(gpu, True)
+
+# os.environ['CUDA_CACHE_MAXSIZE'] = str(2147483648)
+# os.environ['CUDA_CACHE_DISABLE'] = str(0)
+# gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.6)
+# sess = tf.Session(config=tf.ConfigProto(gpu_options=gpu_options))
 import sys
 sys.path.append(os.path.dirname(os.path.abspath(__file__)) + "/../")
 import time
-import traceback
-from multiprocessing.context import Process
-import multiprocessing
 import logging
 # from table_line import *
 import cv2
 import numpy as np
-import tensorflow as tf
 from flask import Flask, request
-
-from format_convert.utils import request_post, judge_error_code, get_intranet_ip, log, get_md5_from_bytes
-from otr.table_line import get_best_predict_size, table_line, get_points, get_split_line, get_points_row, \
+from format_convert.utils import request_post, judge_error_code, get_intranet_ip, log, get_md5_from_bytes, get_platform
+from otr.table_line import get_points, get_split_line, get_points_row, \
     get_points_col, \
-    delete_close_points, fix_outline, get_bbox, get_outline_point, table_net, delete_contain_bbox, points_to_line, \
-    fix_inner, merge_line, fix_corner, add_continue_bbox, delete_outline
+    delete_close_points, fix_outline, get_bbox, get_outline_point, delete_contain_bbox, points_to_line, \
+    fix_inner, merge_line, fix_corner, add_continue_bbox, delete_outline, table_net, table_line
 from format_convert import _global
 
 
@@ -47,7 +61,8 @@ def _otr():
         data = request.form.get("data")
         is_from_pdf = request.form.get("is_from_pdf")
         img_data = base64.b64decode(data)
-        _md5 = get_md5_from_bytes(img_data)[0]
+        # _md5 = get_md5_from_bytes(img_data)[0]
+        _md5 = request.form.get("md5")
         _global.update({"md5": _md5})
         if is_from_pdf:
             list_lines = line_detect(img_data, otr_model, prob=0.2)
@@ -346,15 +361,35 @@ class OtrModels:
         return self.otr_model
 
 
-def test_otr_model():
-    file_path = "C:/Users/Administrator/Desktop/error2.png"
+def test_otr_model(from_remote=True):
+    _global._init()
+    from format_convert.convert_image import get_best_predict_size, image_process
+    if get_platform() == "Windows":
+        file_path = "C:/Users/Administrator/Desktop/error2.png"
+        file_path = "C:/Users/Administrator/Downloads/1652672734044.jpg"
+    else:
+        file_path = "1.jpg"
+    image_np = cv2.imread(file_path)
+    best_h, best_w = get_best_predict_size(image_np)
+    image_resize = cv2.resize(image_np, (best_w, best_h), interpolation=cv2.INTER_AREA)
+    cv2.imwrite(file_path, image_resize)
     with open(file_path, "rb") as f:
         file_bytes = f.read()
     file_base64 = base64.b64encode(file_bytes)
-    file_json = {"data": file_base64, "is_from_pdf": False}
+    _md5 = get_md5_from_bytes(file_bytes)[0]
+
+    _global.update({"port": 15010, "md5": _md5})
+
+    if from_remote:
+        file_json = {"data": file_base64, "is_from_pdf": False, "md5": _md5}
 
-    _url = "http://192.168.2.103:18000/otr"
-    r = json.loads(request_post(_url, file_json))
+        # _url = "http://192.168.2.104:18000/otr"
+        _url = "http://127.0.0.1:18000/otr"
+        r = json.loads(request_post(_url, file_json))
+    else:
+        # otr_model = OtrModels().get_model()
+        # r = otr(file_base64, otr_model, is_from_pdf=False)
+        r = image_process(image_resize, file_path)
     print(r)
 
 
@@ -378,20 +413,36 @@ if __name__ == '__main__':
     # logging.basicConfig(level=logging.INFO,
     #                     format='%(asctime)s - %(name)s - %(levelname)s - '
     #                            + ip + ' - ' + str(port) + ' - %(message)s')
-
+    logging.info(get_platform())
     # 限制tensorflow显存
-    memory_limit_scale = 0.3
-    os.environ['CUDA_VISIBLE_DEVICES'] = str(using_gpu_index)
-    os.environ['CUDA_CACHE_MAXSIZE'] = str(2147483648)
-    os.environ['CUDA_CACHE_DISABLE'] = str(0)
-    gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=memory_limit_scale)
-    sess = tf.Session(config=tf.ConfigProto(gpu_options=gpu_options))
-
-    app.run(host='0.0.0.0', port=port, processes=1, threaded=False, debug=False)
-    # app.run(port=port)
+    # os.environ['CUDA_VISIBLE_DEVICES'] = str(using_gpu_index)
+    # import tensorflow as tf
+    # if get_platform() != "Windows":
+    #     _version = tf.__version__
+    #     logging.info(str(_version))
+    #     memory_limit_scale = 0.3
+    #     # tensorflow 1.x
+    #     if str(_version)[0] == "1":
+    #         logging.info("1.x " + str(_version))
+    #         os.environ['CUDA_CACHE_MAXSIZE'] = str(2147483648)
+    #         os.environ['CUDA_CACHE_DISABLE'] = str(0)
+    #         gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=memory_limit_scale)
+    #         sess = tf.Session(config=tf.ConfigProto(gpu_options=gpu_options))
+    #
+    #     # tensorflow 2.x
+    #     elif str(_version)[0] == "2":
+    #         logging.info("2.x " + str(_version))
+            # config = tf.compat.v1.ConfigProto()
+            # config.gpu_options.per_process_gpu_memory_fraction = memory_limit_scale
+            # config.gpu_options.allow_growth = True
+            # sess = tf.compat.v1.Session(config=config)
+
+
+    # app.run(host='0.0.0.0', port=port, processes=1, threaded=False, debug=False)
+    app.run()
     log("OTR running "+str(port))
 
-    # test_otr_model()
+    # test_otr_model(False)
 
     # print(json.dumps([-2]))
 

+ 196 - 26
otr/table_line.py

@@ -6,7 +6,11 @@ table line detect
 @author: chineseocr
 """
 import copy
+import io
 import logging
+import sys
+import traceback
+
 import tensorflow as tf
 import tensorflow.keras.backend as K
 from tensorflow.keras.models import Model
@@ -16,7 +20,8 @@ from otr.utils import letterbox_image, get_table_line, adjust_lines, line_to_lin
 import numpy as np
 import cv2
 import time
-logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
+from format_convert import _global
+from format_convert.utils import log
 
 
 def dice_coef(y_true, y_pred, smooth=1e-5):
@@ -239,7 +244,7 @@ def points2lines(pred, sourceP_LB=True, prob=0.2, line_width=7, padding=3, min_l
     _last = False
     _current = False
     while 1:
-        h_index += 5
+        h_index += 2
         if h_index>=height:
             break
         w_index = -1
@@ -248,7 +253,7 @@ def points2lines(pred, sourceP_LB=True, prob=0.2, line_width=7, padding=3, min_l
         else:
             h_i = h_index
         while 1:
-            w_index += 5
+            w_index += 2
             if w_index>=width:
                 break
             if _sum[w_index]<min_len:
@@ -274,12 +279,12 @@ def points2lines(pred, sourceP_LB=True, prob=0.2, line_width=7, padding=3, min_l
     w_index = -1
     _sum = list(np.sum(np.array((pred[...,0]>prob)).astype(int),axis=1))
     while 1:
-        w_index += 5
+        w_index += 2
         if w_index>=width:
             break
         h_index = -1
         while 1:
-            h_index += 5
+            h_index += 2
             if h_index>=height:
                 break
             if _sum[h_index]<min_len:
@@ -355,7 +360,7 @@ def points2lines(pred, sourceP_LB=True, prob=0.2, line_width=7, padding=3, min_l
     for _line in final_horizontal:
         list_line.append(_line)
 
-    logging.info("points2lines cost %.2fs"%(time.time()-_time))
+    log("points2lines cost %.2fs"%(time.time()-_time))
 
     # import matplotlib.pyplot as plt
     # plt.figure()
@@ -443,14 +448,179 @@ def get_line_from_binary_image(image_np, point_value=1, axis=0):
     return axis_lines
 
 
+def table_preprocess(img_data, prob=0.2):
+    try:
+        log("into table_preprocess, prob is " + str(prob))
+        start_time = time.time()
+
+        # 二进制数据流转np.ndarray [np.uint8: 8位像素]
+        img = cv2.imdecode(np.frombuffer(img_data, np.uint8), cv2.IMREAD_COLOR)
+
+        # 将bgr转为rbg
+        image_np = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
+
+        # 模型输入
+        inputs = np.array([image_np])
+
+        # # 压缩numpy
+        # compressed_array = io.BytesIO()
+        # np.savez_compressed(compressed_array, inputs)
+        # compressed_array.seek(0)
+        # inputs_compressed = compressed_array.read()
+
+        log("otr preprocess time: " + str(round(float(time.time()-start_time), 4)) + "s")
+        return image_np, inputs
+    except Exception as e:
+        log("table pre process failed!")
+        traceback.print_exc()
+        return [-1], [-1]
+
+
+def table_postprocess(img_new, pred, prob=0.2, is_test=0):
+    try:
+        # 横线预测结果
+        # row_pred = pred[..., 0] > hprob
+        # row_pred = row_pred.astype(np.uint8)
+        # # 竖线预测结果
+        # col_pred = pred[..., 1] > vprob
+        # col_pred = col_pred.astype(np.uint8)
+        # # 打印模型输出
+        # cv2.imshow("predict", (col_pred+row_pred)*255)
+        # cv2.waitKey(0)
+
+        start_time = time.time()
+        list_line = points2lines(pred, False, prob=prob)
+        mat_plot(list_line, "points2lines", is_test)
+        log("points2lines " + str(time.time()-start_time))
+
+        # 清除短线
+        # print(img_new.shape)
+        start_time = time.time()
+        list_line = delete_short_lines(list_line, img_new.shape)
+        mat_plot(list_line, "delete_short_lines", is_test)
+        log("delete_short_lines " + str(time.time()-start_time))
+
+        # 清除无交点线
+        start_time = time.time()
+        list_line = delete_no_cross_lines(list_line)
+        mat_plot(list_line, "delete_no_cross_lines", is_test)
+        log("delete_no_cross_lines " + str(time.time()-start_time))
+
+        # 分成横竖线
+        start_time = time.time()
+        list_rows = []
+        list_cols = []
+        for line in list_line:
+            if line[0] == line[2]:
+                list_cols.append(line)
+            elif line[1] == line[3]:
+                list_rows.append(line)
+        log("divide rows and cols " + str(time.time()-start_time))
+
+        # 合并错开线
+        start_time = time.time()
+        list_rows = merge_line(list_rows, axis=0)
+        list_cols = merge_line(list_cols, axis=1)
+        mat_plot(list_rows+list_cols, "merge_line", is_test)
+        log("merge_line " + str(time.time()-start_time))
+
+        # 计算交点、分割线
+        start_time = time.time()
+        cross_points = get_points(list_rows, list_cols, (img_new.shape[0], img_new.shape[1]))
+        if not cross_points:
+            return []
+        log("get_points " + str(time.time()-start_time))
+
+        # 清掉外围的没用的线
+        # list_rows, list_cols = delete_outline(list_rows, list_cols, cross_points)
+        # mat_plot(list_rows+list_cols, "delete_outline", is_test)
+
+        # 多个表格分割线
+        start_time = time.time()
+        list_rows, list_cols = fix_in_split_lines(list_rows, list_cols, img_new)
+        split_lines, split_y = get_split_line(cross_points, list_cols, img_new)
+        log("get_split_line " + str(time.time()-start_time))
+
+        # 修复边框
+        start_time = time.time()
+        new_rows, new_cols, long_rows, long_cols = fix_outline(img_new, list_rows, list_cols, cross_points,
+                                                               split_y)
+
+        # 如有补线
+        if new_rows or new_cols:
+            # 连接至补线的延长线
+            if long_rows:
+                list_rows = long_rows
+            if long_cols:
+                list_cols = long_cols
+            # 新的补线
+            if new_rows:
+                list_rows += new_rows
+            if new_cols:
+                list_cols += new_cols
+
+            list_rows, list_cols = fix_in_split_lines(list_rows, list_cols, img_new)
+
+            # 修复边框后重新计算交点、分割线
+            cross_points = get_points(list_rows, list_cols, (img_new.shape[0], img_new.shape[1]))
+            cv_plot(cross_points, img_new.shape, 0, is_test)
+
+            split_lines, split_y = get_split_line(cross_points, list_cols, img_new)
+            print("fix new split_y", split_y)
+            print("fix new split_lines", split_lines)
+
+            # 修复内部缺线
+            # cross_points = fix_inner(list_rows, list_cols, cross_points, split_y)
+            # if not cross_points:
+            #     return []
+        mat_plot(list_rows+list_cols, "fix_outline", is_test)
+
+        split_lines_show = []
+        for _l in split_lines:
+            split_lines_show.append([_l[0][0], _l[0][1], _l[1][0], _l[1][1]])
+        mat_plot(split_lines_show+list_cols,
+                 "split_lines", is_test)
+        log("fix_outline " + str(time.time()-start_time))
+
+        # 修复表格4个角
+        start_time = time.time()
+        list_rows, list_cols = fix_corner(list_rows, list_cols, split_y, threshold=0)
+        mat_plot(list_rows+list_cols, "fix_corner", is_test)
+        log("fix_corner " + str(time.time()-start_time))
+
+        # 修复内部缺线
+        start_time = time.time()
+        list_rows, list_cols = fix_inner(list_rows, list_cols, cross_points, split_y)
+        mat_plot(list_rows+list_cols, "fix_inner", is_test)
+        log("fix_inner " + str(time.time()-start_time))
+
+        # 合并错开线
+        start_time = time.time()
+        list_rows = merge_line(list_rows, axis=0)
+        list_cols = merge_line(list_cols, axis=1)
+        mat_plot(list_rows+list_cols, "merge_line", is_test)
+        log("merge_line " + str(time.time()-start_time))
+
+        list_line = list_rows + list_cols
+
+        # 打印处理后线
+        mat_plot(list_line, "all", is_test)
+        log("otr postprocess table_line " + str(time.time()-start_time))
+        return list_line
+    except Exception as e:
+        log("table post process failed!")
+        traceback.print_exc()
+        return [-1]
+
+
 def table_line(img, model, size=(512, 1024), prob=0.2, is_test=0):
-    logging.info("into table_line, prob is " + str(prob))
+    log("into table_line, prob is " + str(prob))
     sizew, sizeh = size
     img_new = cv2.resize(img, (sizew, sizeh), interpolation=cv2.INTER_AREA)
 
     start_time = time.time()
     pred = model.predict(np.array([img_new]))
-    logging.info("otr model predict time " + str(time.time()-start_time))
+    log("otr model predict time " + str(time.time()-start_time))
     pred = pred[0]
 
     draw_pixel(pred, prob, is_test)
@@ -468,20 +638,20 @@ def table_line(img, model, size=(512, 1024), prob=0.2, is_test=0):
     start_time = time.time()
     list_line = points2lines(pred, False, prob=prob)
     mat_plot(list_line, "points2lines", is_test)
-    logging.info("points2lines " + str(time.time()-start_time))
+    log("points2lines " + str(time.time()-start_time))
 
     # 清除短线
     # print(img_new.shape)
     start_time = time.time()
     list_line = delete_short_lines(list_line, img_new.shape)
     mat_plot(list_line, "delete_short_lines", is_test)
-    logging.info("delete_short_lines " + str(time.time()-start_time))
+    log("delete_short_lines " + str(time.time()-start_time))
 
     # 清除无交点线
     start_time = time.time()
     list_line = delete_no_cross_lines(list_line)
     mat_plot(list_line, "delete_no_cross_lines", is_test)
-    logging.info("delete_no_cross_lines " + str(time.time()-start_time))
+    log("delete_no_cross_lines " + str(time.time()-start_time))
 
     # 分成横竖线
     start_time = time.time()
@@ -492,21 +662,21 @@ def table_line(img, model, size=(512, 1024), prob=0.2, is_test=0):
             list_cols.append(line)
         elif line[1] == line[3]:
             list_rows.append(line)
-    logging.info("divide rows and cols " + str(time.time()-start_time))
+    log("divide rows and cols " + str(time.time()-start_time))
 
     # 合并错开线
     start_time = time.time()
     list_rows = merge_line(list_rows, axis=0)
     list_cols = merge_line(list_cols, axis=1)
     mat_plot(list_rows+list_cols, "merge_line", is_test)
-    logging.info("merge_line " + str(time.time()-start_time))
+    log("merge_line " + str(time.time()-start_time))
 
     # 计算交点、分割线
     start_time = time.time()
     cross_points = get_points(list_rows, list_cols, (img_new.shape[0], img_new.shape[1]))
     if not cross_points:
         return []
-    logging.info("get_points " + str(time.time()-start_time))
+    log("get_points " + str(time.time()-start_time))
 
     # 清掉外围的没用的线
     # list_rows, list_cols = delete_outline(list_rows, list_cols, cross_points)
@@ -516,7 +686,7 @@ def table_line(img, model, size=(512, 1024), prob=0.2, is_test=0):
     start_time = time.time()
     list_rows, list_cols = fix_in_split_lines(list_rows, list_cols, img_new)
     split_lines, split_y = get_split_line(cross_points, list_cols, img_new)
-    logging.info("get_split_line " + str(time.time()-start_time))
+    log("get_split_line " + str(time.time()-start_time))
 
     # 修复边框
     start_time = time.time()
@@ -557,32 +727,32 @@ def table_line(img, model, size=(512, 1024), prob=0.2, is_test=0):
         split_lines_show.append([_l[0][0], _l[0][1], _l[1][0], _l[1][1]])
     mat_plot(split_lines_show+list_cols,
              "split_lines", is_test)
-    logging.info("fix_outline " + str(time.time()-start_time))
+    log("fix_outline " + str(time.time()-start_time))
 
     # 修复表格4个角
     start_time = time.time()
     list_rows, list_cols = fix_corner(list_rows, list_cols, split_y, threshold=0)
     mat_plot(list_rows+list_cols, "fix_corner", is_test)
-    logging.info("fix_corner " + str(time.time()-start_time))
+    log("fix_corner " + str(time.time()-start_time))
 
     # 修复内部缺线
     start_time = time.time()
     list_rows, list_cols = fix_inner(list_rows, list_cols, cross_points, split_y)
     mat_plot(list_rows+list_cols, "fix_inner", is_test)
-    logging.info("fix_inner " + str(time.time()-start_time))
+    log("fix_inner " + str(time.time()-start_time))
 
     # 合并错开线
     start_time = time.time()
     list_rows = merge_line(list_rows, axis=0)
     list_cols = merge_line(list_cols, axis=1)
     mat_plot(list_rows+list_cols, "merge_line", is_test)
-    logging.info("merge_line " + str(time.time()-start_time))
+    log("merge_line " + str(time.time()-start_time))
 
     list_line = list_rows + list_cols
 
     # 打印处理后线
     mat_plot(list_line, "all", is_test)
-    logging.info("otr postprocess table_line " + str(time.time()-start_time))
+    log("otr postprocess table_line " + str(time.time()-start_time))
     return list_line
 
 
@@ -593,16 +763,16 @@ def table_line2(img, model, size=(512, 1024), hprob=0.5, vprob=0.5, row=50, col=
     # pred = model.predict(np.array([np.array(inputBlob)]))
     # pred = model.predict(np.array([np.array(inputBlob)/255.0]))
     img_new = cv2.resize(img, (sizew, sizeh), interpolation=cv2.INTER_AREA)
-    # logging.info("into table_line 1")
+    # log("into table_line 1")
 
     pred = model.predict(np.array([img_new]))
-    # logging.info("into table_line 2")
+    # log("into table_line 2")
     pred = pred[0]
 
     draw_pixel(pred)
     _time = time.time()
     points2lines(pred)
-    logging.info("points2lines takes %ds"%(time.time()-_time))
+    log("points2lines takes %ds"%(time.time()-_time))
 
     vpred = pred[..., 1] > vprob   # 横线
     hpred = pred[..., 0] > hprob   # 竖线
@@ -614,7 +784,7 @@ def table_line2(img, model, size=(512, 1024), hprob=0.5, vprob=0.5, row=50, col=
 
     colboxes = get_table_line(vpred, axis=1, lineW=col)
     rowboxes = get_table_line(hpred, axis=0, lineW=row)
-    # logging.info("into table_line 3")
+    # log("into table_line 3")
 
     # if len(rowboxes) > 0:
     #     rowboxes = np.array(rowboxes)
@@ -633,7 +803,7 @@ def table_line2(img, model, size=(512, 1024), hprob=0.5, vprob=0.5, row=50, col=
         for j in range(ncol):
             rowboxes[i] = line_to_line(rowboxes[i], colboxes[j], 10)
             colboxes[j] = line_to_line(colboxes[j], rowboxes[i], 10)
-    # logging.info("into table_line 4")
+    # log("into table_line 4")
 
     # 删掉贴着边框的line
     temp_list = []
@@ -1810,7 +1980,7 @@ def fix_outline2(image, row_lines, col_lines, points, split_y):
 
 
 def fix_outline(image, row_lines, col_lines, points, split_y, scale=25):
-    logging.info("into fix_outline")
+    log("into fix_outline")
     x_min_len = max(10, int(image.shape[0] / scale))
     y_min_len = max(10, int(image.shape[1] / scale))
     # print("x_min_len", x_min_len, "y_min_len", y_min_len)