paddleocr.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363
  1. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import os
  15. import sys
  16. __dir__ = os.path.dirname(__file__)
  17. sys.path.append(os.path.dirname(os.path.abspath(__file__)))
  18. sys.path.append(os.path.dirname(os.path.abspath(__file__)) + "/../")
  19. project_path = os.path.abspath(__dir__)
  20. # project_path = ""
  21. import cv2
  22. import numpy as np
  23. from pathlib import Path
  24. import tarfile
  25. import requests
  26. from tqdm import tqdm
  27. os.environ['FLAGS_eager_delete_tensor_gb'] = '0'
  28. from ocr.tools.infer import predict_system
  29. from ocr.ppocr.utils.logging import get_logger
  30. logger = get_logger()
  31. from ocr.ppocr.utils.utility import check_and_read_gif, get_image_file_list
  32. __all__ = ['PaddleOCR']
  33. model_urls = {
  34. 'det':
  35. 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_det_infer.tar',
  36. 'rec': {
  37. 'ch': {
  38. 'url':
  39. 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_rec_infer.tar',
  40. 'dict_path': './ppocr/utils/ppocr_keys_v1.txt'
  41. },
  42. 'en': {
  43. 'url':
  44. 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/en_number_mobile_v2.0_rec_infer.tar',
  45. 'dict_path': './ppocr/utils/dict/en_dict.txt'
  46. },
  47. 'french': {
  48. 'url':
  49. 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/french_mobile_v2.0_rec_infer.tar',
  50. 'dict_path': './ppocr/utils/dict/french_dict.txt'
  51. },
  52. 'german': {
  53. 'url':
  54. 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/german_mobile_v2.0_rec_infer.tar',
  55. 'dict_path': './ppocr/utils/dict/german_dict.txt'
  56. },
  57. 'korean': {
  58. 'url':
  59. 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/korean_mobile_v2.0_rec_infer.tar',
  60. 'dict_path': './ppocr/utils/dict/korean_dict.txt'
  61. },
  62. 'japan': {
  63. 'url':
  64. 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/japan_mobile_v2.0_rec_infer.tar',
  65. 'dict_path': './ppocr/utils/dict/japan_dict.txt'
  66. }
  67. },
  68. 'cls':
  69. 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_cls_infer.tar'
  70. }
  71. SUPPORT_DET_MODEL = ['DB']
  72. VERSION = 2.0
  73. SUPPORT_REC_MODEL = ['CRNN']
  74. # BASE_DIR = os.path.expanduser("~/.paddleocr/")
  75. BASE_DIR = project_path + "/model/"
  76. def download_with_progressbar(url, save_path):
  77. response = requests.get(url, stream=True)
  78. total_size_in_bytes = int(response.headers.get('content-length', 0))
  79. block_size = 1024 # 1 Kibibyte
  80. progress_bar = tqdm(total=total_size_in_bytes, unit='iB', unit_scale=True)
  81. with open(save_path, 'wb') as file:
  82. for data in response.iter_content(block_size):
  83. progress_bar.update(len(data))
  84. file.write(data)
  85. progress_bar.close()
  86. if total_size_in_bytes == 0 or progress_bar.n != total_size_in_bytes:
  87. logger.error("Something went wrong while downloading models")
  88. sys.exit(0)
  89. def maybe_download(model_storage_directory, url):
  90. # using custom model
  91. tar_file_name_list = [
  92. 'inference.pdiparams', 'inference.pdiparams.info', 'inference.pdmodel'
  93. ]
  94. if not os.path.exists(
  95. os.path.join(model_storage_directory, 'inference.pdiparams')
  96. ) or not os.path.exists(
  97. os.path.join(model_storage_directory, 'inference.pdmodel')):
  98. tmp_path = os.path.join(model_storage_directory, url.split('/')[-1])
  99. print('download {} to {}'.format(url, tmp_path))
  100. os.makedirs(model_storage_directory, exist_ok=True)
  101. download_with_progressbar(url, tmp_path)
  102. with tarfile.open(tmp_path, 'r') as tarObj:
  103. for member in tarObj.getmembers():
  104. filename = None
  105. for tar_file_name in tar_file_name_list:
  106. if tar_file_name in member.name:
  107. filename = tar_file_name
  108. if filename is None:
  109. continue
  110. file = tarObj.extractfile(member)
  111. with open(
  112. os.path.join(model_storage_directory, filename),
  113. 'wb') as f:
  114. f.write(file.read())
  115. os.remove(tmp_path)
  116. def parse_args(mMain=True, add_help=True):
  117. import argparse
  118. def str2bool(v):
  119. return v.lower() in ("true", "t", "1")
  120. if mMain:
  121. parser = argparse.ArgumentParser(add_help=add_help)
  122. # params for prediction engine
  123. parser.add_argument("--use_gpu", type=str2bool, default=True)
  124. parser.add_argument("--ir_optim", type=str2bool, default=True)
  125. parser.add_argument("--use_tensorrt", type=str2bool, default=False)
  126. parser.add_argument("--gpu_mem", type=int, default=8000)
  127. # params for text detector
  128. parser.add_argument("--image_dir", type=str)
  129. parser.add_argument("--det_algorithm", type=str, default='DB')
  130. parser.add_argument("--det_model_dir", type=str, default=None)
  131. parser.add_argument("--det_limit_side_len", type=float, default=960)
  132. parser.add_argument("--det_limit_type", type=str, default='max')
  133. # DB parmas
  134. parser.add_argument("--det_db_thresh", type=float, default=0.1)
  135. parser.add_argument("--det_db_box_thresh", type=float, default=0.1)
  136. parser.add_argument("--det_db_unclip_ratio", type=float, default=1.6)
  137. parser.add_argument("--use_dilation", type=bool, default=False)
  138. # EAST parmas
  139. parser.add_argument("--det_east_score_thresh", type=float, default=0.8)
  140. parser.add_argument("--det_east_cover_thresh", type=float, default=0.1)
  141. parser.add_argument("--det_east_nms_thresh", type=float, default=0.2)
  142. # params for text recognizer
  143. parser.add_argument("--rec_algorithm", type=str, default='CRNN')
  144. parser.add_argument("--rec_model_dir", type=str, default=None)
  145. parser.add_argument("--rec_image_shape", type=str, default="3, 32, 320")
  146. parser.add_argument("--rec_char_type", type=str, default='ch')
  147. parser.add_argument("--rec_batch_num", type=int, default=30)
  148. parser.add_argument("--max_text_length", type=int, default=25)
  149. parser.add_argument("--rec_char_dict_path", type=str, default=None)
  150. parser.add_argument("--use_space_char", type=bool, default=True)
  151. parser.add_argument("--drop_score", type=float, default=0.5)
  152. # params for text classifier
  153. parser.add_argument("--cls_model_dir", type=str, default=None)
  154. parser.add_argument("--cls_image_shape", type=str, default="3, 48, 192")
  155. parser.add_argument("--label_list", type=list, default=['0', '180'])
  156. parser.add_argument("--cls_batch_num", type=int, default=30)
  157. parser.add_argument("--cls_thresh", type=float, default=0.9)
  158. parser.add_argument("--enable_mkldnn", type=bool, default=False)
  159. parser.add_argument("--use_zero_copy_run", type=bool, default=False)
  160. parser.add_argument("--use_pdserving", type=str2bool, default=False)
  161. parser.add_argument("--lang", type=str, default='ch')
  162. parser.add_argument("--det", type=str2bool, default=True)
  163. parser.add_argument("--rec", type=str2bool, default=True)
  164. parser.add_argument("--use_angle_cls", type=str2bool, default=False)
  165. return parser.parse_args()
  166. else:
  167. return argparse.Namespace(
  168. use_gpu=True,
  169. ir_optim=True,
  170. use_tensorrt=False,
  171. gpu_mem=8000,
  172. image_dir='',
  173. det_algorithm='DB',
  174. det_model_dir=None,
  175. det_limit_side_len=1280,
  176. det_limit_type='max',
  177. det_db_thresh=0.1,
  178. # det_db_box_thresh 漏行 调小
  179. det_db_box_thresh=0.1,
  180. # det_db_unclip_ratio 检测框的贴近程度
  181. det_db_unclip_ratio=2.5,
  182. # 对文字膨胀操作
  183. use_dilation=False,
  184. det_east_score_thresh=0.8,
  185. det_east_cover_thresh=0.1,
  186. det_east_nms_thresh=0.2,
  187. rec_algorithm='CRNN',
  188. rec_model_dir=None,
  189. rec_image_shape="3, 32, 1000",
  190. rec_char_type='ch',
  191. rec_batch_num=30,
  192. max_text_length=128,
  193. rec_char_dict_path='ocr/ppocr/utils/ppocr_keys_v1.txt',
  194. use_space_char=True,
  195. drop_score=0.5,
  196. cls_model_dir=None,
  197. cls_image_shape="3, 32, 1000",
  198. label_list=['0', '180'],
  199. cls_batch_num=30,
  200. cls_thresh=0.9,
  201. enable_mkldnn=False,
  202. use_zero_copy_run=True,
  203. use_pdserving=False,
  204. lang='ch',
  205. det=True,
  206. rec=True,
  207. use_angle_cls=False)
  208. class PaddleOCR(predict_system.TextSystem):
  209. def __init__(self, **kwargs):
  210. """
  211. paddleocr package
  212. args:
  213. **kwargs: other params show in paddleocr --help
  214. """
  215. postprocess_params = parse_args(mMain=False, add_help=False)
  216. postprocess_params.__dict__.update(**kwargs)
  217. self.use_angle_cls = postprocess_params.use_angle_cls
  218. lang = postprocess_params.lang
  219. assert lang in model_urls[
  220. 'rec'], 'param lang must in {}, but got {}'.format(
  221. model_urls['rec'].keys(), lang)
  222. if postprocess_params.rec_char_dict_path is None:
  223. postprocess_params.rec_char_dict_path = model_urls['rec'][lang][
  224. 'dict_path']
  225. # init model dir
  226. if postprocess_params.det_model_dir is None:
  227. postprocess_params.det_model_dir = os.path.join(
  228. BASE_DIR, '{}/det'.format(VERSION))
  229. if postprocess_params.rec_model_dir is None:
  230. postprocess_params.rec_model_dir = os.path.join(
  231. BASE_DIR, '{}/rec/{}'.format(VERSION, lang))
  232. if postprocess_params.cls_model_dir is None:
  233. postprocess_params.cls_model_dir = os.path.join(
  234. BASE_DIR, '{}/cls'.format(VERSION))
  235. print(postprocess_params)
  236. # download model
  237. maybe_download(postprocess_params.det_model_dir, model_urls['det'])
  238. maybe_download(postprocess_params.rec_model_dir,
  239. model_urls['rec'][lang]['url'])
  240. maybe_download(postprocess_params.cls_model_dir, model_urls['cls'])
  241. if postprocess_params.det_algorithm not in SUPPORT_DET_MODEL:
  242. logger.error('det_algorithm must in {}'.format(SUPPORT_DET_MODEL))
  243. sys.exit(0)
  244. if postprocess_params.rec_algorithm not in SUPPORT_REC_MODEL:
  245. logger.error('rec_algorithm must in {}'.format(SUPPORT_REC_MODEL))
  246. sys.exit(0)
  247. postprocess_params.rec_char_dict_path = str(
  248. Path(__file__).parent.parent / postprocess_params.rec_char_dict_path)
  249. # init det_model and rec_model
  250. super().__init__(postprocess_params)
  251. def ocr(self, img, det=True, rec=True, cls=False):
  252. """
  253. ocr with paddleocr
  254. args:
  255. img: img for ocr, support ndarray, img_path and list or ndarray
  256. det: use text detection or not, if false, only rec will be exec. default is True
  257. rec: use text recognition or not, if false, only det will be exec. default is True
  258. """
  259. print(det, rec, cls)
  260. assert isinstance(img, (np.ndarray, list, str))
  261. if isinstance(img, list) and det == True:
  262. logger.error('When input a list of images, det must be false')
  263. exit(0)
  264. self.use_angle_cls = cls
  265. if isinstance(img, str):
  266. # download net image
  267. if img.startswith('http'):
  268. download_with_progressbar(img, 'tmp.jpg')
  269. img = 'tmp.jpg'
  270. image_file = img
  271. img, flag = check_and_read_gif(image_file)
  272. if not flag:
  273. with open(image_file, 'rb') as f:
  274. np_arr = np.frombuffer(f.read(), dtype=np.uint8)
  275. img = cv2.imdecode(np_arr, cv2.IMREAD_COLOR)
  276. if img is None:
  277. logger.error("error in loading image:{}".format(image_file))
  278. return None
  279. if isinstance(img, np.ndarray) and len(img.shape) == 2:
  280. img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
  281. if det and rec:
  282. dt_boxes, rec_res = self.__call__(img)
  283. # print("paddleocr.py dt_boxes", len(dt_boxes))
  284. # print("paddleocr.py rec_res", len(rec_res))
  285. return [[box.tolist(), res] for box, res in zip(dt_boxes, rec_res)]
  286. elif det and not rec:
  287. dt_boxes, elapse = self.text_detector(img)
  288. if dt_boxes is None:
  289. return None
  290. return [box.tolist() for box in dt_boxes]
  291. else:
  292. if not isinstance(img, list):
  293. img = [img]
  294. if self.use_angle_cls:
  295. img, cls_res, elapse = self.text_classifier(img)
  296. if not rec:
  297. return cls_res
  298. rec_res, elapse = self.text_recognizer(img)
  299. return rec_res
  300. def main(mMain=True):
  301. # for cmd
  302. args = parse_args(mMain)
  303. # args = parse_args(mMain=True)
  304. # 图片是网络的还是本地路径
  305. image_dir = args.image_dir
  306. if image_dir.startswith('http'):
  307. download_with_progressbar(image_dir, 'tmp.jpg')
  308. image_file_list = ['tmp.jpg']
  309. else:
  310. image_file_list = get_image_file_list(args.image_dir)
  311. if len(image_file_list) == 0:
  312. logger.error('no images find in {}'.format(args.image_dir))
  313. return
  314. ocr_engine = PaddleOCR(**(args.__dict__))
  315. for img_path in image_file_list:
  316. logger.info('{}{}{}'.format('*' * 10, img_path, '*' * 10))
  317. result = ocr_engine.ocr(img_path,
  318. det=args.det,
  319. rec=args.rec,
  320. cls=args.use_angle_cls)
  321. if result is not None:
  322. for line in result:
  323. logger.info(line)
  324. if __name__ == '__main__':
  325. main(False)