Browse Source

优化isr速度

fangjiasheng 2 years ago
parent
commit
ed432afd9a
2 changed files with 50 additions and 43 deletions
  1. 48 41
      isr/isr_interface.py
  2. 2 2
      isr/post_process.py

+ 48 - 41
isr/isr_interface.py

@@ -4,9 +4,9 @@ import os
 import time
 import sys
 import traceback
-os.environ["CUDA_VISIBLE_DEVICES"] = "0"
+os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
 import tensorflow as tf
-
+tf.compat.v1.enable_eager_execution()
 MAX_COMPUTE = False
 
 if not MAX_COMPUTE:
@@ -34,15 +34,14 @@ from format_convert import _global
 import cv2
 import numpy as np
 from PIL import Image
-from format_convert.utils import log, get_md5_from_bytes, request_post, np2pil, bytes2np, pil2np, get_platform
+from format_convert.utils import log, get_md5_from_bytes, request_post, np2pil, bytes2np, pil2np, get_platform, \
+    judge_error_code
 from isr.post_process import get_seal_part, replace_seal_part
 from isr.model import get_tiny_inference_model, seal_model, seal_model_se
 from isr.pre_process import count_red_pixel, get_anchors, get_classes, get_colors
 from isr.utils import get_best_predict_size, pil_resize, letterbox_image, draw_boxes, adjust_boxes
-
 from flask import Flask, request
 
-tf.compat.v1.disable_eager_execution()
 sess1 = tf.compat.v1.Session(graph=tf.Graph())
 sess2 = tf.compat.v1.Session(graph=tf.Graph())
 
@@ -62,7 +61,9 @@ def remove_seal(image_np, model):
     # predict
     with sess2.as_default():
         with sess2.graph.as_default():
-            pred = model.predict(X)
+            pred = model.predict(X, batch_size=1000)
+            # pred = model(X, training=False)
+            # pred = pred.eval()
             pred = pred[0]*255.
             pred = pred.astype(np.uint8)
             pred = pil_resize(pred, h, w)
@@ -99,8 +100,7 @@ def detect_seal(image_np, model):
     # inference data
     with sess1.as_default():
         with sess1.graph.as_default():
-            out_boxes, out_scores, out_classes = model.predict([image_resize, image_shape], steps=1)
-
+            out_boxes, out_scores, out_classes = model.predict([image_resize, image_shape], batch_size=1000, steps=1)
     if int(out_boxes.shape[0]) == 0:
         log("there is no seal!")
         return image_np, [], []
@@ -109,8 +109,6 @@ def detect_seal(image_np, model):
     out_boxes = out_boxes.astype(np.int32)
     out_classes = out_classes.astype(np.int32)
     boxes = adjust_boxes(image_pil, out_boxes)
-    # else:
-    #     boxes = out_boxes
 
     # # draw
     # class_names = get_classes(os.path.abspath(os.path.dirname(__file__))+"/yolo_data/my_classes.txt")
@@ -128,15 +126,34 @@ def isr(data, isr_yolo_model, isr_model):
     try:
         img_data = base64.b64decode(data)
         img_np = bytes2np(img_data)
+
+        # 检测印章
+        start_time = time.time()
         _img, boxes, classes = detect_seal(img_np, isr_yolo_model)
+        log("detect_seal cost " + str(time.time()-start_time))
+
+        # 检测不到,直接返回
         if not boxes and not classes:
-            return {"image": img_np}
+            log("no seal detected! return 1")
+            return {"image": [1]}
+
+        # 截取
+        start_time = time.time()
         part_list = get_seal_part(_img, boxes, classes)
+        log("get_seal_part cost " + str(time.time()-start_time))
+
+        # 去除印章
+        start_time = time.time()
         new_part_list = []
         for part in part_list:
             part_remove = remove_seal(part, isr_model)
             new_part_list.append(part_remove)
+        log("remove_seal cost " + str(time.time()-start_time))
+
+        # 替换
+        start_time = time.time()
         img_replace = replace_seal_part(img_np, new_part_list, boxes)
+        log("replace_seal_part cost " + str(time.time()-start_time))
         return {"image": img_replace}
     except TimeoutError:
         return {"image": [-5]}
@@ -161,10 +178,8 @@ def _isr():
             log("isr no data!")
             return json.dumps({"text": str([-9]), "bbox": str([-9])})
         data = request.form.get("data")
-        log("isr_interface get data time" + str(time.time()-start_time))
+        log("isr_interface get data time " + str(time.time()-start_time))
 
-        img_data = base64.b64decode(data)
-        img_np = bytes2np(img_data)
         _md5 = request.form.get("md5")
         _global.update({"md5": _md5})
 
