predict_det_pytorch.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434
  1. # encoding=utf8
  2. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. import io
  16. import logging
  17. import os
  18. import sys
  19. sys.path.append(os.path.dirname(os.path.abspath(__file__)) + "/../../../")
  20. import requests
  21. from format_convert import _global
  22. from format_convert.utils import judge_error_code, log, namespace_to_dict, get_platform, file_lock
  23. os.environ["FLAGS_allocator_strategy"] = 'auto_growth'
  24. import cv2
  25. import numpy as np
  26. import time
  27. import sys
  28. os.environ['FLAGS_eager_delete_tensor_gb'] = '0'
  29. import ocr.tools.infer.utility as utility
  30. from ocr.ppocr.utils.logging import get_logger
  31. from ocr.ppocr.utils.utility import get_image_file_list, check_and_read_gif
  32. from ocr.ppocr.data import create_operators, transform
  33. from ocr.ppocr.postprocess import build_post_process
  34. from config.max_compute_config import MAX_COMPUTE
  35. import torch
  36. from torch import nn
  37. from ocr.tools.infer.torch_det_model import DB_ResNet_18
  38. import gc
  39. logger = get_logger()
  40. class TextDetector(object):
  41. shrink_memory_count = 0
  42. def __init__(self, args):
  43. self.args = args
  44. self.det_algorithm = args.det_algorithm
  45. pre_process_list = [{
  46. 'DetResizeForTest': None
  47. }, {
  48. 'NormalizeImage': {
  49. 'std': [0.229, 0.224, 0.225],
  50. 'mean': [0.485, 0.456, 0.406],
  51. 'scale': '1./255.',
  52. 'order': 'hwc'
  53. }
  54. }, {
  55. 'ToCHWImage': None
  56. }, {
  57. 'KeepKeys': {
  58. 'keep_keys': ['image', 'shape']
  59. }
  60. }]
  61. postprocess_params = {}
  62. if self.det_algorithm == "DB":
  63. postprocess_params['name'] = 'DBPostProcess'
  64. postprocess_params["thresh"] = args.det_db_thresh
  65. postprocess_params["box_thresh"] = args.det_db_box_thresh
  66. postprocess_params["max_candidates"] = 1000
  67. postprocess_params["unclip_ratio"] = args.det_db_unclip_ratio
  68. postprocess_params["use_dilation"] = args.use_dilation
  69. elif self.det_algorithm == "EAST":
  70. postprocess_params['name'] = 'EASTPostProcess'
  71. postprocess_params["score_thresh"] = args.det_east_score_thresh
  72. postprocess_params["cover_thresh"] = args.det_east_cover_thresh
  73. postprocess_params["nms_thresh"] = args.det_east_nms_thresh
  74. elif self.det_algorithm == "SAST":
  75. pre_process_list[0] = {
  76. 'DetResizeForTest': {
  77. 'resize_long': args.det_limit_side_len
  78. }
  79. }
  80. postprocess_params['name'] = 'SASTPostProcess'
  81. postprocess_params["score_thresh"] = args.det_sast_score_thresh
  82. postprocess_params["nms_thresh"] = args.det_sast_nms_thresh
  83. self.det_sast_polygon = args.det_sast_polygon
  84. if self.det_sast_polygon:
  85. postprocess_params["sample_pts_num"] = 6
  86. postprocess_params["expand_scale"] = 1.2
  87. postprocess_params["shrink_ratio_of_width"] = 0.2
  88. else:
  89. postprocess_params["sample_pts_num"] = 2
  90. postprocess_params["expand_scale"] = 1.0
  91. postprocess_params["shrink_ratio_of_width"] = 0.3
  92. else:
  93. logger.info("unknown det_algorithm:{}".format(self.det_algorithm))
  94. sys.exit(0)
  95. self.preprocess_op = create_operators(pre_process_list)
  96. self.postprocess_op = build_post_process(postprocess_params)
  97. det_model_path = args.det_model_dir
  98. self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
  99. model = DB_ResNet_18()
  100. mode_state_dict = torch.load(det_model_path, self.device)['state_dict']
  101. if str(self.device) == 'cpu': # cpu处理时精度调整,加速推理
  102. for name, value in mode_state_dict.items():
  103. if get_platform() != "Windows":
  104. value = value.double()
  105. value = torch.where((value < 1.0e-23) & (value > 0.0), 1.0e-23, value)
  106. value = torch.where((value > -1.0e-23) & (value < 0.0), -1.0e-23, value)
  107. mode_state_dict[name] = value
  108. model.load_state_dict(mode_state_dict)
  109. self.predictor = model
  110. self.predictor.to(self.device)
  111. self.predictor.eval()
  112. # self.predictor, self.input_tensor, self.output_tensors = utility.create_predictor(
  113. # args, 'det', logger) # paddle.jit.load(args.det_model_dir)
  114. # self.predictor.eval()
  115. def order_points_clockwise(self, pts):
  116. """
  117. reference from: https://github.com/jrosebr1/imutils/blob/master/imutils/perspective.py
  118. # sort the points based on their x-coordinates
  119. """
  120. xSorted = pts[np.argsort(pts[:, 0]), :]
  121. # grab the left-most and right-most points from the sorted
  122. # x-roodinate points
  123. leftMost = xSorted[:2, :]
  124. rightMost = xSorted[2:, :]
  125. # now, sort the left-most coordinates according to their
  126. # y-coordinates so we can grab the top-left and bottom-left
  127. # points, respectively
  128. leftMost = leftMost[np.argsort(leftMost[:, 1]), :]
  129. (tl, bl) = leftMost
  130. rightMost = rightMost[np.argsort(rightMost[:, 1]), :]
  131. (tr, br) = rightMost
  132. rect = np.array([tl, tr, br, bl], dtype="float32")
  133. return rect
  134. def clip_det_res(self, points, img_height, img_width):
  135. for pno in range(points.shape[0]):
  136. points[pno, 0] = int(min(max(points[pno, 0], 0), img_width - 1))
  137. points[pno, 1] = int(min(max(points[pno, 1], 0), img_height - 1))
  138. return points
  139. def filter_tag_det_res(self, dt_boxes, image_shape):
  140. img_height, img_width = image_shape[0:2]
  141. dt_boxes_new = []
  142. for box in dt_boxes:
  143. box = self.order_points_clockwise(box)
  144. box = self.clip_det_res(box, img_height, img_width)
  145. rect_width = int(np.linalg.norm(box[0] - box[1]))
  146. rect_height = int(np.linalg.norm(box[0] - box[3]))
  147. if rect_width <= 3 or rect_height <= 3:
  148. continue
  149. dt_boxes_new.append(box)
  150. dt_boxes = np.array(dt_boxes_new)
  151. return dt_boxes
  152. def filter_tag_det_res_only_clip(self, dt_boxes, image_shape):
  153. img_height, img_width = image_shape[0:2]
  154. dt_boxes_new = []
  155. for box in dt_boxes:
  156. box = self.clip_det_res(box, img_height, img_width)
  157. dt_boxes_new.append(box)
  158. dt_boxes = np.array(dt_boxes_new)
  159. return dt_boxes
  160. def __call__(self, img):
  161. # cv2.imwrite("/data2/znj/format_conversion_maxcompute/ocr/temp_image/temp.jpg",img)
  162. ori_im = img.copy()
  163. data = {'image': img}
  164. data = transform(data, self.preprocess_op)
  165. img, shape_list = data
  166. if img is None:
  167. return None, 0
  168. img = np.expand_dims(img, axis=0)
  169. shape_list = np.expand_dims(shape_list, axis=0)
  170. img = img.copy()
  171. starttime = time.time()
  172. # self.input_tensor.copy_from_cpu(img)
  173. img = torch.from_numpy(img).float()
  174. img = img.to(self.device)
  175. try:
  176. # 加锁,防止太多大图片同时预测,爆显存
  177. if ori_im.shape[0] > 1024 and ori_im.shape[1] > 1024 and get_platform() != "Windows" and not MAX_COMPUTE:
  178. time2 = time.time()
  179. lock_file_sub = 'ocr'
  180. lock_file = os.path.abspath(os.path.dirname(__file__)) + "/" + lock_file_sub + ".lock"
  181. f = file_lock(lock_file)
  182. log("get file_lock " + lock_file_sub + " time " + str(time.time()-time2))
  183. with torch.no_grad():
  184. out = self.predictor(img)
  185. f.close()
  186. else:
  187. with torch.no_grad():
  188. out = self.predictor(img)
  189. except RuntimeError:
  190. log("ocr/tools/infer/predict_det.py predict.run error! maybe no gpu memory!")
  191. log("predictor shrink memory!")
  192. # self.predictor.clear_intermediate_tensor()
  193. # self.predictor.try_shrink_memory()
  194. if str(self.device)!='cpu':
  195. torch.cuda.empty_cache()
  196. gc.collect()
  197. raise RuntimeError
  198. # outputs = []
  199. # for output_tensor in self.output_tensors:
  200. # output = output_tensor.copy_to_cpu()
  201. # outputs.append(output)
  202. out = out.cpu().numpy()
  203. preds = {}
  204. preds['maps'] = out
  205. # if self.det_algorithm == "EAST":
  206. # preds['f_geo'] = outputs[0]
  207. # preds['f_score'] = outputs[1]
  208. # elif self.det_algorithm == 'SAST':
  209. # preds['f_border'] = outputs[0]
  210. # preds['f_score'] = outputs[1]
  211. # preds['f_tco'] = outputs[2]
  212. # preds['f_tvo'] = outputs[3]
  213. # elif self.det_algorithm == 'DB':
  214. # preds['maps'] = outputs[0]
  215. # else:
  216. # raise NotImplementedError
  217. post_result = self.postprocess_op(preds, shape_list)
  218. dt_boxes = post_result[0]['points']
  219. if self.det_algorithm == "SAST" and self.det_sast_polygon:
  220. dt_boxes = self.filter_tag_det_res_only_clip(dt_boxes, ori_im.shape)
  221. else:
  222. dt_boxes = self.filter_tag_det_res(dt_boxes, ori_im.shape)
  223. elapse = time.time() - starttime
  224. # 释放内存
  225. # print("TextDetector", self.predictor)
  226. # if TextDetector.shrink_memory_count % 100 == 0:
  227. # print("TextDetector shrink memory")
  228. # self.predictor.clear_intermediate_tensor()
  229. # self.predictor.try_shrink_memory()
  230. # TextDetector.shrink_memory_count += 1
  231. if str(self.device) != 'cpu':
  232. torch.cuda.empty_cache()
  233. # gc.collect()
  234. return dt_boxes, elapse
  235. class TextDetector2(object):
  236. shrink_memory_count = 0
  237. def __init__(self, args):
  238. self.args = args
  239. self.det_algorithm = args.det_algorithm
  240. pre_process_list = [{
  241. 'DetResizeForTest': None
  242. }, {
  243. 'NormalizeImage': {
  244. 'std': [0.229, 0.224, 0.225],
  245. 'mean': [0.485, 0.456, 0.406],
  246. 'scale': '1./255.',
  247. 'order': 'hwc'
  248. }
  249. }, {
  250. 'ToCHWImage': None
  251. }, {
  252. 'KeepKeys': {
  253. 'keep_keys': ['image', 'shape']
  254. }
  255. }]
  256. postprocess_params = {}
  257. if self.det_algorithm == "DB":
  258. postprocess_params['name'] = 'DBPostProcess'
  259. postprocess_params["thresh"] = args.det_db_thresh
  260. postprocess_params["box_thresh"] = args.det_db_box_thresh
  261. postprocess_params["max_candidates"] = 1000
  262. postprocess_params["unclip_ratio"] = args.det_db_unclip_ratio
  263. postprocess_params["use_dilation"] = args.use_dilation
  264. else:
  265. logger.info("unknown det_algorithm:{}".format(self.det_algorithm))
  266. sys.exit(0)
  267. self.preprocess_op = create_operators(pre_process_list)
  268. self.postprocess_op = build_post_process(postprocess_params)
  269. def order_points_clockwise(self, pts):
  270. """
  271. reference from: https://github.com/jrosebr1/imutils/blob/master/imutils/perspective.py
  272. # sort the points based on their x-coordinates
  273. """
  274. xSorted = pts[np.argsort(pts[:, 0]), :]
  275. # grab the left-most and right-most points from the sorted
  276. # x-roodinate points
  277. leftMost = xSorted[:2, :]
  278. rightMost = xSorted[2:, :]
  279. # now, sort the left-most coordinates according to their
  280. # y-coordinates so we can grab the top-left and bottom-left
  281. # points, respectively
  282. leftMost = leftMost[np.argsort(leftMost[:, 1]), :]
  283. (tl, bl) = leftMost
  284. rightMost = rightMost[np.argsort(rightMost[:, 1]), :]
  285. (tr, br) = rightMost
  286. rect = np.array([tl, tr, br, bl], dtype="float32")
  287. return rect
  288. def clip_det_res(self, points, img_height, img_width):
  289. for pno in range(points.shape[0]):
  290. points[pno, 0] = int(min(max(points[pno, 0], 0), img_width - 1))
  291. points[pno, 1] = int(min(max(points[pno, 1], 0), img_height - 1))
  292. return points
  293. def filter_tag_det_res(self, dt_boxes, image_shape):
  294. img_height, img_width = image_shape[0:2]
  295. dt_boxes_new = []
  296. for box in dt_boxes:
  297. box = self.order_points_clockwise(box)
  298. box = self.clip_det_res(box, img_height, img_width)
  299. rect_width = int(np.linalg.norm(box[0] - box[1]))
  300. rect_height = int(np.linalg.norm(box[0] - box[3]))
  301. if rect_width <= 3 or rect_height <= 3:
  302. continue
  303. dt_boxes_new.append(box)
  304. dt_boxes = np.array(dt_boxes_new)
  305. return dt_boxes
  306. def filter_tag_det_res_only_clip(self, dt_boxes, image_shape):
  307. img_height, img_width = image_shape[0:2]
  308. dt_boxes_new = []
  309. for box in dt_boxes:
  310. box = self.clip_det_res(box, img_height, img_width)
  311. dt_boxes_new.append(box)
  312. dt_boxes = np.array(dt_boxes_new)
  313. return dt_boxes
  314. def __call__(self, img):
  315. from format_convert.convert_need_interface import from_gpu_interface_redis
  316. # 预处理
  317. ori_im = img.copy()
  318. data = {'image': img}
  319. data = transform(data, self.preprocess_op)
  320. img, shape_list = data
  321. if img is None:
  322. return None, 0
  323. img = np.expand_dims(img, axis=0)
  324. shape_list = np.expand_dims(shape_list, axis=0)
  325. img = img.copy()
  326. starttime = time.time()
  327. # # 压缩numpy
  328. # compressed_array = io.BytesIO()
  329. # np.savez_compressed(compressed_array, img)
  330. # compressed_array.seek(0)
  331. # img = compressed_array.read()
  332. # 调用GPU接口
  333. _dict = {"inputs": img, "args": str(namespace_to_dict(self.args)), "md5": _global.get("md5")}
  334. result = from_gpu_interface_redis(_dict, model_type="ocr", predictor_type="det")
  335. if judge_error_code(result):
  336. logging.error("from_gpu_interface failed! " + str(result))
  337. raise requests.exceptions.RequestException
  338. _preds = result.get("preds")
  339. gpu_time = result.get("gpu_time")
  340. # # 解压numpy
  341. # decompressed_array = io.BytesIO()
  342. # decompressed_array.write(_preds)
  343. # decompressed_array.seek(0)
  344. # _preds = np.load(decompressed_array, allow_pickle=True)['arr_0']
  345. # log("inputs.shape" + str(_preds.shape))
  346. # 后处理
  347. preds = {}
  348. if self.det_algorithm == 'DB':
  349. preds['maps'] = _preds
  350. else:
  351. raise NotImplementedError
  352. post_result = self.postprocess_op(preds, shape_list)
  353. dt_boxes = post_result[0]['points']
  354. dt_boxes = self.filter_tag_det_res(dt_boxes, ori_im.shape)
  355. elapse = time.time() - starttime
  356. log("ocr model predict time - det - time " + str(gpu_time))
  357. return dt_boxes, elapse
  358. if __name__ == "__main__":
  359. args = utility.parse_args()
  360. image_file_list = get_image_file_list(args.image_dir)
  361. text_detector = TextDetector(args)
  362. count = 0
  363. total_time = 0
  364. draw_img_save = "./inference_results"
  365. if not os.path.exists(draw_img_save):
  366. os.makedirs(draw_img_save)
  367. for image_file in image_file_list:
  368. img, flag = check_and_read_gif(image_file)
  369. if not flag:
  370. img = cv2.imread(image_file)
  371. if img is None:
  372. logger.info("error in loading image:{}".format(image_file))
  373. continue
  374. dt_boxes, elapse = text_detector(img)
  375. if count > 0:
  376. total_time += elapse
  377. count += 1
  378. logger.info("Predict time of {}: {}".format(image_file, elapse))
  379. src_im = utility.draw_text_det_res(dt_boxes, image_file)
  380. img_name_pure = os.path.split(image_file)[-1]
  381. img_path = os.path.join(draw_img_save,
  382. "det_res_{}".format(img_name_pure))
  383. cv2.imwrite(img_path, src_im)
  384. logger.info("The visualized image saved in {}".format(img_path))
  385. if count > 1:
  386. logger.info("Avg Time: {}".format(total_time / (count - 1)))