isr_interface.py 11 KB


  1. import base64
  2. import json
  3. import os
  4. import time
  5. import sys
  6. import traceback
  7. os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
  8. import tensorflow as tf
  9. tf.compat.v1.enable_eager_execution()
  10. MAX_COMPUTE = False
  11. if not MAX_COMPUTE:
  12. # tensorflow 内存设置
  13. try:
  14. gpus = tf.config.list_physical_devices('GPU')
  15. if len(gpus) > 0:
  16. tf.config.experimental.set_virtual_device_configuration(
  17. gpus[0],
  18. [tf.config.experimental.VirtualDeviceConfiguration(memory_limit=1024)])
  19. except:
  20. traceback.print_exc()
  21. # pass
  22. # gpus = tf.config.list_physical_devices('GPU')
  23. # for gpu in gpus: # 如果使用多块GPU时
  24. # tf.config.experimental.set_memory_growth(gpu, True)
  25. os.environ['CUDA_CACHE_MAXSIZE'] = str(2147483648)
  26. os.environ['CUDA_CACHE_DISABLE'] = str(0)
  27. gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.1)
  28. sess = tf.Session(config=tf.ConfigProto(gpu_options=gpu_options))
  29. sys.path.append(os.path.dirname(os.path.abspath(__file__)) + "/../")
  30. from format_convert import _global
  31. import cv2
  32. import numpy as np
  33. from PIL import Image
  34. from format_convert.utils import log, get_md5_from_bytes, request_post, np2pil, bytes2np, pil2np, get_platform, \
  35. judge_error_code
  36. from isr.post_process import get_seal_part, replace_seal_part
  37. from isr.model import get_tiny_inference_model, seal_model, seal_model_se
  38. from isr.pre_process import count_red_pixel, get_anchors, get_classes, get_colors
  39. from isr.utils import get_best_predict_size, pil_resize, letterbox_image, draw_boxes, adjust_boxes
  40. from flask import Flask, request
  41. sess1 = tf.compat.v1.Session(graph=tf.Graph())
  42. sess2 = tf.compat.v1.Session(graph=tf.Graph())
  43. def remove_seal(image_np, model):
  44. # inference data
  45. image_seal = image_np
  46. h, w = image_seal.shape[:2]
  47. best_h, best_w = get_best_predict_size(image_seal)
  48. X = np.zeros((1, best_h, best_w, 3))
  49. # resize
  50. image_seal = pil_resize(image_seal, best_h, best_w)
  51. # cv2.imshow("resize", image_seal)
  52. X[0] = image_seal / 255
  53. # predict
  54. with sess2.as_default():
  55. with sess2.graph.as_default():
  56. pred = model.predict(X, batch_size=1000)
  57. # pred = model(X, training=False)
  58. # pred = pred.eval()
  59. pred = pred[0]*255.
  60. pred = pred.astype(np.uint8)
  61. pred = pil_resize(pred, h, w)
  62. # cv2.imshow("pred", pred)
  63. # cv2.waitKey(0)
  64. return pred
  65. def detect_seal(image_np, model):
  66. image_pil = np2pil(image_np)
  67. # 首先判断红色像素
  68. # if not count_red_pixel(image_np):
  69. # return image_np, [], []
  70. # create image input
  71. h, w = image_np.shape[:2]
  72. # best_h, best_w = get_best_predict_size(image_np, times=32, max_size=1280)
  73. best_h, best_w = 1024, 1024
  74. image_resize = letterbox_image(image_pil, tuple(reversed([best_h, best_w])))
  75. # cv2.imshow("letterbox_image", pil2np(image_resize))
  76. # cv2.waitKey(0)
  77. # image_resize = pil_resize(image_np, best_h, best_w)
  78. # image_resize = image_pil.resize((int(416), int(416)), Image.BICUBIC)
  79. image_resize = np.array(image_resize, dtype='float32')
  80. image_resize = image_resize.astype('float32') / 255.
  81. image_resize = np.expand_dims(image_resize, 0)
  82. # create image shape input
  83. image_shape = np.array([image_pil.size[1], image_pil.size[0]])
  84. image_shape = np.expand_dims(image_shape, 0)
  85. # inference data
  86. with sess1.as_default():
  87. with sess1.graph.as_default():
  88. out_boxes, out_scores, out_classes = model.predict([image_resize, image_shape], batch_size=1000, steps=1)
  89. if int(out_boxes.shape[0]) == 0:
  90. log("there is no seal!")
  91. return image_np, [], []
  92. else:
  93. log("there are " + str(out_boxes.shape[0]) + " seals!")
  94. out_boxes = out_boxes.astype(np.int32)
  95. out_classes = out_classes.astype(np.int32)
  96. boxes = adjust_boxes(image_pil, out_boxes)
  97. # # draw
  98. # class_names = get_classes(os.path.abspath(os.path.dirname(__file__))+"/yolo_data/my_classes.txt")
  99. # colors = get_colors(len(class_names))
  100. # image_draw = draw_boxes(image_pil, out_boxes, out_classes, out_scores, class_names, colors)
  101. # image_draw = cv2.cvtColor(np.array(image_draw), cv2.COLOR_RGB2BGR)
  102. # cv2.namedWindow('detect', cv2.WINDOW_NORMAL | cv2.WINDOW_KEEPRATIO)
  103. # cv2.imshow("detect", image_draw)
  104. # cv2.waitKey(0)
  105. return image_np, boxes, out_classes
  106. def isr(data, isr_yolo_model, isr_model):
  107. log("into isr_interface isr")
  108. try:
  109. img_data = base64.b64decode(data)
  110. img_np = bytes2np(img_data)
  111. # 检测印章
  112. start_time = time.time()
  113. _img, boxes, classes = detect_seal(img_np, isr_yolo_model)
  114. log("detect_seal cost " + str(time.time()-start_time))
  115. # 检测不到,直接返回
  116. if not boxes and not classes:
  117. log("no seal detected! return 1")
  118. return {"image": [1]}
  119. # 截取
  120. start_time = time.time()
  121. part_list = get_seal_part(_img, boxes, classes)
  122. log("get_seal_part cost " + str(time.time()-start_time))
  123. # 去除印章
  124. start_time = time.time()
  125. new_part_list = []
  126. for part in part_list:
  127. part_remove = remove_seal(part, isr_model)
  128. new_part_list.append(part_remove)
  129. log("remove_seal cost " + str(time.time()-start_time))
  130. # 替换
  131. start_time = time.time()
  132. img_replace = replace_seal_part(img_np, new_part_list, boxes)
  133. log("replace_seal_part cost " + str(time.time()-start_time))
  134. return {"image": img_replace}
  135. except TimeoutError:
  136. return {"image": [-5]}
  137. except:
  138. traceback.print_exc()
  139. return {"image": [-1]}
  140. # 接口配置
  141. app = Flask(__name__)
  142. @app.route('/isr', methods=['POST'])
  143. def _isr():
  144. _global._init()
  145. _global.update({"port": globals().get("port")})
  146. start_time = time.time()
  147. log("into isr_interface _isr")
  148. try:
  149. if not request.form:
  150. log("isr no data!")
  151. return json.dumps({"text": str([-9]), "bbox": str([-9])})
  152. data = request.form.get("data")
  153. log("isr_interface get data time " + str(time.time()-start_time))
  154. _md5 = request.form.get("md5")
  155. _global.update({"md5": _md5})
  156. # 初始化模型
  157. isr_yolo_model = globals().get("global_isr_yolo_model")
  158. isr_model = globals().get("global_isr_model")
  159. if isr_model is None or isr_yolo_model is None:
  160. print("=========== init isr model ===========")
  161. isr_yolo_model, isr_model = IsrModels().get_model()
  162. globals().update({"global_isr_yolo_model": isr_yolo_model})
  163. globals().update({"global_isr_model": isr_model})
  164. # 检测+去除
  165. result = isr(data, isr_yolo_model, isr_model)
  166. result = result.get("image")
  167. if judge_error_code(result):
  168. return json.dumps({"image": result})
  169. if isinstance(result, list) and result == [1]:
  170. return json.dumps({"image": result})
  171. img_replace = result
  172. # numpy转为可序列化的string
  173. success, img_encode = cv2.imencode(".jpg", img_replace)
  174. # numpy -> bytes
  175. img_bytes = img_encode.tobytes()
  176. # bytes -> base64 bytes
  177. img_base64 = base64.b64encode(img_bytes)
  178. # base64 bytes -> string (utf-8)
  179. base64_string = img_base64.decode('utf-8')
  180. return json.dumps({"image": base64_string})
  181. except TimeoutError:
  182. return json.dumps({"image": [-5]})
  183. except:
  184. traceback.print_exc()
  185. return json.dumps({"image": [-1]})
  186. finally:
  187. log("isr interface finish time " + str(time.time()-start_time))
  188. class IsrModels:
  189. def __init__(self):
  190. # python文件所在目录
  191. _dir = os.path.abspath(os.path.dirname(__file__))
  192. # detect
  193. model_path = _dir + "/models/seal_detect_yolo.h5"
  194. anchors = get_anchors(_dir + "/yolo_data/my_anchors.txt")
  195. class_names = get_classes(_dir + "/yolo_data/my_classes.txt")
  196. colors = get_colors(len(class_names))
  197. with sess1.as_default():
  198. with sess1.graph.as_default():
  199. self.isr_yolo_model = get_tiny_inference_model(anchors, len(class_names), weights_path=model_path)
  200. self.isr_yolo_model.load_weights(model_path)
  201. # self.isr_yolo_model.compile(run_eagerly=True)
  202. # remove
  203. model_path = _dir + "/models/seal_remove_unet.h5"
  204. with sess2.as_default():
  205. with sess2.graph.as_default():
  206. self.isr_model = seal_model_se(input_shape=(None, None, 3),
  207. output_shape=(None, None, 3))
  208. self.isr_model.load_weights(model_path)
  209. def get_model(self):
  210. return [self.isr_yolo_model, self.isr_model]
  211. def test_isr_model(from_remote=False):
  212. if get_platform() == "Windows":
  213. file_path = "C:/Users/Administrator/Desktop/test_image/error10.jpg"
  214. # file_path = "C:\\Users\\Administrator\\Downloads\\1647913696016.jpg"
  215. else:
  216. file_path = "error10.jpg"
  217. with open(file_path, "rb") as f:
  218. file_bytes = f.read()
  219. file_base64 = base64.b64encode(file_bytes)
  220. _md5 = get_md5_from_bytes(file_bytes)[0]
  221. _global._init()
  222. _global.update({"port": 15010, "md5": _md5})
  223. if from_remote:
  224. file_json = {"data": file_base64, "md5": _md5}
  225. # _url = "http://192.168.2.102:18040/isr"
  226. _url = "http://127.0.0.1:18040/isr"
  227. result = json.loads(request_post(_url, file_json))
  228. if type(result.get("image")) == list:
  229. print("result", result)
  230. else:
  231. img = result.get("image")
  232. image_base64 = img.encode("utf-8")
  233. image_bytes = base64.b64decode(image_base64)
  234. buffer = np.frombuffer(image_bytes, dtype=np.uint8)
  235. image_np = cv2.imdecode(buffer, 1)
  236. print(image_np.shape)
  237. else:
  238. if globals().get("global_isr_model") is None:
  239. isr_yolo_model, isr_model = IsrModels().get_model()
  240. globals().update({"global_isr_yolo_model": isr_yolo_model})
  241. globals().update({"global_isr_model": isr_model})
  242. result = isr(file_base64,
  243. globals().get("global_isr_yolo_model"),
  244. globals().get("global_isr_model"))
  245. # print(result)
  246. if type(result.get("image")) == list:
  247. print("result", len(result))
  248. else:
  249. img = result.get("image")
  250. print(img.shape)
  251. # cv2.namedWindow('img', cv2.WINDOW_NORMAL | cv2.WINDOW_KEEPRATIO)
  252. # cv2.imshow("img", img)
  253. # cv2.waitKey(0)
  254. # print(result)
  255. if __name__ == "__main__":
  256. for i in range(100):
  257. s_t = time.time()
  258. test_isr_model(from_remote=True)
  259. print("finish test_isr_model", time.time()-s_t)