isr_interface.py 6.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175
  1. import base64
  2. import json
  3. import os
  4. import time
  5. import sys
  6. import traceback
  7. sys.path.append(os.path.dirname(os.path.abspath(__file__)) + "/../")
  8. from format_convert import _global
  9. import cv2
  10. import numpy as np
  11. from PIL import Image
  12. from format_convert.utils import log, get_md5_from_bytes, request_post, np2pil, bytes2np, pil2np
  13. from isr.post_process import get_seal_part, replace_seal_part
  14. from isr.model import get_tiny_inference_model, seal_model, seal_model_se
  15. from isr.pre_process import count_red_pixel, get_anchors, get_classes, get_colors
  16. from isr.utils import get_best_predict_size, pil_resize, letterbox_image, draw_boxes, adjust_boxes
  17. import tensorflow as tf
  18. sess1 = tf.compat.v1.Session(graph=tf.Graph())
  19. sess2 = tf.compat.v1.Session(graph=tf.Graph())
  20. def remove_seal(image_np, model):
  21. # inference data
  22. image_seal = image_np
  23. h, w = image_seal.shape[:2]
  24. best_h, best_w = get_best_predict_size(image_seal)
  25. X = np.zeros((1, best_h, best_w, 3))
  26. # resize
  27. image_seal = pil_resize(image_seal, best_h, best_w)
  28. # cv2.imshow("resize", image_seal)
  29. X[0] = image_seal / 255
  30. # predict
  31. with sess2.as_default():
  32. with sess2.graph.as_default():
  33. pred = model.predict(X)
  34. pred = pred[0]*255.
  35. pred = pred.astype(np.uint8)
  36. pred = pil_resize(pred, h, w)
  37. # cv2.imshow("pred", pred)
  38. # cv2.waitKey(0)
  39. return pred
  40. def detect_seal(image_np, model):
  41. image_pil = np2pil(image_np)
  42. # 首先判断红色像素
  43. if not count_red_pixel(image_np):
  44. return image_np, [], []
  45. # create image input
  46. h, w = image_np.shape[:2]
  47. # best_h, best_w = get_best_predict_size(image_np, times=32, max_size=1280)
  48. best_h, best_w = 1024, 1024
  49. image_resize = letterbox_image(image_pil, tuple(reversed([best_h, best_w])))
  50. # cv2.imshow("letterbox_image", pil2np(image_resize))
  51. # cv2.waitKey(0)
  52. # image_resize = pil_resize(image_np, best_h, best_w)
  53. # image_resize = image_pil.resize((int(416), int(416)), Image.BICUBIC)
  54. image_resize = np.array(image_resize, dtype='float32')
  55. image_resize = image_resize.astype('float32') / 255.
  56. image_resize = np.expand_dims(image_resize, 0)
  57. # create image shape input
  58. image_shape = np.array([image_pil.size[1], image_pil.size[0]])
  59. image_shape = np.expand_dims(image_shape, 0)
  60. # inference data
  61. with sess1.as_default():
  62. with sess1.graph.as_default():
  63. out_boxes, out_scores, out_classes = model.predict([image_resize, image_shape])
  64. # print("image_size", image_shape)
  65. # print("out_boxes", out_boxes)
  66. # print("out_scores", out_scores)
  67. # print("out_classes", out_classes)
  68. out_boxes = out_boxes.astype(np.int32)
  69. out_classes = out_classes.astype(np.int32)
  70. boxes = adjust_boxes(image_pil, out_boxes)
  71. # # draw
  72. # class_names = get_classes(os.path.abspath(os.path.dirname(__file__))+"/yolo_data/my_classes.txt")
  73. # colors = get_colors(len(class_names))
  74. # image_draw = draw_boxes(image_pil, out_boxes, out_classes, out_scores, class_names, colors)
  75. # image_draw = cv2.cvtColor(np.array(image_draw), cv2.COLOR_RGB2BGR)
  76. # cv2.namedWindow('detect', cv2.WINDOW_NORMAL | cv2.WINDOW_KEEPRATIO)
  77. # cv2.imshow("detect", image_draw)
  78. # cv2.waitKey(0)
  79. return image_np, boxes, out_classes
  80. def isr(data, isr_yolo_model, isr_model):
  81. log("into isr_interface isr")
  82. try:
  83. img_data = base64.b64decode(data)
  84. img_np = bytes2np(img_data)
  85. _img, boxes, classes = detect_seal(img_np, isr_yolo_model)
  86. if not boxes and not classes:
  87. return {"image": img_np}
  88. part_list = get_seal_part(_img, boxes, classes)
  89. new_part_list = []
  90. for part in part_list:
  91. part_remove = remove_seal(part, isr_model)
  92. new_part_list.append(part_remove)
  93. img_replace = replace_seal_part(img_np, new_part_list, boxes)
  94. return {"image": img_replace}
  95. except TimeoutError:
  96. return {"image": [-5]}
  97. except:
  98. traceback.print_exc()
  99. return {"image": [-1]}
  100. class IsrModels:
  101. def __init__(self):
  102. # python文件所在目录
  103. _dir = os.path.abspath(os.path.dirname(__file__))
  104. # detect
  105. model_path = _dir + "/models/seal_detect_yolo.h5"
  106. anchors = get_anchors(_dir + "/yolo_data/my_anchors.txt")
  107. class_names = get_classes(_dir + "/yolo_data/my_classes.txt")
  108. colors = get_colors(len(class_names))
  109. with sess1.as_default():
  110. with sess1.graph.as_default():
  111. self.isr_yolo_model = get_tiny_inference_model(anchors, len(class_names), weights_path=model_path)
  112. self.isr_yolo_model.load_weights(model_path)
  113. # remove
  114. model_path = _dir + "/models/seal_remove_unet.h5"
  115. with sess2.as_default():
  116. with sess2.graph.as_default():
  117. self.isr_model = seal_model_se(input_shape=(None, None, 3),
  118. output_shape=(None, None, 3))
  119. self.isr_model.load_weights(model_path)
  120. def get_model(self):
  121. return [self.isr_yolo_model, self.isr_model]
  122. def test_isr_model(from_remote=False):
  123. file_path = "C:/Users/Administrator/Desktop/test_image/error10.jpg"
  124. with open(file_path, "rb") as f:
  125. file_bytes = f.read()
  126. file_base64 = base64.b64encode(file_bytes)
  127. _md5 = get_md5_from_bytes(file_bytes)[0]
  128. _global._init()
  129. _global.update({"port": 15010, "md5": _md5})
  130. if from_remote:
  131. file_json = {"data": file_base64, "md5": _md5}
  132. # _url = "http://192.168.2.102:17000/ocr"
  133. _url = "http://127.0.0.1:17000/ocr"
  134. print(json.loads(request_post(_url, file_json)))
  135. else:
  136. isr_yolo_model, isr_model = IsrModels().get_model()
  137. result = isr(file_base64, isr_yolo_model, isr_model)
  138. # print(result)
  139. if type(result.get("image")) == list:
  140. print(result)
  141. else:
  142. img = result.get("image")
  143. print(img.shape)
  144. cv2.namedWindow('img', cv2.WINDOW_NORMAL | cv2.WINDOW_KEEPRATIO)
  145. cv2.imshow("img", img)
  146. cv2.waitKey(0)
  147. # print(result)
  148. if __name__ == "__main__":
  149. test_isr_model()