paddleocr.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362
  1. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import os
  15. import sys
  16. __dir__ = os.path.dirname(__file__)
  17. sys.path.append(os.path.join(__dir__, ''))
  18. project_path = os.path.abspath(__dir__)
  19. # project_path = ""
  20. import cv2
  21. import numpy as np
  22. from pathlib import Path
  23. import tarfile
  24. import requests
  25. from tqdm import tqdm
  26. os.environ['FLAGS_eager_delete_tensor_gb'] = '0'
  27. from ocr.tools.infer import predict_system
  28. from ocr.ppocr.utils.logging import get_logger
  29. logger = get_logger()
  30. from ocr.ppocr.utils.utility import check_and_read_gif, get_image_file_list
  31. __all__ = ['PaddleOCR']
  32. model_urls = {
  33. 'det':
  34. 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_det_infer.tar',
  35. 'rec': {
  36. 'ch': {
  37. 'url':
  38. 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_rec_infer.tar',
  39. 'dict_path': './ppocr/utils/ppocr_keys_v1.txt'
  40. },
  41. 'en': {
  42. 'url':
  43. 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/en_number_mobile_v2.0_rec_infer.tar',
  44. 'dict_path': './ppocr/utils/dict/en_dict.txt'
  45. },
  46. 'french': {
  47. 'url':
  48. 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/french_mobile_v2.0_rec_infer.tar',
  49. 'dict_path': './ppocr/utils/dict/french_dict.txt'
  50. },
  51. 'german': {
  52. 'url':
  53. 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/german_mobile_v2.0_rec_infer.tar',
  54. 'dict_path': './ppocr/utils/dict/german_dict.txt'
  55. },
  56. 'korean': {
  57. 'url':
  58. 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/korean_mobile_v2.0_rec_infer.tar',
  59. 'dict_path': './ppocr/utils/dict/korean_dict.txt'
  60. },
  61. 'japan': {
  62. 'url':
  63. 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/japan_mobile_v2.0_rec_infer.tar',
  64. 'dict_path': './ppocr/utils/dict/japan_dict.txt'
  65. }
  66. },
  67. 'cls':
  68. 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_cls_infer.tar'
  69. }
  70. SUPPORT_DET_MODEL = ['DB']
  71. VERSION = 2.0
  72. SUPPORT_REC_MODEL = ['CRNN']
  73. # BASE_DIR = os.path.expanduser("~/.paddleocr/")
  74. BASE_DIR = project_path + "/model/"
  75. def download_with_progressbar(url, save_path):
  76. response = requests.get(url, stream=True)
  77. total_size_in_bytes = int(response.headers.get('content-length', 0))
  78. block_size = 1024 # 1 Kibibyte
  79. progress_bar = tqdm(total=total_size_in_bytes, unit='iB', unit_scale=True)
  80. with open(save_path, 'wb') as file:
  81. for data in response.iter_content(block_size):
  82. progress_bar.update(len(data))
  83. file.write(data)
  84. progress_bar.close()
  85. if total_size_in_bytes == 0 or progress_bar.n != total_size_in_bytes:
  86. logger.error("Something went wrong while downloading models")
  87. sys.exit(0)
  88. def maybe_download(model_storage_directory, url):
  89. # using custom model
  90. tar_file_name_list = [
  91. 'inference.pdiparams', 'inference.pdiparams.info', 'inference.pdmodel'
  92. ]
  93. if not os.path.exists(
  94. os.path.join(model_storage_directory, 'inference.pdiparams')
  95. ) or not os.path.exists(
  96. os.path.join(model_storage_directory, 'inference.pdmodel')):
  97. tmp_path = os.path.join(model_storage_directory, url.split('/')[-1])
  98. print('download {} to {}'.format(url, tmp_path))
  99. os.makedirs(model_storage_directory, exist_ok=True)
  100. download_with_progressbar(url, tmp_path)
  101. with tarfile.open(tmp_path, 'r') as tarObj:
  102. for member in tarObj.getmembers():
  103. filename = None
  104. for tar_file_name in tar_file_name_list:
  105. if tar_file_name in member.name:
  106. filename = tar_file_name
  107. if filename is None:
  108. continue
  109. file = tarObj.extractfile(member)
  110. with open(
  111. os.path.join(model_storage_directory, filename),
  112. 'wb') as f:
  113. f.write(file.read())
  114. os.remove(tmp_path)
  115. def parse_args(mMain=True, add_help=True):
  116. import argparse
  117. def str2bool(v):
  118. return v.lower() in ("true", "t", "1")
  119. if mMain:
  120. parser = argparse.ArgumentParser(add_help=add_help)
  121. # params for prediction engine
  122. parser.add_argument("--use_gpu", type=str2bool, default=True)
  123. parser.add_argument("--ir_optim", type=str2bool, default=True)
  124. parser.add_argument("--use_tensorrt", type=str2bool, default=False)
  125. parser.add_argument("--gpu_mem", type=int, default=8000)
  126. # params for text detector
  127. parser.add_argument("--image_dir", type=str)
  128. parser.add_argument("--det_algorithm", type=str, default='DB')
  129. parser.add_argument("--det_model_dir", type=str, default=None)
  130. parser.add_argument("--det_limit_side_len", type=float, default=960)
  131. parser.add_argument("--det_limit_type", type=str, default='max')
  132. # DB parmas
  133. parser.add_argument("--det_db_thresh", type=float, default=0.1)
  134. parser.add_argument("--det_db_box_thresh", type=float, default=0.1)
  135. parser.add_argument("--det_db_unclip_ratio", type=float, default=1.6)
  136. parser.add_argument("--use_dilation", type=bool, default=False)
  137. # EAST parmas
  138. parser.add_argument("--det_east_score_thresh", type=float, default=0.8)
  139. parser.add_argument("--det_east_cover_thresh", type=float, default=0.1)
  140. parser.add_argument("--det_east_nms_thresh", type=float, default=0.2)
  141. # params for text recognizer
  142. parser.add_argument("--rec_algorithm", type=str, default='CRNN')
  143. parser.add_argument("--rec_model_dir", type=str, default=None)
  144. parser.add_argument("--rec_image_shape", type=str, default="3, 32, 320")
  145. parser.add_argument("--rec_char_type", type=str, default='ch')
  146. parser.add_argument("--rec_batch_num", type=int, default=30)
  147. parser.add_argument("--max_text_length", type=int, default=25)
  148. parser.add_argument("--rec_char_dict_path", type=str, default=None)
  149. parser.add_argument("--use_space_char", type=bool, default=True)
  150. parser.add_argument("--drop_score", type=float, default=0.5)
  151. # params for text classifier
  152. parser.add_argument("--cls_model_dir", type=str, default=None)
  153. parser.add_argument("--cls_image_shape", type=str, default="3, 48, 192")
  154. parser.add_argument("--label_list", type=list, default=['0', '180'])
  155. parser.add_argument("--cls_batch_num", type=int, default=30)
  156. parser.add_argument("--cls_thresh", type=float, default=0.9)
  157. parser.add_argument("--enable_mkldnn", type=bool, default=False)
  158. parser.add_argument("--use_zero_copy_run", type=bool, default=False)
  159. parser.add_argument("--use_pdserving", type=str2bool, default=False)
  160. parser.add_argument("--lang", type=str, default='ch')
  161. parser.add_argument("--det", type=str2bool, default=True)
  162. parser.add_argument("--rec", type=str2bool, default=True)
  163. parser.add_argument("--use_angle_cls", type=str2bool, default=False)
  164. return parser.parse_args()
  165. else:
  166. return argparse.Namespace(
  167. use_gpu=False,
  168. ir_optim=True,
  169. use_tensorrt=False,
  170. gpu_mem=8000,
  171. image_dir='',
  172. det_algorithm='DB',
  173. det_model_dir=None,
  174. det_limit_side_len=1280,
  175. det_limit_type='max',
  176. det_db_thresh=0.1,
  177. # det_db_box_thresh 漏行 调小
  178. det_db_box_thresh=0.1,
  179. # det_db_unclip_ratio 检测框的贴近程度
  180. det_db_unclip_ratio=2.5,
  181. # 对文字膨胀操作
  182. use_dilation=False,
  183. det_east_score_thresh=0.8,
  184. det_east_cover_thresh=0.1,
  185. det_east_nms_thresh=0.2,
  186. rec_algorithm='CRNN',
  187. rec_model_dir=None,
  188. rec_image_shape="3, 32, 1000",
  189. rec_char_type='ch',
  190. rec_batch_num=30,
  191. max_text_length=128,
  192. rec_char_dict_path='ocr/ppocr/utils/ppocr_keys_v1.txt',
  193. use_space_char=True,
  194. drop_score=0.5,
  195. cls_model_dir=None,
  196. cls_image_shape="3, 32, 1000",
  197. label_list=['0', '180'],
  198. cls_batch_num=30,
  199. cls_thresh=0.9,
  200. enable_mkldnn=False,
  201. use_zero_copy_run=True,
  202. use_pdserving=False,
  203. lang='ch',
  204. det=True,
  205. rec=True,
  206. use_angle_cls=False)
  207. class PaddleOCR(predict_system.TextSystem):
  208. def __init__(self, **kwargs):
  209. """
  210. paddleocr package
  211. args:
  212. **kwargs: other params show in paddleocr --help
  213. """
  214. postprocess_params = parse_args(mMain=False, add_help=False)
  215. postprocess_params.__dict__.update(**kwargs)
  216. self.use_angle_cls = postprocess_params.use_angle_cls
  217. lang = postprocess_params.lang
  218. assert lang in model_urls[
  219. 'rec'], 'param lang must in {}, but got {}'.format(
  220. model_urls['rec'].keys(), lang)
  221. if postprocess_params.rec_char_dict_path is None:
  222. postprocess_params.rec_char_dict_path = model_urls['rec'][lang][
  223. 'dict_path']
  224. # init model dir
  225. if postprocess_params.det_model_dir is None:
  226. postprocess_params.det_model_dir = os.path.join(
  227. BASE_DIR, '{}/det'.format(VERSION))
  228. if postprocess_params.rec_model_dir is None:
  229. postprocess_params.rec_model_dir = os.path.join(
  230. BASE_DIR, '{}/rec/{}'.format(VERSION, lang))
  231. if postprocess_params.cls_model_dir is None:
  232. postprocess_params.cls_model_dir = os.path.join(
  233. BASE_DIR, '{}/cls'.format(VERSION))
  234. print(postprocess_params)
  235. # download model
  236. maybe_download(postprocess_params.det_model_dir, model_urls['det'])
  237. maybe_download(postprocess_params.rec_model_dir,
  238. model_urls['rec'][lang]['url'])
  239. maybe_download(postprocess_params.cls_model_dir, model_urls['cls'])
  240. if postprocess_params.det_algorithm not in SUPPORT_DET_MODEL:
  241. logger.error('det_algorithm must in {}'.format(SUPPORT_DET_MODEL))
  242. sys.exit(0)
  243. if postprocess_params.rec_algorithm not in SUPPORT_REC_MODEL:
  244. logger.error('rec_algorithm must in {}'.format(SUPPORT_REC_MODEL))
  245. sys.exit(0)
  246. postprocess_params.rec_char_dict_path = str(
  247. Path(__file__).parent.parent / postprocess_params.rec_char_dict_path)
  248. # init det_model and rec_model
  249. super().__init__(postprocess_params)
  250. def ocr(self, img, det=True, rec=True, cls=False):
  251. """
  252. ocr with paddleocr
  253. args:
  254. img: img for ocr, support ndarray, img_path and list or ndarray
  255. det: use text detection or not, if false, only rec will be exec. default is True
  256. rec: use text recognition or not, if false, only det will be exec. default is True
  257. """
  258. print(det, rec, cls)
  259. assert isinstance(img, (np.ndarray, list, str))
  260. if isinstance(img, list) and det == True:
  261. logger.error('When input a list of images, det must be false')
  262. exit(0)
  263. self.use_angle_cls = cls
  264. if isinstance(img, str):
  265. # download net image
  266. if img.startswith('http'):
  267. download_with_progressbar(img, 'tmp.jpg')
  268. img = 'tmp.jpg'
  269. image_file = img
  270. img, flag = check_and_read_gif(image_file)
  271. if not flag:
  272. with open(image_file, 'rb') as f:
  273. np_arr = np.frombuffer(f.read(), dtype=np.uint8)
  274. img = cv2.imdecode(np_arr, cv2.IMREAD_COLOR)
  275. if img is None:
  276. logger.error("error in loading image:{}".format(image_file))
  277. return None
  278. if isinstance(img, np.ndarray) and len(img.shape) == 2:
  279. img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
  280. if det and rec:
  281. dt_boxes, rec_res = self.__call__(img)
  282. print("paddleocr.py dt_boxes", len(dt_boxes))
  283. print("paddleocr.py rec_res", len(rec_res))
  284. return [[box.tolist(), res] for box, res in zip(dt_boxes, rec_res)]
  285. elif det and not rec:
  286. dt_boxes, elapse = self.text_detector(img)
  287. if dt_boxes is None:
  288. return None
  289. return [box.tolist() for box in dt_boxes]
  290. else:
  291. if not isinstance(img, list):
  292. img = [img]
  293. if self.use_angle_cls:
  294. img, cls_res, elapse = self.text_classifier(img)
  295. if not rec:
  296. return cls_res
  297. rec_res, elapse = self.text_recognizer(img)
  298. return rec_res
  299. def main(mMain=True):
  300. # for cmd
  301. args = parse_args(mMain)
  302. # args = parse_args(mMain=True)
  303. # 图片是网络的还是本地路径
  304. image_dir = args.image_dir
  305. if image_dir.startswith('http'):
  306. download_with_progressbar(image_dir, 'tmp.jpg')
  307. image_file_list = ['tmp.jpg']
  308. else:
  309. image_file_list = get_image_file_list(args.image_dir)
  310. if len(image_file_list) == 0:
  311. logger.error('no images find in {}'.format(args.image_dir))
  312. return
  313. ocr_engine = PaddleOCR(**(args.__dict__))
  314. for img_path in image_file_list:
  315. logger.info('{}{}{}'.format('*' * 10, img_path, '*' * 10))
  316. result = ocr_engine.ocr(img_path,
  317. det=args.det,
  318. rec=args.rec,
  319. cls=args.use_angle_cls)
  320. if result is not None:
  321. for line in result:
  322. logger.info(line)
  323. if __name__ == '__main__':
  324. main(False)