Ver Fonte

印章去除作为外部接口

fangjiasheng há 2 anos atrás
pai
commit
a45a053359

+ 34 - 23
format_convert/convert_image.py

@@ -7,19 +7,17 @@ import sys
 import time
 import requests
 import numpy as np
-from ocr.paddleocr import PaddleOCR
-
+from PIL import Image
 sys.path.append(os.path.dirname(__file__) + "/../")
 from pdfminer.layout import LTLine
 import traceback
 import cv2
-from format_convert import get_memory_info, _global
+from isr.pre_process import count_red_pixel
 from format_convert.utils import judge_error_code, add_div, LineTable, get_table_html, get_logger, log, \
     memory_decorator, pil_resize
-from format_convert.table_correct import get_rotated_image
 from format_convert.convert_need_interface import from_otr_interface, from_ocr_interface, from_gpu_interface_redis, \
     from_idc_interface, from_isr_interface
-from otr.table_line import table_preprocess, table_postprocess
+from format_convert.table_correct import get_rotated_image
 
 
 def image_process(image_np, image_path, is_from_pdf=False, is_from_docx=False, use_ocr=True):
@@ -102,30 +100,43 @@ def image_process(image_np, image_path, is_from_pdf=False, is_from_docx=False, u
         # if image_np is None:
         #     return []
         #
-        # with open(image_path, "rb") as f:
-        #     image_bytes = f.read()
-        #
         # # idc模型实现图片倾斜校正
-        # image_np = from_idc_interface(image_bytes)
-        # if judge_error_code(image_np):
+        # image_resize = pil_resize(image_np, 640, 640)
+        # image_resize_path = image_path.split(".")[0] + "_resize_idc." + image_path.split(".")[-1]
+        # cv2.imwrite(image_resize_path, image_resize)
+        #
+        # with open(image_resize_path, "rb") as f:
+        #     image_bytes = f.read()
+        # angle = from_idc_interface(image_bytes)
+        # if judge_error_code(angle):
         #     if is_from_docx:
         #         return []
         #     else:
-        #         return image_np
+        #         return angle
+        # # 根据角度旋转
+        # image_pil = Image.fromarray(image_np)
+        # image_np = np.array(image_pil.rotate(angle, expand=1))
+        # # 写入
         # idc_path = image_path.split(".")[0] + "_idc." + image_path.split(".")[-1]
         # cv2.imwrite(idc_path, image_np)
-        # with open(idc_path, "rb") as f:
-        #     image_bytes = f.read()
-        #
-        # # isr模型去除印章
-        # image_np = from_isr_interface(image_bytes)
-        # if judge_error_code(image_np):
-        #     if is_from_docx:
-        #         return []
-        #     else:
-        #         return image_np
-        # isr_path = image_path.split(".")[0] + "_isr." + image_path.split(".")[-1]
-        # cv2.imwrite(isr_path, image_np)
+
+        # isr模型去除印章
+        if count_red_pixel(image_np):
+            # 红色像素达到一定值才过模型
+            with open(image_path, "rb") as f:
+                image_bytes = f.read()
+            image_np = from_isr_interface(image_bytes)
+            if judge_error_code(image_np):
+                if is_from_docx:
+                    return []
+                else:
+                    return image_np
+            # [1]代表检测不到印章,直接返回
+            if isinstance(image_np, list) and image_np == [1]:
+                log("no seals detected!")
+            else:
+                isr_path = image_path.split(".")[0] + "_isr." + image_path.split(".")[-1]
+                cv2.imwrite(isr_path, image_np)
 
         # otr模型识别表格,需要图片resize成模型所需大小, 写入另一个路径
         best_h, best_w = get_best_predict_size(image_np)

+ 86 - 10
format_convert/convert_need_interface.py

@@ -12,6 +12,8 @@ import time
 import uuid
 import zlib
 from queue import Queue
+
+import cv2
 import redis
 from werkzeug.exceptions import NotFound
 
@@ -489,7 +491,7 @@ def from_otr_interface(image_stream, is_from_pdf=False, from_remote=FROM_REMOTE)
         return [-1]
 
 
-def from_isr_interface(image_stream, from_remote=False):
+def from_isr_interface(image_stream, from_remote=FROM_REMOTE):
     log("into from_isr_interface")
     start_time = time.time()
     try:
@@ -535,7 +537,20 @@ def from_isr_interface(image_stream, from_remote=False):
             return [-2]
 
         _dict = r
-        image_np = _dict.get("image")
+        if from_remote:
+            image_string = _dict.get("image")
+            if judge_error_code(image_string):
+                return image_string
+            # [1]代表检测不到印章,直接返回
+            if isinstance(image_string, list) and image_string == [1]:
+                return image_string
+
+            image_base64 = image_string.encode("utf-8")
+            image_bytes = base64.b64decode(image_base64)
+            buffer = np.frombuffer(image_bytes, dtype=np.uint8)
+            image_np = cv2.imdecode(buffer, 1)
+        else:
+            image_np = _dict.get("image")
         log("from_isr_interface cost time " + str(time.time()-start_time))
         return image_np
     except Exception as e:
@@ -544,7 +559,7 @@ def from_isr_interface(image_stream, from_remote=False):
         return [-11]
 
 
-def from_idc_interface(image_stream, from_remote=False):
+def from_idc_interface(image_stream, from_remote=FROM_REMOTE):
     log("into from_idc_interface")
     start_time = time.time()
     try:
@@ -556,10 +571,10 @@ def from_idc_interface(image_stream, from_remote=False):
                 retry_times_1 = 3
                 # 重试
                 while retry_times_1:
-                    ip_port = interface_pool_gunicorn("isr")
+                    ip_port = interface_pool_gunicorn("idc")
                     if judge_error_code(ip_port):
                         return ip_port
-                    _url = ip_port + "/isr"
+                    _url = ip_port + "/idc"
                     r = json.loads(request_post(_url, {"data": base64_stream,
                                                        "md5": _global.get("md5")},
                                                 time_out=60))
@@ -570,7 +585,7 @@ def from_idc_interface(image_stream, from_remote=False):
                             return r
                         else:
                             retry_times_1 -= 1
-                            log("retry post isr_interface... left times " + str(retry_times_1))
+                            log("retry post idc_interface... left times " + str(retry_times_1))
                             continue
                     if judge_error_code(r):
                         return r
@@ -588,9 +603,9 @@ def from_idc_interface(image_stream, from_remote=False):
             return [-2]
 
         _dict = r
-        image_np = _dict.get("image")
+        angle = _dict.get("angle")
         log("from_idc_interface cost time " + str(time.time()-start_time))
-        return image_np
+        return angle
     except Exception as e:
         log("from_idc_interface error!")
         traceback.print_exc()
@@ -670,6 +685,59 @@ def interface_pool(interface_type, use_gunicorn=True):
 
 
 def interface_pool_gunicorn(interface_type):
+    ip_port_flag_dict = _global.get("ip_port_flag")
+    ip_port_dict = _global.get("ip_port")
+    try:
+        if ip_port_dict is None or ip_port_flag_dict is None:
+            raise NotFound
+
+        # 负载均衡, 选取有该接口的ip
+        min_cnt = 10000.
+        interface_cnt = 0
+        _ip = None
+        port_list = []
+        for key in ip_port_flag_dict.keys():
+            temp_port_list = get_args_from_config(ip_port_dict, key, interface_type)
+            if not temp_port_list:
+                continue
+            interface_cnt = ip_port_flag_dict.get(key).get(interface_type)
+            if interface_cnt is not None and interface_cnt / len(temp_port_list[0]) < min_cnt:
+                _ip = key
+                min_cnt = interface_cnt / len(temp_port_list[0])
+                port_list = temp_port_list[0]
+
+        # 选取端口
+        if interface_type == "office":
+            # 刚开始随机,后续求余
+            if min_cnt == 0:
+                _port = port_list[random.randint(0, len(port_list)-1)]
+                ip_port_flag_dict[_ip][interface_type] = int(_port[-2:])
+            else:
+                _port = port_list[interface_cnt % len(port_list)]
+        else:
+            # 使用gunicorn则直接选第一个
+            _port = port_list[0]
+
+        # 更新flag
+        if ip_port_flag_dict.get(_ip).get(interface_type) >= 10000:
+            ip_port_flag_dict[_ip][interface_type] = 0
+        else:
+            ip_port_flag_dict[_ip][interface_type] += 1
+        _global.update({"ip_port_flag": ip_port_flag_dict})
+
+        ip_port = _ip + ":" + str(_port)
+        log(interface_type)
+        log(ip_port)
+        return ip_port
+    except NotFound:
+        log("ip_flag or ip_port_dict is None! checkout config")
+        return [-2]
+    except:
+        traceback.print_exc()
+        return [-1]
+
+
+def interface_pool_gunicorn_old(interface_type):
     ip_flag_list = _global.get("ip_flag")
     ip_port_flag_dict = _global.get("ip_port_flag")
     ip_port_dict = _global.get("ip_port")
@@ -715,7 +783,15 @@ def interface_pool_gunicorn(interface_type):
             log("port_list" + str(port_list))
             if port_index >= len(port_list):
                 port_index = 0
-            _port = get_args_from_config(ip_port_dict, _ip, interface_type)[port_index][0]
+            _port = port_list[port_index][0]
+
+            # # 选取端口, 使用gunicorn则直接选第一个
+            # _ip = _ip.split("_")[0]
+            # port_list = get_args_from_config(ip_port_dict, _ip, interface_type)
+            # if
+            # print(port_list)
+            # _port = port_list[0][0]
+
             # 更新flag
             if ip_flag_list[ip_index][1] >= 10000:
                 ip_flag_list[ip_index][1] = 0
@@ -904,4 +980,4 @@ if __name__ == "__main__":
     set_flask_global()
 
     for i in range(10):
-        print("result", interface_pool_gunicorn("ocr"))
+        print("result", interface_pool_gunicorn("otr"))

+ 2 - 2
format_convert/convert_test.py

@@ -57,8 +57,8 @@ if __name__ == '__main__':
         # file_path = "C:/Users/Administrator/Desktop/test_xls/merge_cell.xlsx"
         # file_path = "D:/BIDI_DOC/比地_文档/2022/Test_Interface/20210609202634853485.xlsx"
         # file_path = "D:/BIDI_DOC/比地_文档/2022/Test_ODPS/1624325845476.pdf"
-        file_path = "C:/Users/Administrator/Downloads/1660296734009.pdf"
-        # file_path = "C:/Users/Administrator/Desktop/test_image/error10.png"
+        # file_path = "C:/Users/Administrator/Downloads/1660296734009.pdf"
+        file_path = "C:/Users/Administrator/Desktop/test_image/error3.png"
     else:
         file_path = "1660296734009.pdf"
     test_one(file_path, from_remote=True)

+ 0 - 5
format_convert/kill_all.py

@@ -4,15 +4,10 @@ import sys
 import psutil
 sys.path.append(os.path.dirname(os.path.abspath(__file__)) + "/../")
 from format_convert.utils import get_ip_port, get_intranet_ip, get_using_ip, get_args_from_config
-
 import time
 
 ip_port_dict = get_ip_port()
-# ip = "http://127.0.0.1"
 ip = get_using_ip()
-# python_path = ip_port_dict.get(ip).get("MASTER").get("python_path")
-# gunicorn_path = ip_port_dict.get(ip).get("MASTER").get("gunicorn_path")
-# project_path = ip_port_dict.get(ip).get("MASTER").get("project_path")
 python_path = get_args_from_config(ip_port_dict, ip, "python_path")[0]
 project_path = get_args_from_config(ip_port_dict, ip, "project_path")[0]
 gunicorn_path = get_args_from_config(ip_port_dict, ip, "gunicorn_path")[0]

+ 32 - 44
format_convert/monitor_process_config.py

@@ -8,56 +8,19 @@ import psutil
 sys.path.append(os.path.dirname(os.path.abspath(__file__)) + "/../")
 from format_convert.utils import get_ip_port, get_intranet_ip, get_args_from_config, get_all_ip, get_using_ip
 
-# ip_port_dict = get_ip_port()
-# ip = "http://" + get_intranet_ip()
-# ip = "http://127.0.0.1"
-# convert_port_list = ip_port_dict.get(ip).get("convert")
-# ocr_port_list = ip_port_dict.get(ip).get("ocr")
-# otr_port_list = ip_port_dict.get(ip).get("otr")
-# soffice_port_list = ip_port_dict.get(ip).get("office")
-# if not convert_port_list:
-#     convert_port_list = []
-# if not ocr_port_list:
-#     ocr_port_list = []
-# if not otr_port_list:
-#     otr_port_list = []
-# if not soffice_port_list:
-#     soffice_port_list = []
-#
-# # schedule_port_list = ip_port_dict.get(ip).get("schedule")
-# # python_path = ip_port_dict.get(ip).get("python_path")
-# # project_path = ip_port_dict.get(ip).get("project_path")
-#
-# interface_path = project_path[:-1]
-# std_out = " >>/convert.out 2>&1 &"
-# std_out_gpu = " >>/gpu.out 2>&1 &"
-# std_out_schedule = " >>/schedule.out 2>&1 &"
-# # convert_comm = "nohup " + python_path + " " + interface_path + "/format_convert/convert.py #" + std_out
-# # ocr_comm = "nohup " + python_path + " " + interface_path + "/ocr/ocr_gpu_interface.py # 0" + std_out_gpu
-# # otr_comm = "nohup " + python_path + " " + interface_path + "/otr/otr_gpu_interface.py # 0" + std_out_gpu
-# schedule_comm = "nohup " + python_path + " " + interface_path + "/format_convert/schedule_interface.py #" + std_out_schedule
-# soffice_comm = "docker run --init -itd --log-opt max-size=10m --log-opt max-file=3 -p #:16000 soffice:v2 bash"
-#
-#
-# gunicorn_path = python_path
-# # print("convert_port_list", len(convert_port_list))
-# convert_comm = "nohup " + gunicorn_path + " -w " + str(len(convert_port_list)) + " -t 300 -b 0.0.0.0:# --chdir " \
-#                + interface_path + "/format_convert convert:app" + std_out
-# ocr_comm = "nohup " + gunicorn_path + " -w " + str(len(ocr_port_list)) + " -t 300 --keep-alive 600 -b 0.0.0.0:# --chdir " \
-#            + interface_path + "/ocr ocr_gpu_interface:app" + std_out_gpu
-# otr_comm = "nohup " + gunicorn_path + " -w " + str(len(otr_port_list)) + " -t 300 --keep-alive 600 -b 0.0.0.0:# --chdir " \
-#            + interface_path + "/otr otr_gpu_interface:app" + std_out_gpu
-
 
 ip_port_dict = get_ip_port()
 ip = get_using_ip()
 print("local ip:", ip)
 
+# 获取各个参数
 convert_port_list = get_args_from_config(ip_port_dict, ip, "convert", "MASTER")
 if convert_port_list:
     convert_port_list = convert_port_list[0]
 ocr_port_list = get_args_from_config(ip_port_dict, ip, "ocr")
 otr_port_list = get_args_from_config(ip_port_dict, ip, "otr")
+idc_port_list = get_args_from_config(ip_port_dict, ip, "idc")
+isr_port_list = get_args_from_config(ip_port_dict, ip, "isr")
 soffice_port_list = get_args_from_config(ip_port_dict, ip, "office", "MASTER")
 if soffice_port_list:
     soffice_port_list = soffice_port_list[0]
@@ -71,11 +34,15 @@ std_out_schedule = " >>/schedule.out 2>&1 &"
 print("convert_port_list", convert_port_list)
 print("ocr_port_list", ocr_port_list)
 print("otr_port_list", otr_port_list)
+print("idc_port_list", idc_port_list)
+print("isr_port_list", isr_port_list)
 print("soffice_port_list", soffice_port_list)
 
-
+# 根据port生成gunicorn语句
 ocr_comm_list = []
 otr_comm_list = []
+isr_comm_list = []
+idc_comm_list = []
 for i in range(len(ocr_port_list)):
     ocr_comm_list.append("nohup " + gunicorn_path_list[i] + " -w " + str(len(ocr_port_list[i]))
                          + " -t 300 --keep-alive 600 -b 0.0.0.0:# --chdir "
@@ -84,9 +51,16 @@ for i in range(len(otr_port_list)):
     otr_comm_list.append("nohup " + gunicorn_path_list[i] + " -w " + str(len(otr_port_list[i]))
                          + " -t 300 --keep-alive 600 -b 0.0.0.0:# --chdir "
                          + project_path_list[i] + "/otr otr_interface:app" + std_out_gpu)
+for i in range(len(idc_port_list)):
+    idc_comm_list.append("nohup " + gunicorn_path_list[i] + " -w " + str(len(idc_port_list[i]))
+                         + " -t 300 --keep-alive 600 -b 0.0.0.0:# --chdir "
+                         + project_path_list[i] + "/idc idc_interface:app" + std_out_gpu)
+for i in range(len(isr_port_list)):
+    isr_comm_list.append("nohup " + gunicorn_path_list[i] + " -w " + str(len(isr_port_list[i]))
+                         + " -t 300 --keep-alive 600 -b 0.0.0.0:# --chdir "
+                         + project_path_list[i] + "/isr isr_interface:app" + std_out_gpu)
 convert_comm = "nohup " + gunicorn_path_list[0] + " -w " + str(len(convert_port_list)) + " -t 300 -b 0.0.0.0:# --chdir " \
                + project_path_list[0] + "/format_convert convert:app" + std_out
-schedule_comm = "nohup " + python_path_list[0] + " " + project_path_list[0] + "/format_convert/schedule_interface.py #" + std_out_schedule
 soffice_comm = "docker run --init -itd --log-opt max-size=10m --log-opt max-file=3 -p #:16000 soffice:v2 bash"
 
 
@@ -109,8 +83,10 @@ def restart(process_type, port, index=0):
         _comm = re.sub("#", port, otr_comm_list[index])
     elif process_type == "soffice":
         _comm = re.sub("#", port, soffice_comm)
-    elif process_type == "schedule":
-        _comm = re.sub("#", port, schedule_comm)
+    elif process_type == "idc":
+        _comm = re.sub("#", port, idc_comm_list[index])
+    elif process_type == "isr":
+        _comm = re.sub("#", port, isr_comm_list[index])
     else:
         _comm = "netstat -nltp"
         print("no process_type", process_type)
@@ -195,6 +171,18 @@ def monitor():
                 if p not in current_port_list:
                     restart("otr", p, index=j)
 
+    if idc_port_list:
+        for j in range(len(idc_port_list)):
+            for p in idc_port_list[j][:1]:
+                if p not in current_port_list:
+                    restart("idc", p, index=j)
+
+    if isr_port_list:
+        for j in range(len(isr_port_list)):
+            for p in isr_port_list[j][:1]:
+                if p not in current_port_list:
+                    restart("isr", p, index=j)
+
     if soffice_port_list:
         for p in soffice_port_list:
             if p not in current_port_list:

+ 46 - 25
format_convert/utils.py

@@ -51,7 +51,7 @@ def judge_error_code(_list, code=[0, -1, -2, -3, -4, -5, -6, -7, -8, -9, -10, -1
     [-9]: 接口接收数据为空
     """
     for c in code:
-        if _list == [c]:
+        if isinstance(_list, list) and _list == [c]:
             return True
     return False
 
@@ -1625,7 +1625,7 @@ def get_ip_port(node_type=None, interface_type=None):
         node_type_list = [node_type]
 
     if interface_type is None:
-        interface_type_list = ["convert", "ocr", "otr", "office", "path"]
+        interface_type_list = ["convert", "ocr", "otr", "office", "path", "isr", "idc"]
     else:
         interface_type_list = [interface_type]
 
@@ -1652,16 +1652,16 @@ def get_ip_port(node_type=None, interface_type=None):
                 processes = 0
                 port_list = []
                 interface_type = type2.upper()
-                if interface_type in ["convert".upper()]:
-                    _port = params.get(node_type).get(interface_type).get("port")
-                    if _port is None:
-                        port_list = []
-                    else:
-                        if interface_type == "convert".upper():
-                            processes = params.get(node_type).get(interface_type).get("processes")[j]
-                        port_list = [str(_port[j])]*int(processes)
-                        # port_list = [str(_port)]
-                elif interface_type == "path".upper():
+                # if interface_type in ["convert".upper()]:
+                #     _port = params.get(node_type).get(interface_type).get("port")
+                #     if _port is None:
+                #         port_list = []
+                #     else:
+                #         if interface_type == "convert".upper():
+                #             processes = params.get(node_type).get(interface_type).get("processes")[j]
+                #         port_list = [str(_port[j])]*int(processes)
+                #         # port_list = [str(_port)]
+                if interface_type == "path".upper():
                     python_path = params.get(node_type).get(interface_type).get("python")[j]
                     project_path = params.get(node_type).get(interface_type).get("project")[j]
                     gunicorn_path = params.get(node_type).get(interface_type).get("gunicorn")[j]
@@ -1671,7 +1671,10 @@ def get_ip_port(node_type=None, interface_type=None):
                     if port_start is None or port_no is None:
                         port_list = []
                     else:
-                        port_list = [str(x) for x in range(port_start[j], port_start[j]+port_no[j], 1)]
+                        if interface_type in ["office".upper()]:
+                            port_list = [str(x) for x in range(port_start[j], port_start[j]+port_no[j], 1)]
+                        else:
+                            port_list = [str(port_start[j])] * port_no[j]
                 # if ip_list:
                 #     for i in range(len(ip_list)):
 
@@ -1852,21 +1855,31 @@ def get_logger(_name, _dict):
 def set_flask_global():
     # 接口轮询所需锁、参数
     ip_port_flag = {}
-    ip_flag = []
+    # ip_flag = []
     ip_port_dict = get_ip_port()
     for _k in ip_port_dict.keys():
-        ip_port_flag.update({_k: {"ocr": 0,
-                                  "otr": 0,
-                                  "convert": 0,
-                                  "office": 0
-                                  }})
-        if ip_port_dict.get(_k).get("MASTER"):
-            ip_flag.append([_k+"_master", 0])
-        if ip_port_dict.get(_k).get("SLAVE"):
-            ip_flag.append([_k+"_slave", 0])
+        ip_port_flag.update({_k: {}})
+        for interface in ["ocr", "otr", "convert", "idc", "isr", "office"]:
+            if ip_port_dict.get(_k).get("MASTER"):
+                if ip_port_dict.get(_k).get("MASTER").get(interface):
+                    ip_port_flag[_k][interface] = 0
+            else:
+                if ip_port_dict.get(_k).get("SLAVE").get(interface):
+                    ip_port_flag[_k][interface] = 0
+        # ip_port_flag.update({_k: {"ocr": 0,
+        #                           "otr": 0,
+        #                           "convert": 0,
+        #                           "idc": 0,
+        #                           "isr": 0,
+        #                           "office": 0
+        #                           }})
+        # if ip_port_dict.get(_k).get("MASTER"):
+        #     ip_flag.append([_k+"_master", 0])
+        # if ip_port_dict.get(_k).get("SLAVE"):
+        #     ip_flag.append([_k+"_slave", 0])
     _global.update({"ip_port_flag": ip_port_flag})
     _global.update({"ip_port": ip_port_dict})
-    _global.update({"ip_flag": ip_flag})
+    # _global.update({"ip_flag": ip_flag})
     # print(globals().get("ip_port"))
 
 
@@ -2101,6 +2114,14 @@ def bytes2np(_b):
         return None
 
 
+def np2bytes(image_np):
+    # numpy转为可序列化的string
+    success, img_encode = cv2.imencode(".jpg", image_np)
+    # numpy -> bytes
+    img_bytes = img_encode.tobytes()
+    return img_bytes
+
+
 if __name__ == "__main__":
     # strs = r"D:\Project\temp\04384fcc9e8911ecbd2844f971944973\043876ca9e8911eca5e144f971944973_rar\1624114035529.jpeg"
     # print(slash_replace(strs))
@@ -2132,7 +2153,7 @@ if __name__ == "__main__":
     print(get_ip_port())
     # set_flask_global()
     # print(get_all_ip())
-    print(get_args_from_config(get_ip_port(), get_all_ip()[0], "gunicorn_path"))
+    print(get_args_from_config(get_ip_port(), get_all_ip()[0], "convert"))
     # print(get_args_from_config(get_ip_port(), "http://127.0.0.1", "gunicorn_path"))
     # print(get_intranet_ip())
     # _path = "C:/Users/Administrator/Downloads/3.png"

+ 97 - 43
idc/idc_interface.py

@@ -5,58 +5,107 @@ import os
 import time
 import sys
 import traceback
+from glob import glob
+os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
 sys.path.append(os.path.dirname(os.path.abspath(__file__)) + "/../")
 from format_convert import _global
 import cv2
 import numpy as np
 from PIL import Image
 from idc.model import direction_model
-from format_convert.utils import log, get_md5_from_bytes, request_post, np2pil, bytes2np, pil2np, pil_resize
+from format_convert.utils import log, get_md5_from_bytes, request_post, np2pil, bytes2np, pil2np, pil_resize, np2bytes
 import tensorflow as tf
+from flask import Flask, request
+
+tf.compat.v1.disable_eager_execution()
 sess = tf.compat.v1.Session(graph=tf.Graph())
 
+
 image_shape = (640, 640)
 
 
-def adjust_direction(image_np, model):
+def adjust_direction(image_np, model, if_return_angle=False):
     # 4个方向
     cls_num = 4
 
     # 构建数据
     origin_image = copy.deepcopy(image_np)
-    image_np = pil_resize(image_np, image_shape[0], image_shape[1])
+    # image_np = pil_resize(image_np, image_shape[0], image_shape[1])
     X = np.expand_dims(np.array(image_np), 0)
 
     # 预测
     with sess.as_default():
         with sess.graph.as_default():
-            pred = model.predict(X)
+            pred = model.predict(X, batch_size=1)
             pred = pred.astype(np.float64)
             pred = np.argmax(pred[0])
 
     # 根据分类计算角度
-    angle = 360 - pred*int((360/cls_num))
+    angle = int(360 - pred*int((360/cls_num)))
 
-    # 根据角度旋转
-    image_pil = Image.fromarray(origin_image)
-    image_rotate = np.array(image_pil.rotate(angle, expand=1))
-    return image_rotate
+    if if_return_angle:
+        return angle
+    else:
+        # 根据角度旋转
+        image_pil = Image.fromarray(origin_image)
+        image_rotate = np.array(image_pil.rotate(angle, expand=1))
+        return image_rotate
 
 
 def idc(data, model):
-    log("into idc_interface isr")
+    log("into idc_interface idc")
     try:
         # start_time = time.time()
         img_data = base64.b64decode(data)
         img_np = bytes2np(img_data)
-        image_rotate = adjust_direction(img_np, model)
+        angle = adjust_direction(img_np, model, if_return_angle=True)
         # print(time.time()-start_time)
-        return {"image": image_rotate}
+        return {"angle": angle}
     except TimeoutError:
-        return {"image": [-5]}
+        return {"angle": [-5]}
     except:
         traceback.print_exc()
-        return {"image": [-1]}
+        return {"angle": [-1]}
+
+
+# 接口配置
+app = Flask(__name__)
+
+
+@app.route('/idc', methods=['POST'])
+def _idc():
+    _global._init()
+    _global.update({"port": globals().get("port")})
+    start_time = time.time()
+
+    log("into idc_interface _idc")
+    try:
+        if not request.form:
+            log("idc no data!")
+            return json.dumps({"text": str([-9]), "bbox": str([-9])})
+        data = request.form.get("data")
+        log("idc_interface get data time" + str(time.time()-start_time))
+
+        img_data = base64.b64decode(data)
+        img_np = bytes2np(img_data)
+        _md5 = request.form.get("md5")
+        _global.update({"md5": _md5})
+
+        idc_model = globals().get("global_idc_model")
+        if idc_model is None:
+            print("=========== init idc model ===========")
+            idc_model = IdcModels().get_model()
+            globals().update({"global_idc_model": idc_model})
+
+        angle = adjust_direction(img_np, idc_model, if_return_angle=True)
+        return json.dumps({"angle": angle})
+    except TimeoutError:
+        return json.dumps({"angle": str([-5])})
+    except:
+        traceback.print_exc()
+        return json.dumps({"angle": str([-1])})
+    finally:
+        log("idc interface finish time " + str(time.time()-start_time))
 
 
 class IdcModels:
@@ -77,37 +126,42 @@ class IdcModels:
 
 
 def test_idc_model(from_remote=False):
-    file_path = "C:/Users/Administrator/Desktop/test_image/error10.jpg"
-    with open(file_path, "rb") as f:
-        file_bytes = f.read()
-    img_np = cv2.imread(file_path)
-    print(img_np.shape)
-    file_base64 = base64.b64encode(file_bytes)
-    _md5 = get_md5_from_bytes(file_bytes)[0]
-
-    _global._init()
-    _global.update({"port": 15010, "md5": _md5})
-
-    if from_remote:
-        file_json = {"data": file_base64, "md5": _md5}
-        # _url = "http://192.168.2.102:17000/ocr"
-        _url = "http://127.0.0.1:17000/ocr"
-        print(json.loads(request_post(_url, file_json)))
-    else:
-
-        idc_model = IdcModels().get_model()
-        result = idc(file_base64, idc_model)
-        # print(result)
-        if type(result.get("image")) == list:
-            print(result)
+    idc_model = IdcModels().get_model()
+    paths = glob("C:/Users/Administrator/Desktop/test_image/*")
+    # file_path = "C:/Users/Administrator/Desktop/test_image/error10.jpg"
+    for file_path in paths:
+        img_np = cv2.imread(file_path)
+        img_np = pil_resize(img_np, 640, 640)
+        print(img_np.shape)
+        file_bytes = np2bytes(img_np)
+        file_base64 = base64.b64encode(file_bytes)
+        _md5 = get_md5_from_bytes(file_bytes)[0]
+
+        _global._init()
+        _global.update({"port": 15010, "md5": _md5})
+
+        if from_remote:
+            file_json = {"data": file_base64, "md5": _md5}
+            # _url = "http://192.168.2.102:17000/ocr"
+            _url = "http://127.0.0.1:17000/ocr"
+            print(json.loads(request_post(_url, file_json)))
         else:
-            img = result.get("image")
-            print(img.shape)
-            cv2.namedWindow('img', cv2.WINDOW_NORMAL | cv2.WINDOW_KEEPRATIO)
-            cv2.imshow("img", img)
-            cv2.waitKey(0)
-        # print(result)
+
+            result = idc(file_base64, idc_model)
+            # print(result)
+            if type(result.get("angle")) == list:
+                print(result)
+            else:
+                angle = result.get("angle")
+                img = Image.fromarray(img_np)
+                img = np.array(img.rotate(angle, expand=1))
+                print(img.shape)
+                cv2.namedWindow('img', cv2.WINDOW_NORMAL | cv2.WINDOW_KEEPRATIO)
+                cv2.imshow("img", img)
+                cv2.waitKey(0)
+            # print(result)
 
 
 if __name__ == "__main__":
     test_idc_model()
+

+ 93 - 10
isr/isr_interface.py

@@ -4,6 +4,7 @@ import os
 import time
 import sys
 import traceback
+os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
 sys.path.append(os.path.dirname(os.path.abspath(__file__)) + "/../")
 from format_convert import _global
 import cv2
@@ -15,6 +16,9 @@ from isr.model import get_tiny_inference_model, seal_model, seal_model_se
 from isr.pre_process import count_red_pixel, get_anchors, get_classes, get_colors
 from isr.utils import get_best_predict_size, pil_resize, letterbox_image, draw_boxes, adjust_boxes
 import tensorflow as tf
+from flask import Flask, request
+
+tf.compat.v1.disable_eager_execution()
 sess1 = tf.compat.v1.Session(graph=tf.Graph())
 sess2 = tf.compat.v1.Session(graph=tf.Graph())
 
@@ -47,8 +51,8 @@ def detect_seal(image_np, model):
     image_pil = np2pil(image_np)
 
     # 首先判断红色像素
-    if not count_red_pixel(image_np):
-        return image_np, [], []
+    # if not count_red_pixel(image_np):
+    #     return image_np, [], []
 
     # create image input
     h, w = image_np.shape[:2]
@@ -71,15 +75,23 @@ def detect_seal(image_np, model):
     # inference data
     with sess1.as_default():
         with sess1.graph.as_default():
-            out_boxes, out_scores, out_classes = model.predict([image_resize, image_shape])
-    # print("image_size", image_shape)
-    # print("out_boxes", out_boxes)
-    # print("out_scores", out_scores)
-    # print("out_classes", out_classes)
-    out_boxes = out_boxes.astype(np.int32)
-    out_classes = out_classes.astype(np.int32)
+            try:
+                out_boxes, out_scores, out_classes = model.predict([image_resize, image_shape])
+            except ValueError as e:
+                if '(0, 4)' in traceback.format_exc():
+                    log("there is no seal!")
+                    out_boxes, out_scores, out_classes = [], [], []
+                else:
+                    traceback.print_exc()
+                    raise e
+
+    if not isinstance(out_boxes, list) and not isinstance(out_classes, list):
+        out_boxes = out_boxes.astype(np.int32)
+        out_classes = out_classes.astype(np.int32)
 
-    boxes = adjust_boxes(image_pil, out_boxes)
+        boxes = adjust_boxes(image_pil, out_boxes)
+    else:
+        boxes = out_boxes
 
     # # draw
     # class_names = get_classes(os.path.abspath(os.path.dirname(__file__))+"/yolo_data/my_classes.txt")
@@ -114,6 +126,77 @@ def isr(data, isr_yolo_model, isr_model):
         return {"image": [-1]}
 
 
+# 接口配置
+app = Flask(__name__)
+
+
+@app.route('/isr', methods=['POST'])
+def _isr():
+    _global._init()
+    _global.update({"port": globals().get("port")})
+    start_time = time.time()
+
+    log("into isr_interface _isr")
+    try:
+        if not request.form:
+            log("isr no data!")
+            return json.dumps({"text": str([-9]), "bbox": str([-9])})
+        data = request.form.get("data")
+        log("isr_interface get data time" + str(time.time()-start_time))
+
+        img_data = base64.b64decode(data)
+        img_np = bytes2np(img_data)
+        _md5 = request.form.get("md5")
+        _global.update({"md5": _md5})
+
+        # 初始化模型
+        isr_yolo_model = globals().get("global_isr_yolo_model")
+        isr_model = globals().get("global_isr_model")
+        if isr_model is None or isr_yolo_model is None:
+            print("=========== init isr model ===========")
+            isr_yolo_model, isr_model = IsrModels().get_model()
+            globals().update({"global_isr_yolo_model": isr_yolo_model})
+            globals().update({"global_isr_model": isr_model})
+
+        # 检测印章
+        _img, boxes, classes = detect_seal(img_np, isr_yolo_model)
+        # 检测不到,直接返回
+        if not boxes and not classes:
+            log("no seal detected!")
+            return json.dumps({"image": [1]})
+        else:
+            log("there is " + str(len(boxes)) + " seals")
+
+        # 截取
+        part_list = get_seal_part(_img, boxes, classes)
+
+        # 去除印章
+        new_part_list = []
+        for part in part_list:
+            part_remove = remove_seal(part, isr_model)
+            new_part_list.append(part_remove)
+
+        # 替换
+        img_replace = replace_seal_part(img_np, new_part_list, boxes)
+
+        # numpy转为可序列化的string
+        success, img_encode = cv2.imencode(".jpg", img_replace)
+        # numpy -> bytes
+        img_bytes = img_encode.tobytes()
+        # bytes -> base64 bytes
+        img_base64 = base64.b64encode(img_bytes)
+        # base64 bytes -> string (utf-8)
+        base64_string = img_base64.decode('utf-8')
+        return json.dumps({"image": base64_string})
+    except TimeoutError:
+        return json.dumps({"image": [-5]})
+    except:
+        traceback.print_exc()
+        return json.dumps({"image": [-1]})
+    finally:
+        log("isr interface finish time " + str(time.time()-start_time))
+
+
 class IsrModels:
     def __init__(self):
         # python文件所在目录