fangjiasheng před 2 roky
rodič
revize
b23bcbec45
2 změnil soubory, kde provedl 63 přidání a 31 odebrání
  1. 1 2
      format_convert/convert_image.py
  2. 62 29
      isr/isr_interface.py

+ 1 - 2
format_convert/convert_image.py

@@ -123,7 +123,6 @@ def image_process(image_np, image_path, is_from_pdf=False, is_from_docx=False, u
         # cv2.imwrite(idc_path, image_np)
 
         # isr模型去除印章
-        image_np_source = image_np
         _isr_time = time.time()
         if count_red_pixel(image_np):
             # 红色像素达到一定值才过模型
@@ -142,7 +141,7 @@ def image_process(image_np, image_path, is_from_pdf=False, is_from_docx=False, u
             else:
                 isr_path = image_path.split(".")[0] + "_isr." + image_path.split(".")[-1]
                 cv2.imwrite(isr_path, image_np)
-        log("isr total time"+str(time.time()-_isr_time))
+        log("isr total time "+str(time.time()-_isr_time))
 
         # otr模型识别表格,需要图片resize成模型所需大小, 写入另一个路径
         best_h, best_w = get_best_predict_size(image_np)

+ 62 - 29
isr/isr_interface.py

@@ -4,18 +4,42 @@ import os
 import time
 import sys
 import traceback
-os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
+os.environ["CUDA_VISIBLE_DEVICES"] = "0"
+import tensorflow as tf
+
+MAX_COMPUTE = False
+
+if not MAX_COMPUTE:
+    # tensorflow 内存设置
+    try:
+        gpus = tf.config.list_physical_devices('GPU')
+        if len(gpus) > 0:
+            tf.config.experimental.set_virtual_device_configuration(
+                gpus[0],
+                [tf.config.experimental.VirtualDeviceConfiguration(memory_limit=1024)])
+    except:
+        traceback.print_exc()
+        # pass
+        # gpus = tf.config.list_physical_devices('GPU')
+        # for gpu in gpus:  # 如果使用多块GPU时
+        #     tf.config.experimental.set_memory_growth(gpu, True)
+
+        os.environ['CUDA_CACHE_MAXSIZE'] = str(2147483648)
+        os.environ['CUDA_CACHE_DISABLE'] = str(0)
+        gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.1)
+        sess = tf.Session(config=tf.ConfigProto(gpu_options=gpu_options))
+
 sys.path.append(os.path.dirname(os.path.abspath(__file__)) + "/../")
 from format_convert import _global
 import cv2
 import numpy as np
 from PIL import Image
-from format_convert.utils import log, get_md5_from_bytes, request_post, np2pil, bytes2np, pil2np
+from format_convert.utils import log, get_md5_from_bytes, request_post, np2pil, bytes2np, pil2np, get_platform
 from isr.post_process import get_seal_part, replace_seal_part
 from isr.model import get_tiny_inference_model, seal_model, seal_model_se
 from isr.pre_process import count_red_pixel, get_anchors, get_classes, get_colors
 from isr.utils import get_best_predict_size, pil_resize, letterbox_image, draw_boxes, adjust_boxes
-import tensorflow as tf
+
 from flask import Flask, request
 
 tf.compat.v1.disable_eager_execution()
@@ -75,25 +99,18 @@ def detect_seal(image_np, model):
     # inference data
     with sess1.as_default():
         with sess1.graph.as_default():
-            try:
-                out_boxes, out_scores, out_classes = model.predict([image_resize, image_shape],steps=1)
-            except ValueError as e:
-                # if '(0, 4)' in traceback.format_exc():
-                #     log("there is no seal!")
-                #     out_boxes, out_scores, out_classes = [], [], []
-                # else:
-                #     traceback.print_exc()
-                #     raise e
-                traceback.print_exc()
-                out_boxes, out_scores, out_classes = [], [], []
-
-    if not isinstance(out_boxes, list) and not isinstance(out_classes, list):
-        out_boxes = out_boxes.astype(np.int32)
-        out_classes = out_classes.astype(np.int32)
-
-        boxes = adjust_boxes(image_pil, out_boxes)
+            out_boxes, out_scores, out_classes = model.predict([image_resize, image_shape], steps=1)
+
+    if int(out_boxes.shape[0]) == 0:
+        log("there is no seal!")
+        return image_np, [], []
     else:
-        boxes = out_boxes
+        log("there are " + str(out_boxes.shape[0]) + " seals!")
+    out_boxes = out_boxes.astype(np.int32)
+    out_classes = out_classes.astype(np.int32)
+    boxes = adjust_boxes(image_pil, out_boxes)
+    # else:
+    #     boxes = out_boxes
 
     # # draw
     # class_names = get_classes(os.path.abspath(os.path.dirname(__file__))+"/yolo_data/my_classes.txt")
@@ -164,10 +181,10 @@ def _isr():
         _img, boxes, classes = detect_seal(img_np, isr_yolo_model)
         # 检测不到,直接返回
         if not boxes and not classes:
-            log("no seal detected!")
+            log("no seal detected! return 1")
             return json.dumps({"image": [1]})
         else:
-            log("there is " + str(len(boxes)) + " seals")
+            log("there are " + str(len(boxes)) + " seals")
 
         # 截取
         part_list = get_seal_part(_img, boxes, classes)
@@ -227,7 +244,11 @@ class IsrModels:
 
 
 def test_isr_model(from_remote=False):
-    file_path = "C:/Users/Administrator/Desktop/test_image/error10.jpg"
+    if get_platform() == "Windows":
+        file_path = "C:/Users/Administrator/Desktop/test_image/error10.jpg"
+        file_path = "C:\\Users\\Administrator\\Downloads\\1647913696016.jpg"
+    else:
+        file_path = "error10.jpg"
     with open(file_path, "rb") as f:
         file_bytes = f.read()
     file_base64 = base64.b64encode(file_bytes)
@@ -238,15 +259,24 @@ def test_isr_model(from_remote=False):
 
     if from_remote:
         file_json = {"data": file_base64, "md5": _md5}
-        # _url = "http://192.168.2.102:17000/ocr"
-        _url = "http://127.0.0.1:17000/ocr"
-        print(json.loads(request_post(_url, file_json)))
+        # _url = "http://192.168.2.102:18040/isr"
+        _url = "http://127.0.0.1:18040/isr"
+        result = json.loads(request_post(_url, file_json))
+        if type(result.get("image")) == list:
+            print("result", result)
+        else:
+            img = result.get("image")
+            image_base64 = img.encode("utf-8")
+            image_bytes = base64.b64decode(image_base64)
+            buffer = np.frombuffer(image_bytes, dtype=np.uint8)
+            image_np = cv2.imdecode(buffer, 1)
+            print(image_np.shape)
     else:
         isr_yolo_model, isr_model = IsrModels().get_model()
         result = isr(file_base64, isr_yolo_model, isr_model)
         # print(result)
         if type(result.get("image")) == list:
-            print(result)
+            print("result", len(result))
         else:
             img = result.get("image")
             print(img.shape)
@@ -257,4 +287,7 @@ def test_isr_model(from_remote=False):
 
 
 if __name__ == "__main__":
-    test_isr_model()
+    for i in range(100):
+        s_t = time.time()
+        test_isr_model(from_remote=False)
+        print("finish test_isr_model", time.time()-s_t)