ocr_infer.py 4.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113
  1. from det_infer import DetInfer
  2. from rec_infer import RecInfer
  3. import argparse
  4. from line_profiler import LineProfiler
  5. from memory_profiler import profile
  6. from torchocr.utils.vis import draw_ocr_box_txt
  7. import numpy as np
  8. def get_rotate_crop_image(img, points):
  9. '''
  10. img_height, img_width = img.shape[0:2]
  11. left = int(np.min(points[:, 0]))
  12. right = int(np.max(points[:, 0]))
  13. top = int(np.min(points[:, 1]))
  14. bottom = int(np.max(points[:, 1]))
  15. img_crop = img[top:bottom, left:right, :].copy()
  16. points[:, 0] = points[:, 0] - left
  17. points[:, 1] = points[:, 1] - top
  18. '''
  19. points = points.astype(np.float32)
  20. img_crop_width = int(
  21. max(
  22. np.linalg.norm(points[0] - points[1]),
  23. np.linalg.norm(points[2] - points[3])))
  24. img_crop_height = int(
  25. max(
  26. np.linalg.norm(points[0] - points[3]),
  27. np.linalg.norm(points[1] - points[2])))
  28. pts_std = np.float32([[0, 0], [img_crop_width, 0],
  29. [img_crop_width, img_crop_height],
  30. [0, img_crop_height]])
  31. M = cv2.getPerspectiveTransform(points, pts_std)
  32. dst_img = cv2.warpPerspective(
  33. img,
  34. M, (img_crop_width, img_crop_height),
  35. borderMode=cv2.BORDER_REPLICATE,
  36. flags=cv2.INTER_CUBIC)
  37. dst_img_height, dst_img_width = dst_img.shape[0:2]
  38. if dst_img_height * 1.0 / dst_img_width >= 1.5:
  39. dst_img = np.rot90(dst_img)
  40. return dst_img
  41. class OCRInfer(object):
  42. def __init__(self, det_path, rec_path, rec_batch_size=16, time_profile=False, mem_profile=False ,**kwargs):
  43. super().__init__()
  44. self.det_model = DetInfer(det_path)
  45. self.rec_model = RecInfer(rec_path, rec_batch_size)
  46. assert not(time_profile and mem_profile),"can not profile memory and time at the same time"
  47. self.line_profiler = None
  48. if time_profile:
  49. self.line_profiler = LineProfiler()
  50. self.predict = self.predict_time_profile
  51. if mem_profile:
  52. self.predict = self.predict_mem_profile
  53. def do_predict(self, img):
  54. box_list, score_list = self.det_model.predict(img)
  55. if len(box_list) == 0:
  56. return [], [], img
  57. draw_box_list = [tuple(map(tuple, box)) for box in box_list]
  58. imgs =[get_rotate_crop_image(img, box) for box in box_list]
  59. texts = self.rec_model.predict(imgs)
  60. texts = [txt[0][0] for txt in texts]
  61. debug_img = draw_ocr_box_txt(img, draw_box_list, texts)
  62. return box_list, score_list, debug_img
  63. def predict(self, img):
  64. return self.do_predict(img)
  65. def predict_mem_profile(self, img):
  66. wapper = profile(self.do_predict)
  67. return wapper(img)
  68. def predict_time_profile(self, img):
  69. # run multi time
  70. for i in range(8):
  71. print("*********** {} profile time *************".format(i))
  72. lp = LineProfiler()
  73. lp_wrapper = lp(self.do_predict)
  74. ret = lp_wrapper(img)
  75. lp.print_stats()
  76. return ret
  77. def init_args():
  78. import argparse
  79. parser = argparse.ArgumentParser(description='OCR infer')
  80. parser.add_argument('--det_path', required=True, type=str, help='det model path')
  81. parser.add_argument('--rec_path', required=True, type=str, help='rec model path')
  82. parser.add_argument('--img_path', required=True, type=str, help='img path for predict')
  83. parser.add_argument('--rec_batch_size', type=int, help='rec batch_size', default=16)
  84. parser.add_argument('-time_profile', action='store_true', help='enable time profile mode')
  85. parser.add_argument('-mem_profile', action='store_true', help='enable memory profile mode')
  86. args = parser.parse_args()
  87. return vars(args)
  88. if __name__ == '__main__':
  89. import cv2
  90. args = init_args()
  91. img = cv2.imread(args['img_path'])
  92. model = OCRInfer(**args)
  93. txts, boxes, debug_img = model.predict(img)
  94. h,w,_, = debug_img.shape
  95. raido = 1
  96. if w > 1200:
  97. raido = 600.0/w
  98. debug_img = cv2.resize(debug_img, (int(w*raido), int(h*raido)))
  99. if not(args['mem_profile'] or args['time_profile']):
  100. cv2.imshow("debug", debug_img)
  101. cv2.waitKey()