import base64 import json import os import time import sys import traceback sys.path.append(os.path.dirname(os.path.abspath(__file__)) + "/../") from format_convert import _global import cv2 import numpy as np from PIL import Image from format_convert.utils import log, get_md5_from_bytes, request_post, np2pil, bytes2np, pil2np from isr.post_process import get_seal_part, replace_seal_part from isr.model import get_tiny_inference_model, seal_model, seal_model_se from isr.pre_process import count_red_pixel, get_anchors, get_classes, get_colors from isr.utils import get_best_predict_size, pil_resize, letterbox_image, draw_boxes, adjust_boxes import tensorflow as tf sess1 = tf.compat.v1.Session(graph=tf.Graph()) sess2 = tf.compat.v1.Session(graph=tf.Graph()) def remove_seal(image_np, model): # inference data image_seal = image_np h, w = image_seal.shape[:2] best_h, best_w = get_best_predict_size(image_seal) X = np.zeros((1, best_h, best_w, 3)) # resize image_seal = pil_resize(image_seal, best_h, best_w) # cv2.imshow("resize", image_seal) X[0] = image_seal / 255 # predict with sess2.as_default(): with sess2.graph.as_default(): pred = model.predict(X) pred = pred[0]*255. pred = pred.astype(np.uint8) pred = pil_resize(pred, h, w) # cv2.imshow("pred", pred) # cv2.waitKey(0) return pred def detect_seal(image_np, model): image_pil = np2pil(image_np) # 首先判断红色像素 if not count_red_pixel(image_np): return image_np, [], [] # create image input h, w = image_np.shape[:2] # best_h, best_w = get_best_predict_size(image_np, times=32, max_size=1280) best_h, best_w = 1024, 1024 image_resize = letterbox_image(image_pil, tuple(reversed([best_h, best_w]))) # cv2.imshow("letterbox_image", pil2np(image_resize)) # cv2.waitKey(0) # image_resize = pil_resize(image_np, best_h, best_w) # image_resize = image_pil.resize((int(416), int(416)), Image.BICUBIC) image_resize = np.array(image_resize, dtype='float32') image_resize = image_resize.astype('float32') / 255. image_resize = np.expand_dims(image_resize, 0) # create image shape input image_shape = np.array([image_pil.size[1], image_pil.size[0]]) image_shape = np.expand_dims(image_shape, 0) # inference data with sess1.as_default(): with sess1.graph.as_default(): out_boxes, out_scores, out_classes = model.predict([image_resize, image_shape]) # print("image_size", image_shape) # print("out_boxes", out_boxes) # print("out_scores", out_scores) # print("out_classes", out_classes) out_boxes = out_boxes.astype(np.int32) out_classes = out_classes.astype(np.int32) boxes = adjust_boxes(image_pil, out_boxes) # # draw # class_names = get_classes(os.path.abspath(os.path.dirname(__file__))+"/yolo_data/my_classes.txt") # colors = get_colors(len(class_names)) # image_draw = draw_boxes(image_pil, out_boxes, out_classes, out_scores, class_names, colors) # image_draw = cv2.cvtColor(np.array(image_draw), cv2.COLOR_RGB2BGR) # cv2.namedWindow('detect', cv2.WINDOW_NORMAL | cv2.WINDOW_KEEPRATIO) # cv2.imshow("detect", image_draw) # cv2.waitKey(0) return image_np, boxes, out_classes def isr(data, isr_yolo_model, isr_model): log("into isr_interface isr") try: img_data = base64.b64decode(data) img_np = bytes2np(img_data) _img, boxes, classes = detect_seal(img_np, isr_yolo_model) if not boxes and not classes: return {"image": img_np} part_list = get_seal_part(_img, boxes, classes) new_part_list = [] for part in part_list: part_remove = remove_seal(part, isr_model) new_part_list.append(part_remove) img_replace = replace_seal_part(img_np, new_part_list, boxes) return {"image": img_replace} except TimeoutError: return {"image": [-5]} except: traceback.print_exc() return {"image": [-1]} class IsrModels: def __init__(self): # python文件所在目录 _dir = os.path.abspath(os.path.dirname(__file__)) # detect model_path = _dir + "/models/seal_detect_yolo.h5" anchors = get_anchors(_dir + "/yolo_data/my_anchors.txt") class_names = get_classes(_dir + "/yolo_data/my_classes.txt") colors = get_colors(len(class_names)) with sess1.as_default(): with sess1.graph.as_default(): self.isr_yolo_model = get_tiny_inference_model(anchors, len(class_names), weights_path=model_path) self.isr_yolo_model.load_weights(model_path) # remove model_path = _dir + "/models/seal_remove_unet.h5" with sess2.as_default(): with sess2.graph.as_default(): self.isr_model = seal_model_se(input_shape=(None, None, 3), output_shape=(None, None, 3)) self.isr_model.load_weights(model_path) def get_model(self): return [self.isr_yolo_model, self.isr_model] def test_isr_model(from_remote=False): file_path = "C:/Users/Administrator/Desktop/test_image/error10.jpg" with open(file_path, "rb") as f: file_bytes = f.read() file_base64 = base64.b64encode(file_bytes) _md5 = get_md5_from_bytes(file_bytes)[0] _global._init() _global.update({"port": 15010, "md5": _md5}) if from_remote: file_json = {"data": file_base64, "md5": _md5} # _url = "http://192.168.2.102:17000/ocr" _url = "http://127.0.0.1:17000/ocr" print(json.loads(request_post(_url, file_json))) else: isr_yolo_model, isr_model = IsrModels().get_model() result = isr(file_base64, isr_yolo_model, isr_model) # print(result) if type(result.get("image")) == list: print(result) else: img = result.get("image") print(img.shape) cv2.namedWindow('img', cv2.WINDOW_NORMAL | cv2.WINDOW_KEEPRATIO) cv2.imshow("img", img) cv2.waitKey(0) # print(result) if __name__ == "__main__": test_isr_model()