import base64 import json import os import time import sys import traceback # os.environ["CUDA_VISIBLE_DEVICES"] = "0" sys.path.append(os.path.dirname(os.path.abspath(__file__)) + "/../") from format_convert.max_compute_config import max_compute import tensorflow as tf tf.compat.v1.disable_eager_execution() MAX_COMPUTE = max_compute if not MAX_COMPUTE: # tensorflow 内存设置 try: gpus = tf.config.list_physical_devices('GPU') if len(gpus) > 0: tf.config.experimental.set_virtual_device_configuration( gpus[0], [tf.config.experimental.VirtualDeviceConfiguration(memory_limit=1024)]) except: traceback.print_exc() # pass # gpus = tf.config.list_physical_devices('GPU') # for gpu in gpus: # 如果使用多块GPU时 # tf.config.experimental.set_memory_growth(gpu, True) os.environ['CUDA_CACHE_MAXSIZE'] = str(2147483648) os.environ['CUDA_CACHE_DISABLE'] = str(0) gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.05) sess = tf.Session(config=tf.ConfigProto(gpu_options=gpu_options)) from format_convert import _global import cv2 import numpy as np from PIL import Image from format_convert.utils import log, get_md5_from_bytes, request_post, np2pil, bytes2np, pil2np, get_platform, \ judge_error_code from isr.post_process import get_seal_part, replace_seal_part from isr.model import get_tiny_inference_model, seal_model, seal_model_se from isr.pre_process import count_red_pixel, get_anchors, get_classes, get_colors from isr.utils import get_best_predict_size, pil_resize, letterbox_image, draw_boxes, adjust_boxes from flask import Flask, request tf.compat.v1.disable_eager_execution() sess1 = tf.compat.v1.Session(graph=tf.Graph()) sess2 = tf.compat.v1.Session(graph=tf.Graph()) def remove_seal(image_np, model): # inference data image_seal = image_np h, w = image_seal.shape[:2] best_h, best_w = get_best_predict_size(image_seal) X = np.zeros((1, best_h, best_w, 3)) # resize image_seal = pil_resize(image_seal, best_h, best_w) # cv2.imshow("resize", image_seal) X[0] = image_seal / 255 # predict with sess2.as_default(): with sess2.graph.as_default(): pred = model.predict(X, batch_size=1000) # pred = model(X, training=False) # pred = pred.eval() pred = pred[0]*255. pred = pred.astype(np.uint8) pred = pil_resize(pred, h, w) # cv2.imshow("pred", pred) # cv2.waitKey(0) return pred def detect_seal(image_np, model): image_pil = np2pil(image_np) # 首先判断红色像素 # if not count_red_pixel(image_np): # return image_np, [], [] # create image input h, w = image_np.shape[:2] # best_h, best_w = get_best_predict_size(image_np, times=32, max_size=1280) best_h, best_w = 1024, 1024 image_resize = letterbox_image(image_pil, tuple(reversed([best_h, best_w]))) # cv2.imshow("letterbox_image", pil2np(image_resize)) # cv2.waitKey(0) # image_resize = pil_resize(image_np, best_h, best_w) # image_resize = image_pil.resize((int(416), int(416)), Image.BICUBIC) image_resize = np.array(image_resize, dtype='float32') image_resize = image_resize.astype('float32') / 255. image_resize = np.expand_dims(image_resize, 0) # create image shape input image_shape = np.array([image_pil.size[1], image_pil.size[0]]) image_shape = np.expand_dims(image_shape, 0) # inference data with sess1.as_default(): with sess1.graph.as_default(): out_boxes, out_scores, out_classes = model.predict([image_resize, image_shape], steps=1) if int(out_boxes.shape[0]) == 0: log("there is no seal!") return image_np, [], [] else: log("there are " + str(out_boxes.shape[0]) + " seals!") out_boxes = out_boxes.astype(np.int32) out_classes = out_classes.astype(np.int32) boxes = adjust_boxes(image_pil, out_boxes) # # draw # class_names = get_classes(os.path.abspath(os.path.dirname(__file__))+"/yolo_data/my_classes.txt") # colors = get_colors(len(class_names)) # image_draw = draw_boxes(image_pil, out_boxes, out_classes, out_scores, class_names, colors) # image_draw = cv2.cvtColor(np.array(image_draw), cv2.COLOR_RGB2BGR) # cv2.namedWindow('detect', cv2.WINDOW_NORMAL | cv2.WINDOW_KEEPRATIO) # cv2.imshow("detect", image_draw) # cv2.waitKey(0) return image_np, boxes, out_classes def isr(data, isr_yolo_model, isr_model): log("into isr_interface isr") try: img_data = base64.b64decode(data) img_np = bytes2np(img_data) # 检测印章 start_time = time.time() _img, boxes, classes = detect_seal(img_np, isr_yolo_model) log("detect_seal cost " + str(time.time()-start_time)) # 检测不到,直接返回 if not boxes and not classes: log("no seal detected! return 1") return {"image": [1]} # 截取 start_time = time.time() part_list = get_seal_part(_img, boxes, classes) log("get_seal_part cost " + str(time.time()-start_time)) # 去除印章 start_time = time.time() new_part_list = [] for part in part_list: part_remove = remove_seal(part, isr_model) new_part_list.append(part_remove) log("remove_seal cost " + str(time.time()-start_time)) # 替换 start_time = time.time() img_replace = replace_seal_part(img_np, new_part_list, boxes) log("replace_seal_part cost " + str(time.time()-start_time)) return {"image": img_replace} except TimeoutError: return {"image": [-5]} except: traceback.print_exc() return {"image": [-1]} # 接口配置 app = Flask(__name__) @app.route('/isr', methods=['POST']) def _isr(): _global._init() _global.update({"port": globals().get("port")}) start_time = time.time() log("into isr_interface _isr") try: if not request.form: log("isr no data!") return json.dumps({"image": str([-9])}) data = request.form.get("data") log("isr_interface get data time " + str(time.time()-start_time)) _md5 = request.form.get("md5") _global.update({"md5": _md5}) # 初始化模型 isr_yolo_model = globals().get("global_isr_yolo_model") isr_model = globals().get("global_isr_model") if isr_model is None or isr_yolo_model is None: print("=========== init isr model ===========") isr_yolo_model, isr_model = IsrModels().get_model() globals().update({"global_isr_yolo_model": isr_yolo_model}) globals().update({"global_isr_model": isr_model}) # 检测+去除 # result = isr(data, isr_yolo_model, isr_model) # result = result.get("image") # if judge_error_code(result): # return json.dumps({"image": result}) # if isinstance(result, list) and result == [1]: # return json.dumps({"image": result}) # img_replace = result img_replace = isr(data, isr_yolo_model, isr_model).get("image") if isinstance(img_replace, list): return json.dumps({"image": img_replace}) # numpy转为可序列化的string success, img_encode = cv2.imencode(".jpg", img_replace) # numpy -> bytes img_bytes = img_encode.tobytes() # bytes -> base64 bytes img_base64 = base64.b64encode(img_bytes) # base64 bytes -> string (utf-8) base64_string = img_base64.decode('utf-8') return json.dumps({"image": base64_string}) except TimeoutError: return json.dumps({"image": [-5]}) except: traceback.print_exc() return json.dumps({"image": [-1]}) finally: log("isr interface finish time " + str(time.time()-start_time)) class IsrModels: def __init__(self): # python文件所在目录 _dir = os.path.abspath(os.path.dirname(__file__)) # detect model_path = _dir + "/models/seal_detect_yolo.h5" anchors = get_anchors(_dir + "/yolo_data/my_anchors.txt") class_names = get_classes(_dir + "/yolo_data/my_classes.txt") colors = get_colors(len(class_names)) with sess1.as_default(): with sess1.graph.as_default(): self.isr_yolo_model = get_tiny_inference_model(anchors, len(class_names), weights_path=model_path) self.isr_yolo_model.load_weights(model_path) # self.isr_yolo_model.compile(run_eagerly=True) # remove model_path = _dir + "/models/seal_remove_unet.h5" with sess2.as_default(): with sess2.graph.as_default(): self.isr_model = seal_model_se(input_shape=(None, None, 3), output_shape=(None, None, 3)) self.isr_model.load_weights(model_path) def get_model(self): return [self.isr_yolo_model, self.isr_model] def test_isr_model(from_remote=False): if get_platform() == "Windows": file_path = "C:/Users/Administrator/Desktop/test_image/114.jpg" # file_path = "C:\\Users\\Administrator\\Downloads\\1647913696016.jpg" else: file_path = "error10.jpg" with open(file_path, "rb") as f: file_bytes = f.read() file_base64 = base64.b64encode(file_bytes) _md5 = get_md5_from_bytes(file_bytes)[0] _global._init() _global.update({"port": 15010, "md5": _md5}) if from_remote: file_json = {"data": file_base64, "md5": _md5} # _url = "http://192.168.2.102:18040/isr" _url = "http://127.0.0.1:18040/isr" result = json.loads(request_post(_url, file_json)) if type(result.get("image")) == list: print("result", result) else: img = result.get("image") image_base64 = img.encode("utf-8") image_bytes = base64.b64decode(image_base64) buffer = np.frombuffer(image_bytes, dtype=np.uint8) image_np = cv2.imdecode(buffer, 1) print(image_np.shape) else: if globals().get("global_isr_model") is None: isr_yolo_model, isr_model = IsrModels().get_model() globals().update({"global_isr_yolo_model": isr_yolo_model}) globals().update({"global_isr_model": isr_model}) result = isr(file_base64, globals().get("global_isr_yolo_model"), globals().get("global_isr_model")) # print(result) if type(result.get("image")) == list: print("result", len(result)) else: img = result.get("image") print(img.shape) cv2.namedWindow('img', cv2.WINDOW_NORMAL | cv2.WINDOW_KEEPRATIO) cv2.imshow("img", img) cv2.waitKey(0) # print(result) if __name__ == "__main__": for i in range(1): s_t = time.time() test_isr_model(from_remote=False) print("finish test_isr_model", time.time()-s_t)