|
@@ -4,9 +4,9 @@ import os
|
|
|
import time
|
|
|
import sys
|
|
|
import traceback
|
|
|
-os.environ["CUDA_VISIBLE_DEVICES"] = "0"
|
|
|
+os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
|
|
|
import tensorflow as tf
|
|
|
-
|
|
|
+tf.compat.v1.enable_eager_execution()
|
|
|
MAX_COMPUTE = False
|
|
|
|
|
|
if not MAX_COMPUTE:
|
|
@@ -34,15 +34,14 @@ from format_convert import _global
|
|
|
import cv2
|
|
|
import numpy as np
|
|
|
from PIL import Image
|
|
|
-from format_convert.utils import log, get_md5_from_bytes, request_post, np2pil, bytes2np, pil2np, get_platform
|
|
|
+from format_convert.utils import log, get_md5_from_bytes, request_post, np2pil, bytes2np, pil2np, get_platform, \
|
|
|
+ judge_error_code
|
|
|
from isr.post_process import get_seal_part, replace_seal_part
|
|
|
from isr.model import get_tiny_inference_model, seal_model, seal_model_se
|
|
|
from isr.pre_process import count_red_pixel, get_anchors, get_classes, get_colors
|
|
|
from isr.utils import get_best_predict_size, pil_resize, letterbox_image, draw_boxes, adjust_boxes
|
|
|
-
|
|
|
from flask import Flask, request
|
|
|
|
|
|
-tf.compat.v1.disable_eager_execution()
|
|
|
sess1 = tf.compat.v1.Session(graph=tf.Graph())
|
|
|
sess2 = tf.compat.v1.Session(graph=tf.Graph())
|
|
|
|
|
@@ -62,7 +61,9 @@ def remove_seal(image_np, model):
|
|
|
# predict
|
|
|
with sess2.as_default():
|
|
|
with sess2.graph.as_default():
|
|
|
- pred = model.predict(X)
|
|
|
+ pred = model.predict(X, batch_size=1000)
|
|
|
+ # pred = model(X, training=False)
|
|
|
+ # pred = pred.eval()
|
|
|
pred = pred[0]*255.
|
|
|
pred = pred.astype(np.uint8)
|
|
|
pred = pil_resize(pred, h, w)
|
|
@@ -99,8 +100,7 @@ def detect_seal(image_np, model):
|
|
|
# inference data
|
|
|
with sess1.as_default():
|
|
|
with sess1.graph.as_default():
|
|
|
- out_boxes, out_scores, out_classes = model.predict([image_resize, image_shape], steps=1)
|
|
|
-
|
|
|
+ out_boxes, out_scores, out_classes = model.predict([image_resize, image_shape], batch_size=1000, steps=1)
|
|
|
if int(out_boxes.shape[0]) == 0:
|
|
|
log("there is no seal!")
|
|
|
return image_np, [], []
|
|
@@ -109,8 +109,6 @@ def detect_seal(image_np, model):
|
|
|
out_boxes = out_boxes.astype(np.int32)
|
|
|
out_classes = out_classes.astype(np.int32)
|
|
|
boxes = adjust_boxes(image_pil, out_boxes)
|
|
|
- # else:
|
|
|
- # boxes = out_boxes
|
|
|
|
|
|
# # draw
|
|
|
# class_names = get_classes(os.path.abspath(os.path.dirname(__file__))+"/yolo_data/my_classes.txt")
|
|
@@ -128,15 +126,34 @@ def isr(data, isr_yolo_model, isr_model):
|
|
|
try:
|
|
|
img_data = base64.b64decode(data)
|
|
|
img_np = bytes2np(img_data)
|
|
|
+
|
|
|
+ # 检测印章
|
|
|
+ start_time = time.time()
|
|
|
_img, boxes, classes = detect_seal(img_np, isr_yolo_model)
|
|
|
+ log("detect_seal cost " + str(time.time()-start_time))
|
|
|
+
|
|
|
+ # 检测不到,直接返回
|
|
|
if not boxes and not classes:
|
|
|
- return {"image": img_np}
|
|
|
+ log("no seal detected! return 1")
|
|
|
+ return {"image": [1]}
|
|
|
+
|
|
|
+ # 截取
|
|
|
+ start_time = time.time()
|
|
|
part_list = get_seal_part(_img, boxes, classes)
|
|
|
+ log("get_seal_part cost " + str(time.time()-start_time))
|
|
|
+
|
|
|
+ # 去除印章
|
|
|
+ start_time = time.time()
|
|
|
new_part_list = []
|
|
|
for part in part_list:
|
|
|
part_remove = remove_seal(part, isr_model)
|
|
|
new_part_list.append(part_remove)
|
|
|
+ log("remove_seal cost " + str(time.time()-start_time))
|
|
|
+
|
|
|
+ # 替换
|
|
|
+ start_time = time.time()
|
|
|
img_replace = replace_seal_part(img_np, new_part_list, boxes)
|
|
|
+ log("replace_seal_part cost " + str(time.time()-start_time))
|
|
|
return {"image": img_replace}
|
|
|
except TimeoutError:
|
|
|
return {"image": [-5]}
|
|
@@ -161,10 +178,8 @@ def _isr():
|
|
|
log("isr no data!")
|
|
|
return json.dumps({"text": str([-9]), "bbox": str([-9])})
|
|
|
data = request.form.get("data")
|
|
|
- log("isr_interface get data time" + str(time.time()-start_time))
|
|
|
+ log("isr_interface get data time " + str(time.time()-start_time))
|
|
|
|
|
|
- img_data = base64.b64decode(data)
|
|
|
- img_np = bytes2np(img_data)
|
|
|
_md5 = request.form.get("md5")
|
|
|
_global.update({"md5": _md5})
|
|
|
|
|
@@ -177,26 +192,12 @@ def _isr():
|
|
|
globals().update({"global_isr_yolo_model": isr_yolo_model})
|
|
|
globals().update({"global_isr_model": isr_model})
|
|
|
|
|
|
- # 检测印章
|
|
|
- _img, boxes, classes = detect_seal(img_np, isr_yolo_model)
|
|
|
- # 检测不到,直接返回
|
|
|
- if not boxes and not classes:
|
|
|
- log("no seal detected! return 1")
|
|
|
- return json.dumps({"image": [1]})
|
|
|
- else:
|
|
|
- log("there are " + str(len(boxes)) + " seals")
|
|
|
-
|
|
|
- # 截取
|
|
|
- part_list = get_seal_part(_img, boxes, classes)
|
|
|
-
|
|
|
- # 去除印章
|
|
|
- new_part_list = []
|
|
|
- for part in part_list:
|
|
|
- part_remove = remove_seal(part, isr_model)
|
|
|
- new_part_list.append(part_remove)
|
|
|
-
|
|
|
- # 替换
|
|
|
- img_replace = replace_seal_part(img_np, new_part_list, boxes)
|
|
|
+ # 检测+去除
|
|
|
+ result = isr(data, isr_yolo_model, isr_model)
|
|
|
+ result = result.get("image")
|
|
|
+ if judge_error_code(result):
|
|
|
+ return json.dumps({"image": result})
|
|
|
+ img_replace = result
|
|
|
|
|
|
# numpy转为可序列化的string
|
|
|
success, img_encode = cv2.imencode(".jpg", img_replace)
|
|
@@ -230,6 +231,7 @@ class IsrModels:
|
|
|
with sess1.graph.as_default():
|
|
|
self.isr_yolo_model = get_tiny_inference_model(anchors, len(class_names), weights_path=model_path)
|
|
|
self.isr_yolo_model.load_weights(model_path)
|
|
|
+ # self.isr_yolo_model.compile(run_eagerly=True)
|
|
|
|
|
|
# remove
|
|
|
model_path = _dir + "/models/seal_remove_unet.h5"
|
|
@@ -246,7 +248,7 @@ class IsrModels:
|
|
|
def test_isr_model(from_remote=False):
|
|
|
if get_platform() == "Windows":
|
|
|
file_path = "C:/Users/Administrator/Desktop/test_image/error10.jpg"
|
|
|
- file_path = "C:\\Users\\Administrator\\Downloads\\1647913696016.jpg"
|
|
|
+ # file_path = "C:\\Users\\Administrator\\Downloads\\1647913696016.jpg"
|
|
|
else:
|
|
|
file_path = "error10.jpg"
|
|
|
with open(file_path, "rb") as f:
|
|
@@ -272,22 +274,27 @@ def test_isr_model(from_remote=False):
|
|
|
image_np = cv2.imdecode(buffer, 1)
|
|
|
print(image_np.shape)
|
|
|
else:
|
|
|
- isr_yolo_model, isr_model = IsrModels().get_model()
|
|
|
- result = isr(file_base64, isr_yolo_model, isr_model)
|
|
|
+ if globals().get("global_isr_model") is None:
|
|
|
+ isr_yolo_model, isr_model = IsrModels().get_model()
|
|
|
+ globals().update({"global_isr_yolo_model": isr_yolo_model})
|
|
|
+ globals().update({"global_isr_model": isr_model})
|
|
|
+ result = isr(file_base64,
|
|
|
+ globals().get("global_isr_yolo_model"),
|
|
|
+ globals().get("global_isr_model"))
|
|
|
# print(result)
|
|
|
if type(result.get("image")) == list:
|
|
|
print("result", len(result))
|
|
|
else:
|
|
|
img = result.get("image")
|
|
|
print(img.shape)
|
|
|
- cv2.namedWindow('img', cv2.WINDOW_NORMAL | cv2.WINDOW_KEEPRATIO)
|
|
|
- cv2.imshow("img", img)
|
|
|
- cv2.waitKey(0)
|
|
|
+ # cv2.namedWindow('img', cv2.WINDOW_NORMAL | cv2.WINDOW_KEEPRATIO)
|
|
|
+ # cv2.imshow("img", img)
|
|
|
+ # cv2.waitKey(0)
|
|
|
# print(result)
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
for i in range(100):
|
|
|
s_t = time.time()
|
|
|
- test_isr_model(from_remote=False)
|
|
|
+ test_isr_model(from_remote=True)
|
|
|
print("finish test_isr_model", time.time()-s_t)
|