isr_interface.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307
  1. import base64
  2. import json
  3. import os
  4. import time
  5. import sys
  6. import traceback
  7. # os.environ["CUDA_VISIBLE_DEVICES"] = "0"
  8. sys.path.append(os.path.dirname(os.path.abspath(__file__)) + "/../")
  9. from format_convert.max_compute_config import max_compute
  10. import tensorflow as tf
  11. tf.compat.v1.disable_eager_execution()
  12. MAX_COMPUTE = max_compute
  13. if not MAX_COMPUTE:
  14. # tensorflow 内存设置
  15. try:
  16. gpus = tf.config.list_physical_devices('GPU')
  17. if len(gpus) > 0:
  18. tf.config.experimental.set_virtual_device_configuration(
  19. gpus[0],
  20. [tf.config.experimental.VirtualDeviceConfiguration(memory_limit=1024)])
  21. except:
  22. traceback.print_exc()
  23. # pass
  24. # gpus = tf.config.list_physical_devices('GPU')
  25. # for gpu in gpus: # 如果使用多块GPU时
  26. # tf.config.experimental.set_memory_growth(gpu, True)
  27. os.environ['CUDA_CACHE_MAXSIZE'] = str(2147483648)
  28. os.environ['CUDA_CACHE_DISABLE'] = str(0)
  29. gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.05)
  30. sess = tf.Session(config=tf.ConfigProto(gpu_options=gpu_options))
  31. from format_convert import _global
  32. import cv2
  33. import numpy as np
  34. from PIL import Image
  35. from format_convert.utils import log, get_md5_from_bytes, request_post, np2pil, bytes2np, pil2np, get_platform, \
  36. judge_error_code
  37. from isr.post_process import get_seal_part, replace_seal_part
  38. from isr.model import get_tiny_inference_model, seal_model, seal_model_se
  39. from isr.pre_process import count_red_pixel, get_anchors, get_classes, get_colors
  40. from isr.utils import get_best_predict_size, pil_resize, letterbox_image, draw_boxes, adjust_boxes
  41. from flask import Flask, request
  42. tf.compat.v1.disable_eager_execution()
  43. sess1 = tf.compat.v1.Session(graph=tf.Graph())
  44. sess2 = tf.compat.v1.Session(graph=tf.Graph())
  45. def remove_seal(image_np, model):
  46. # inference data
  47. image_seal = image_np
  48. h, w = image_seal.shape[:2]
  49. best_h, best_w = get_best_predict_size(image_seal)
  50. X = np.zeros((1, best_h, best_w, 3))
  51. # resize
  52. image_seal = pil_resize(image_seal, best_h, best_w)
  53. # cv2.imshow("resize", image_seal)
  54. X[0] = image_seal / 255
  55. # predict
  56. with sess2.as_default():
  57. with sess2.graph.as_default():
  58. pred = model.predict(X, batch_size=1000)
  59. # pred = model(X, training=False)
  60. # pred = pred.eval()
  61. pred = pred[0]*255.
  62. pred = pred.astype(np.uint8)
  63. pred = pil_resize(pred, h, w)
  64. # cv2.imshow("pred", pred)
  65. # cv2.waitKey(0)
  66. return pred
  67. def detect_seal(image_np, model):
  68. image_pil = np2pil(image_np)
  69. # 首先判断红色像素
  70. # if not count_red_pixel(image_np):
  71. # return image_np, [], []
  72. # create image input
  73. h, w = image_np.shape[:2]
  74. # best_h, best_w = get_best_predict_size(image_np, times=32, max_size=1280)
  75. best_h, best_w = 1024, 1024
  76. image_resize = letterbox_image(image_pil, tuple(reversed([best_h, best_w])))
  77. # cv2.imshow("letterbox_image", pil2np(image_resize))
  78. # cv2.waitKey(0)
  79. # image_resize = pil_resize(image_np, best_h, best_w)
  80. # image_resize = image_pil.resize((int(416), int(416)), Image.BICUBIC)
  81. image_resize = np.array(image_resize, dtype='float32')
  82. image_resize = image_resize.astype('float32') / 255.
  83. image_resize = np.expand_dims(image_resize, 0)
  84. # create image shape input
  85. image_shape = np.array([image_pil.size[1], image_pil.size[0]])
  86. image_shape = np.expand_dims(image_shape, 0)
  87. # inference data
  88. with sess1.as_default():
  89. with sess1.graph.as_default():
  90. out_boxes, out_scores, out_classes = model.predict([image_resize, image_shape], steps=1)
  91. if int(out_boxes.shape[0]) == 0:
  92. log("there is no seal!")
  93. return image_np, [], []
  94. else:
  95. log("there are " + str(out_boxes.shape[0]) + " seals!")
  96. out_boxes = out_boxes.astype(np.int32)
  97. out_classes = out_classes.astype(np.int32)
  98. boxes = adjust_boxes(image_pil, out_boxes)
  99. # # draw
  100. # class_names = get_classes(os.path.abspath(os.path.dirname(__file__))+"/yolo_data/my_classes.txt")
  101. # colors = get_colors(len(class_names))
  102. # image_draw = draw_boxes(image_pil, out_boxes, out_classes, out_scores, class_names, colors)
  103. # image_draw = cv2.cvtColor(np.array(image_draw), cv2.COLOR_RGB2BGR)
  104. # cv2.namedWindow('detect', cv2.WINDOW_NORMAL | cv2.WINDOW_KEEPRATIO)
  105. # cv2.imshow("detect", image_draw)
  106. # cv2.waitKey(0)
  107. return image_np, boxes, out_classes
  108. def isr(data, isr_yolo_model, isr_model):
  109. log("into isr_interface isr")
  110. try:
  111. img_data = base64.b64decode(data)
  112. img_np = bytes2np(img_data)
  113. # 检测印章
  114. start_time = time.time()
  115. _img, boxes, classes = detect_seal(img_np, isr_yolo_model)
  116. log("detect_seal cost " + str(time.time()-start_time))
  117. # 检测不到,直接返回
  118. if not boxes and not classes:
  119. log("no seal detected! return 1")
  120. return {"image": [1]}
  121. # 截取
  122. start_time = time.time()
  123. part_list = get_seal_part(_img, boxes, classes)
  124. log("get_seal_part cost " + str(time.time()-start_time))
  125. # 去除印章
  126. start_time = time.time()
  127. new_part_list = []
  128. for part in part_list:
  129. part_remove = remove_seal(part, isr_model)
  130. new_part_list.append(part_remove)
  131. log("remove_seal cost " + str(time.time()-start_time))
  132. # 替换
  133. start_time = time.time()
  134. img_replace = replace_seal_part(img_np, new_part_list, boxes)
  135. log("replace_seal_part cost " + str(time.time()-start_time))
  136. return {"image": img_replace}
  137. except TimeoutError:
  138. return {"image": [-5]}
  139. except:
  140. traceback.print_exc()
  141. return {"image": [-1]}
  142. # 接口配置
  143. app = Flask(__name__)
  144. @app.route('/isr', methods=['POST'])
  145. def _isr():
  146. _global._init()
  147. _global.update({"port": globals().get("port")})
  148. start_time = time.time()
  149. log("into isr_interface _isr")
  150. try:
  151. if not request.form:
  152. log("isr no data!")
  153. return json.dumps({"image": str([-9])})
  154. data = request.form.get("data")
  155. log("isr_interface get data time " + str(time.time()-start_time))
  156. _md5 = request.form.get("md5")
  157. _global.update({"md5": _md5})
  158. # 初始化模型
  159. isr_yolo_model = globals().get("global_isr_yolo_model")
  160. isr_model = globals().get("global_isr_model")
  161. if isr_model is None or isr_yolo_model is None:
  162. print("=========== init isr model ===========")
  163. isr_yolo_model, isr_model = IsrModels().get_model()
  164. globals().update({"global_isr_yolo_model": isr_yolo_model})
  165. globals().update({"global_isr_model": isr_model})
  166. # 检测+去除
  167. # result = isr(data, isr_yolo_model, isr_model)
  168. # result = result.get("image")
  169. # if judge_error_code(result):
  170. # return json.dumps({"image": result})
  171. # if isinstance(result, list) and result == [1]:
  172. # return json.dumps({"image": result})
  173. # img_replace = result
  174. img_replace = isr(data, isr_yolo_model, isr_model).get("image")
  175. if isinstance(img_replace, list):
  176. return json.dumps({"image": img_replace})
  177. # numpy转为可序列化的string
  178. success, img_encode = cv2.imencode(".jpg", img_replace)
  179. # numpy -> bytes
  180. img_bytes = img_encode.tobytes()
  181. # bytes -> base64 bytes
  182. img_base64 = base64.b64encode(img_bytes)
  183. # base64 bytes -> string (utf-8)
  184. base64_string = img_base64.decode('utf-8')
  185. return json.dumps({"image": base64_string})
  186. except TimeoutError:
  187. return json.dumps({"image": [-5]})
  188. except:
  189. traceback.print_exc()
  190. return json.dumps({"image": [-1]})
  191. finally:
  192. log("isr interface finish time " + str(time.time()-start_time))
  193. class IsrModels:
  194. def __init__(self):
  195. # python文件所在目录
  196. _dir = os.path.abspath(os.path.dirname(__file__))
  197. # detect
  198. model_path = _dir + "/models/seal_detect_yolo.h5"
  199. anchors = get_anchors(_dir + "/yolo_data/my_anchors.txt")
  200. class_names = get_classes(_dir + "/yolo_data/my_classes.txt")
  201. colors = get_colors(len(class_names))
  202. with sess1.as_default():
  203. with sess1.graph.as_default():
  204. self.isr_yolo_model = get_tiny_inference_model(anchors, len(class_names), weights_path=model_path)
  205. self.isr_yolo_model.load_weights(model_path)
  206. # self.isr_yolo_model.compile(run_eagerly=True)
  207. # remove
  208. model_path = _dir + "/models/seal_remove_unet.h5"
  209. with sess2.as_default():
  210. with sess2.graph.as_default():
  211. self.isr_model = seal_model_se(input_shape=(None, None, 3),
  212. output_shape=(None, None, 3))
  213. self.isr_model.load_weights(model_path)
  214. def get_model(self):
  215. return [self.isr_yolo_model, self.isr_model]
  216. def test_isr_model(from_remote=False):
  217. if get_platform() == "Windows":
  218. file_path = "C:/Users/Administrator/Desktop/test_image/error10.jpg"
  219. # file_path = "C:\\Users\\Administrator\\Downloads\\1647913696016.jpg"
  220. else:
  221. file_path = "error10.jpg"
  222. with open(file_path, "rb") as f:
  223. file_bytes = f.read()
  224. file_base64 = base64.b64encode(file_bytes)
  225. _md5 = get_md5_from_bytes(file_bytes)[0]
  226. _global._init()
  227. _global.update({"port": 15010, "md5": _md5})
  228. if from_remote:
  229. file_json = {"data": file_base64, "md5": _md5}
  230. # _url = "http://192.168.2.102:18040/isr"
  231. _url = "http://127.0.0.1:18040/isr"
  232. result = json.loads(request_post(_url, file_json))
  233. if type(result.get("image")) == list:
  234. print("result", result)
  235. else:
  236. img = result.get("image")
  237. image_base64 = img.encode("utf-8")
  238. image_bytes = base64.b64decode(image_base64)
  239. buffer = np.frombuffer(image_bytes, dtype=np.uint8)
  240. image_np = cv2.imdecode(buffer, 1)
  241. print(image_np.shape)
  242. else:
  243. if globals().get("global_isr_model") is None:
  244. isr_yolo_model, isr_model = IsrModels().get_model()
  245. globals().update({"global_isr_yolo_model": isr_yolo_model})
  246. globals().update({"global_isr_model": isr_model})
  247. result = isr(file_base64,
  248. globals().get("global_isr_yolo_model"),
  249. globals().get("global_isr_model"))
  250. # print(result)
  251. if type(result.get("image")) == list:
  252. print("result", len(result))
  253. else:
  254. img = result.get("image")
  255. print(img.shape)
  256. # cv2.namedWindow('img', cv2.WINDOW_NORMAL | cv2.WINDOW_KEEPRATIO)
  257. # cv2.imshow("img", img)
  258. # cv2.waitKey(0)
  259. # print(result)
  260. if __name__ == "__main__":
  261. for i in range(100):
  262. s_t = time.time()
  263. test_isr_model(from_remote=True)
  264. print("finish test_isr_model", time.time()-s_t)