@@ -177,26 +192,12 @@ def _isr():
             globals().update({"global_isr_yolo_model": isr_yolo_model})
             globals().update({"global_isr_model": isr_model})
 
-        # 检测印章
-        _img, boxes, classes = detect_seal(img_np, isr_yolo_model)
-        # 检测不到,直接返回
-        if not boxes and not classes:
-            log("no seal detected! return 1")
-            return json.dumps({"image": [1]})
-        else:
-            log("there are " + str(len(boxes)) + " seals")
-
-        # 截取
-        part_list = get_seal_part(_img, boxes, classes)
-
-        # 去除印章
-        new_part_list = []
-        for part in part_list:
-            part_remove = remove_seal(part, isr_model)
-            new_part_list.append(part_remove)
-
-        # 替换
-        img_replace = replace_seal_part(img_np, new_part_list, boxes)
+        # 检测+去除
+        result = isr(data, isr_yolo_model, isr_model)
+        result = result.get("image")
+        if judge_error_code(result):
+            return json.dumps({"image": result})
+        img_replace = result
 
         # numpy转为可序列化的string
         success, img_encode = cv2.imencode(".jpg", img_replace)
@@ -230,6 +231,7 @@ class IsrModels:
             with sess1.graph.as_default():
                 self.isr_yolo_model = get_tiny_inference_model(anchors, len(class_names), weights_path=model_path)
                 self.isr_yolo_model.load_weights(model_path)
+                # self.isr_yolo_model.compile(run_eagerly=True)
 
         # remove
         model_path = _dir + "/models/seal_remove_unet.h5"
@@ -246,7 +248,7 @@ class IsrModels:
 def test_isr_model(from_remote=False):
     if get_platform() == "Windows":
         file_path = "C:/Users/Administrator/Desktop/test_image/error10.jpg"
-        file_path = "C:\\Users\\Administrator\\Downloads\\1647913696016.jpg"
+        # file_path = "C:\\Users\\Administrator\\Downloads\\1647913696016.jpg"
     else:
         file_path = "error10.jpg"
     with open(file_path, "rb") as f:
@@ -272,22 +274,27 @@ def test_isr_model(from_remote=False):
             image_np = cv2.imdecode(buffer, 1)
             print(image_np.shape)
     else:
-        isr_yolo_model, isr_model = IsrModels().get_model()
-        result = isr(file_base64, isr_yolo_model, isr_model)
+        if globals().get("global_isr_model") is None:
+            isr_yolo_model, isr_model = IsrModels().get_model()
+            globals().update({"global_isr_yolo_model": isr_yolo_model})
+            globals().update({"global_isr_model": isr_model})
+        result = isr(file_base64,
+                     globals().get("global_isr_yolo_model"),
+                     globals().get("global_isr_model"))
         # print(result)
         if type(result.get("image")) == list:
             print("result", len(result))
         else:
             img = result.get("image")
             print(img.shape)
-            cv2.namedWindow('img', cv2.WINDOW_NORMAL | cv2.WINDOW_KEEPRATIO)
-            cv2.imshow("img", img)
-            cv2.waitKey(0)
+            # cv2.namedWindow('img', cv2.WINDOW_NORMAL | cv2.WINDOW_KEEPRATIO)
+            # cv2.imshow("img", img)
+            # cv2.waitKey(0)
         # print(result)
 
 
 if __name__ == "__main__":
     for i in range(100):
         s_t = time.time()
-        test_isr_model(from_remote=False)
+        test_isr_model(from_remote=True)
         print("finish test_isr_model", time.time()-s_t)

+ 2 - 2
isr/post_process.py

@@ -89,8 +89,8 @@ def yolo_head(feats, anchors, num_classes, input_shape, calc_loss=False):
         feats, [-1, grid_shape[0], grid_shape[1], num_anchors, num_classes + 5])
 
     # Adjust predictions to each spatial grid point and anchor size.
-    box_xy = (K.sigmoid(feats[..., :2]) + grid) / K.cast(grid_shape[::-1], K.dtype(feats))
-    box_wh = K.exp(feats[..., 2:4]) * anchors_tensor / K.cast(input_shape[::-1], K.dtype(feats))
+    box_xy = (K.sigmoid(feats[..., :2]) + grid) / K.cast(grid_shape[..., ::-1], K.dtype(feats))
+    box_wh = K.exp(feats[..., 2:4]) * anchors_tensor / K.cast(input_shape[..., ::-1], K.dtype(feats))
     box_confidence = K.sigmoid(feats[..., 4:5])
     box_class_probs = K.sigmoid(feats[..., 5:])