paddleocr.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353
  1. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import os
  15. import sys
  16. __dir__ = os.path.dirname(__file__)
  17. sys.path.append(os.path.join(__dir__, ''))
  18. import cv2
  19. import numpy as np
  20. from pathlib import Path
  21. import tarfile
  22. import requests
  23. from tqdm import tqdm
  24. from tools.infer import predict_system
  25. from ppocr.utils.logging import get_logger
  26. logger = get_logger()
  27. from ppocr.utils.utility import check_and_read_gif, get_image_file_list
  28. __all__ = ['PaddleOCR']
  29. model_urls = {
  30. 'det':
  31. 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_det_infer.tar',
  32. 'rec': {
  33. 'ch': {
  34. 'url':
  35. 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_rec_infer.tar',
  36. 'dict_path': './ppocr/utils/ppocr_keys_v1.txt'
  37. },
  38. 'en': {
  39. 'url':
  40. 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/en_number_mobile_v2.0_rec_infer.tar',
  41. 'dict_path': './ppocr/utils/dict/en_dict.txt'
  42. },
  43. 'french': {
  44. 'url':
  45. 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/french_mobile_v2.0_rec_infer.tar',
  46. 'dict_path': './ppocr/utils/dict/french_dict.txt'
  47. },
  48. 'german': {
  49. 'url':
  50. 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/german_mobile_v2.0_rec_infer.tar',
  51. 'dict_path': './ppocr/utils/dict/german_dict.txt'
  52. },
  53. 'korean': {
  54. 'url':
  55. 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/korean_mobile_v2.0_rec_infer.tar',
  56. 'dict_path': './ppocr/utils/dict/korean_dict.txt'
  57. },
  58. 'japan': {
  59. 'url':
  60. 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/japan_mobile_v2.0_rec_infer.tar',
  61. 'dict_path': './ppocr/utils/dict/japan_dict.txt'
  62. }
  63. },
  64. 'cls':
  65. 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_cls_infer.tar'
  66. }
  67. SUPPORT_DET_MODEL = ['DB']
  68. VERSION = 2.0
  69. SUPPORT_REC_MODEL = ['CRNN']
  70. BASE_DIR = os.path.expanduser("~/.paddleocr/")
  71. def download_with_progressbar(url, save_path):
  72. response = requests.get(url, stream=True)
  73. total_size_in_bytes = int(response.headers.get('content-length', 0))
  74. block_size = 1024 # 1 Kibibyte
  75. progress_bar = tqdm(total=total_size_in_bytes, unit='iB', unit_scale=True)
  76. with open(save_path, 'wb') as file:
  77. for data in response.iter_content(block_size):
  78. progress_bar.update(len(data))
  79. file.write(data)
  80. progress_bar.close()
  81. if total_size_in_bytes == 0 or progress_bar.n != total_size_in_bytes:
  82. logger.error("Something went wrong while downloading models")
  83. sys.exit(0)
  84. def maybe_download(model_storage_directory, url):
  85. # using custom model
  86. tar_file_name_list = [
  87. 'inference.pdiparams', 'inference.pdiparams.info', 'inference.pdmodel'
  88. ]
  89. if not os.path.exists(
  90. os.path.join(model_storage_directory, 'inference.pdiparams')
  91. ) or not os.path.exists(
  92. os.path.join(model_storage_directory, 'inference.pdmodel')):
  93. tmp_path = os.path.join(model_storage_directory, url.split('/')[-1])
  94. print('download {} to {}'.format(url, tmp_path))
  95. os.makedirs(model_storage_directory, exist_ok=True)
  96. download_with_progressbar(url, tmp_path)
  97. with tarfile.open(tmp_path, 'r') as tarObj:
  98. for member in tarObj.getmembers():
  99. filename = None
  100. for tar_file_name in tar_file_name_list:
  101. if tar_file_name in member.name:
  102. filename = tar_file_name
  103. if filename is None:
  104. continue
  105. file = tarObj.extractfile(member)
  106. with open(
  107. os.path.join(model_storage_directory, filename),
  108. 'wb') as f:
  109. f.write(file.read())
  110. os.remove(tmp_path)
  111. def parse_args(mMain=True, add_help=True):
  112. import argparse
  113. def str2bool(v):
  114. return v.lower() in ("true", "t", "1")
  115. if mMain:
  116. parser = argparse.ArgumentParser(add_help=add_help)
  117. # params for prediction engine
  118. parser.add_argument("--use_gpu", type=str2bool, default=True)
  119. parser.add_argument("--ir_optim", type=str2bool, default=True)
  120. parser.add_argument("--use_tensorrt", type=str2bool, default=False)
  121. parser.add_argument("--gpu_mem", type=int, default=8000)
  122. # params for text detector
  123. parser.add_argument("--image_dir", type=str)
  124. parser.add_argument("--det_algorithm", type=str, default='DB')
  125. parser.add_argument("--det_model_dir", type=str, default=None)
  126. parser.add_argument("--det_limit_side_len", type=float, default=960)
  127. parser.add_argument("--det_limit_type", type=str, default='max')
  128. # DB parmas
  129. parser.add_argument("--det_db_thresh", type=float, default=0.3)
  130. parser.add_argument("--det_db_box_thresh", type=float, default=0.5)
  131. parser.add_argument("--det_db_unclip_ratio", type=float, default=1.6)
  132. parser.add_argument("--use_dilation", type=bool, default=False)
  133. # EAST parmas
  134. parser.add_argument("--det_east_score_thresh", type=float, default=0.8)
  135. parser.add_argument("--det_east_cover_thresh", type=float, default=0.1)
  136. parser.add_argument("--det_east_nms_thresh", type=float, default=0.2)
  137. # params for text recognizer
  138. parser.add_argument("--rec_algorithm", type=str, default='CRNN')
  139. parser.add_argument("--rec_model_dir", type=str, default=None)
  140. parser.add_argument("--rec_image_shape", type=str, default="3, 32, 320")
  141. parser.add_argument("--rec_char_type", type=str, default='ch')
  142. parser.add_argument("--rec_batch_num", type=int, default=30)
  143. parser.add_argument("--max_text_length", type=int, default=25)
  144. parser.add_argument("--rec_char_dict_path", type=str, default=None)
  145. parser.add_argument("--use_space_char", type=bool, default=True)
  146. parser.add_argument("--drop_score", type=float, default=0.5)
  147. # params for text classifier
  148. parser.add_argument("--cls_model_dir", type=str, default=None)
  149. parser.add_argument("--cls_image_shape", type=str, default="3, 48, 192")
  150. parser.add_argument("--label_list", type=list, default=['0', '180'])
  151. parser.add_argument("--cls_batch_num", type=int, default=30)
  152. parser.add_argument("--cls_thresh", type=float, default=0.9)
  153. parser.add_argument("--enable_mkldnn", type=bool, default=False)
  154. parser.add_argument("--use_zero_copy_run", type=bool, default=False)
  155. parser.add_argument("--use_pdserving", type=str2bool, default=False)
  156. parser.add_argument("--lang", type=str, default='ch')
  157. parser.add_argument("--det", type=str2bool, default=True)
  158. parser.add_argument("--rec", type=str2bool, default=True)
  159. parser.add_argument("--use_angle_cls", type=str2bool, default=False)
  160. return parser.parse_args()
  161. else:
  162. return argparse.Namespace(
  163. use_gpu=False,
  164. ir_optim=True,
  165. use_tensorrt=False,
  166. gpu_mem=8000,
  167. image_dir='D:\\Project\\PaddleOCR-release-2.0\\train_data\\bidi_data\\orgs_data\\train\\text_0.jpg',
  168. det_algorithm='DB',
  169. det_model_dir='D:\\Project\\PaddleOCR-release-2.0\\pretrained_model\\ch_ppocr_mobile_v2.0_det_infer',
  170. det_limit_side_len=960,
  171. det_limit_type='max',
  172. det_db_thresh=0.3,
  173. det_db_box_thresh=0.5,
  174. det_db_unclip_ratio=1.6,
  175. use_dilation=False,
  176. det_east_score_thresh=0.8,
  177. det_east_cover_thresh=0.1,
  178. det_east_nms_thresh=0.2,
  179. rec_algorithm='CRNN',
  180. rec_model_dir='D:\\Project\\PaddleOCR-release-2.0\\pretrained_model\\ch_ppocr_mobile_v2.0_rec_infer',
  181. rec_image_shape="3, 32, 320",
  182. rec_char_type='ch',
  183. rec_batch_num=30,
  184. max_text_length=100,
  185. rec_char_dict_path='D:\\Project\\PaddleOCR-release-2.0\\ppocr\\utils\\ppocr_keys_v1.txt',
  186. use_space_char=True,
  187. drop_score=0.5,
  188. cls_model_dir='D:\\Project\\PaddleOCR-release-2.0\\pretrained_model\\ch_ppocr_mobile_v2.0_cls_infer',
  189. cls_image_shape="3, 48, 192",
  190. label_list=['0', '180'],
  191. cls_batch_num=30,
  192. cls_thresh=0.9,
  193. enable_mkldnn=False,
  194. use_zero_copy_run=False,
  195. use_pdserving=False,
  196. lang='ch',
  197. det=False,
  198. rec=True,
  199. use_angle_cls=False)
  200. class PaddleOCR(predict_system.TextSystem):
  201. def __init__(self, **kwargs):
  202. """
  203. paddleocr package
  204. args:
  205. **kwargs: other params show in paddleocr --help
  206. """
  207. postprocess_params = parse_args(mMain=False, add_help=False)
  208. postprocess_params.__dict__.update(**kwargs)
  209. self.use_angle_cls = postprocess_params.use_angle_cls
  210. lang = postprocess_params.lang
  211. assert lang in model_urls[
  212. 'rec'], 'param lang must in {}, but got {}'.format(
  213. model_urls['rec'].keys(), lang)
  214. if postprocess_params.rec_char_dict_path is None:
  215. postprocess_params.rec_char_dict_path = model_urls['rec'][lang][
  216. 'dict_path']
  217. # init model dir
  218. if postprocess_params.det_model_dir is None:
  219. postprocess_params.det_model_dir = os.path.join(
  220. BASE_DIR, '{}/det'.format(VERSION))
  221. if postprocess_params.rec_model_dir is None:
  222. postprocess_params.rec_model_dir = os.path.join(
  223. BASE_DIR, '{}/rec/{}'.format(VERSION, lang))
  224. if postprocess_params.cls_model_dir is None:
  225. postprocess_params.cls_model_dir = os.path.join(
  226. BASE_DIR, '{}/cls'.format(VERSION))
  227. print(postprocess_params)
  228. # download model
  229. maybe_download(postprocess_params.det_model_dir, model_urls['det'])
  230. maybe_download(postprocess_params.rec_model_dir,
  231. model_urls['rec'][lang]['url'])
  232. maybe_download(postprocess_params.cls_model_dir, model_urls['cls'])
  233. if postprocess_params.det_algorithm not in SUPPORT_DET_MODEL:
  234. logger.error('det_algorithm must in {}'.format(SUPPORT_DET_MODEL))
  235. sys.exit(0)
  236. if postprocess_params.rec_algorithm not in SUPPORT_REC_MODEL:
  237. logger.error('rec_algorithm must in {}'.format(SUPPORT_REC_MODEL))
  238. sys.exit(0)
  239. postprocess_params.rec_char_dict_path = str(
  240. Path(__file__).parent / postprocess_params.rec_char_dict_path)
  241. # init det_model and rec_model
  242. super().__init__(postprocess_params)
  243. def ocr(self, img, det=True, rec=True, cls=False):
  244. """
  245. ocr with paddleocr
  246. args:
  247. img: img for ocr, support ndarray, img_path and list or ndarray
  248. det: use text detection or not, if false, only rec will be exec. default is True
  249. rec: use text recognition or not, if false, only det will be exec. default is True
  250. """
  251. print(det, rec, cls)
  252. assert isinstance(img, (np.ndarray, list, str))
  253. if isinstance(img, list) and det == True:
  254. logger.error('When input a list of images, det must be false')
  255. exit(0)
  256. self.use_angle_cls = cls
  257. if isinstance(img, str):
  258. # download net image
  259. if img.startswith('http'):
  260. download_with_progressbar(img, 'tmp.jpg')
  261. img = 'tmp.jpg'
  262. image_file = img
  263. img, flag = check_and_read_gif(image_file)
  264. if not flag:
  265. with open(image_file, 'rb') as f:
  266. np_arr = np.frombuffer(f.read(), dtype=np.uint8)
  267. img = cv2.imdecode(np_arr, cv2.IMREAD_COLOR)
  268. if img is None:
  269. logger.error("error in loading image:{}".format(image_file))
  270. return None
  271. if isinstance(img, np.ndarray) and len(img.shape) == 2:
  272. img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
  273. if det and rec:
  274. dt_boxes, rec_res = self.__call__(img)
  275. return [[box.tolist(), res] for box, res in zip(dt_boxes, rec_res)]
  276. elif det and not rec:
  277. dt_boxes, elapse = self.text_detector(img)
  278. if dt_boxes is None:
  279. return None
  280. return [box.tolist() for box in dt_boxes]
  281. else:
  282. if not isinstance(img, list):
  283. img = [img]
  284. if self.use_angle_cls:
  285. img, cls_res, elapse = self.text_classifier(img)
  286. if not rec:
  287. return cls_res
  288. rec_res, elapse = self.text_recognizer(img)
  289. return rec_res
  290. def main(mMain=True):
  291. # for cmd
  292. args = parse_args(mMain)
  293. # args = parse_args(mMain=True)
  294. # 图片是网络的还是本地路径
  295. image_dir = args.image_dir
  296. if image_dir.startswith('http'):
  297. download_with_progressbar(image_dir, 'tmp.jpg')
  298. image_file_list = ['tmp.jpg']
  299. else:
  300. image_file_list = get_image_file_list(args.image_dir)
  301. if len(image_file_list) == 0:
  302. logger.error('no images find in {}'.format(args.image_dir))
  303. return
  304. ocr_engine = PaddleOCR(**(args.__dict__))
  305. for img_path in image_file_list:
  306. logger.info('{}{}{}'.format('*' * 10, img_path, '*' * 10))
  307. result = ocr_engine.ocr(img_path,
  308. det=args.det,
  309. rec=args.rec,
  310. cls=args.use_angle_cls)
  311. if result is not None:
  312. for line in result:
  313. logger.info(line)
  314. if __name__ == '__main__':
  315. main(False)