utility.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420
  1. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import argparse
  15. import logging
  16. import os
  17. import sys
  18. import time
  19. import cv2
  20. import numpy as np
  21. import json
  22. from PIL import Image, ImageDraw, ImageFont
  23. import math
  24. os.environ['FLAGS_eager_delete_tensor_gb'] = '0'
  25. from paddle import inference
  26. def parse_args(return_parse=False):
  27. def str2bool(v):
  28. return v.lower() in ("true", "t", "1")
  29. parser = argparse.ArgumentParser()
  30. # params for prediction engine
  31. parser.add_argument("--use_gpu", type=str2bool, default=True)
  32. parser.add_argument("--ir_optim", type=str2bool, default=True)
  33. parser.add_argument("--use_tensorrt", type=str2bool, default=False)
  34. parser.add_argument("--use_fp16", type=str2bool, default=False)
  35. parser.add_argument("--gpu_mem", type=int, default=500)
  36. # params for text detector
  37. parser.add_argument("--image_dir", type=str)
  38. parser.add_argument("--det_algorithm", type=str, default='DB')
  39. parser.add_argument("--det_model_dir", type=str)
  40. parser.add_argument("--det_limit_side_len", type=float, default=960)
  41. parser.add_argument("--det_limit_type", type=str, default='max')
  42. # DB parmas
  43. parser.add_argument("--det_db_thresh", type=float, default=0.3)
  44. parser.add_argument("--det_db_box_thresh", type=float, default=0.5)
  45. parser.add_argument("--det_db_unclip_ratio", type=float, default=1.6)
  46. parser.add_argument("--max_batch_size", type=int, default=10)
  47. parser.add_argument("--use_dilation", type=bool, default=False)
  48. # EAST parmas
  49. parser.add_argument("--det_east_score_thresh", type=float, default=0.8)
  50. parser.add_argument("--det_east_cover_thresh", type=float, default=0.1)
  51. parser.add_argument("--det_east_nms_thresh", type=float, default=0.2)
  52. # SAST parmas
  53. parser.add_argument("--det_sast_score_thresh", type=float, default=0.5)
  54. parser.add_argument("--det_sast_nms_thresh", type=float, default=0.2)
  55. parser.add_argument("--det_sast_polygon", type=bool, default=False)
  56. # params for text recognizer
  57. parser.add_argument("--rec_algorithm", type=str, default='CRNN')
  58. parser.add_argument("--rec_model_dir", type=str)
  59. parser.add_argument("--rec_image_shape", type=str, default="3, 32, 320")
  60. parser.add_argument("--rec_char_type", type=str, default='ch')
  61. parser.add_argument("--rec_batch_num", type=int, default=6)
  62. parser.add_argument("--max_text_length", type=int, default=25)
  63. parser.add_argument(
  64. "--rec_char_dict_path",
  65. type=str,
  66. default="./ppocr/utils/ppocr_keys_v1.txt")
  67. parser.add_argument("--use_space_char", type=str2bool, default=True)
  68. parser.add_argument(
  69. "--vis_font_path", type=str, default="./doc/fonts/simfang.ttf")
  70. parser.add_argument("--drop_score", type=float, default=0.5)
  71. # params for text classifier
  72. parser.add_argument("--use_angle_cls", type=str2bool, default=False)
  73. parser.add_argument("--cls_model_dir", type=str)
  74. parser.add_argument("--cls_image_shape", type=str, default="3, 48, 192")
  75. parser.add_argument("--label_list", type=list, default=['0', '180'])
  76. parser.add_argument("--cls_batch_num", type=int, default=6)
  77. parser.add_argument("--cls_thresh", type=float, default=0.9)
  78. parser.add_argument("--enable_mkldnn", type=str2bool, default=False)
  79. parser.add_argument("--use_pdserving", type=str2bool, default=False)
  80. if return_parse:
  81. return parser
  82. return parser.parse_args()
  83. def create_predictor(args, mode, logger):
  84. if mode == "det":
  85. model_dir = args.det_model_dir
  86. elif mode == 'cls':
  87. model_dir = args.cls_model_dir
  88. else:
  89. model_dir = args.rec_model_dir
  90. if model_dir is None:
  91. logger.info("not find {} model file path {}".format(mode, model_dir))
  92. sys.exit(0)
  93. model_file_path = model_dir + "/inference.pdmodel"
  94. params_file_path = model_dir + "/inference.pdiparams"
  95. if not os.path.exists(model_file_path):
  96. logger.info("not find model file path {}".format(model_file_path))
  97. sys.exit(0)
  98. if not os.path.exists(params_file_path):
  99. logger.info("not find params file path {}".format(params_file_path))
  100. sys.exit(0)
  101. config = inference.Config(model_file_path, params_file_path)
  102. if args.use_gpu:
  103. config.enable_use_gpu(args.gpu_mem, 0)
  104. if args.use_tensorrt:
  105. config.enable_tensorrt_engine(
  106. precision_mode=inference.PrecisionType.Half
  107. if args.use_fp16 else inference.PrecisionType.Float32,
  108. max_batch_size=args.max_batch_size)
  109. else:
  110. config.disable_gpu()
  111. config.set_cpu_math_library_num_threads(1)
  112. if args.enable_mkldnn:
  113. # cache 10 different shapes for mkldnn to avoid memory leak
  114. config.set_mkldnn_cache_capacity(1)
  115. config.enable_mkldnn()
  116. # TODO LDOUBLEV: fix mkldnn bug when bach_size > 1
  117. #config.set_mkldnn_op({'conv2d', 'depthwise_conv2d', 'pool2d', 'batch_norm'})
  118. args.rec_batch_num = 1
  119. # config.disable_gpu()
  120. # config.enable_use_gpu(args.gpu_mem, 0)
  121. config.enable_memory_optim()
  122. config.set_cpu_math_library_num_threads(1)
  123. config.disable_glog_info()
  124. config.switch_use_feed_fetch_ops(False)
  125. config.switch_specify_input_names(True)
  126. #
  127. config.delete_pass("conv_transpose_eltwiseadd_bn_fuse_pass")
  128. config.delete_pass('conv_elementwise_add_act_fuse_pass')
  129. config.delete_pass('conv_elementwise_add2_act_fuse_pass')
  130. # config.switch_use_feed_fetch_ops(False)
  131. # create predictor
  132. start_time = time.time()
  133. predictor = inference.create_predictor(config)
  134. input_names = predictor.get_input_names()
  135. for name in input_names:
  136. input_tensor = predictor.get_input_handle(name)
  137. output_names = predictor.get_output_names()
  138. output_tensors = []
  139. for output_name in output_names:
  140. output_tensor = predictor.get_output_handle(output_name)
  141. output_tensors.append(output_tensor)
  142. logging.info("ocr model predict time " + str(time.time()-start_time))
  143. return predictor, input_tensor, output_tensors
  144. def draw_text_det_res(dt_boxes, img_path):
  145. src_im = cv2.imread(img_path)
  146. for box in dt_boxes:
  147. box = np.array(box).astype(np.int32).reshape(-1, 2)
  148. cv2.polylines(src_im, [box], True, color=(255, 255, 0), thickness=2)
  149. return src_im
  150. def resize_img(img, input_size=600):
  151. """
  152. resize img and limit the longest side of the image to input_size
  153. """
  154. img = np.array(img)
  155. im_shape = img.shape
  156. im_size_max = np.max(im_shape[0:2])
  157. im_scale = float(input_size) / float(im_size_max)
  158. img = cv2.resize(img, None, None, fx=im_scale, fy=im_scale)
  159. return img
  160. def draw_ocr(image,
  161. boxes,
  162. txts=None,
  163. scores=None,
  164. drop_score=0.5,
  165. font_path="./doc/simfang.ttf"):
  166. """
  167. Visualize the results of OCR detection and recognition
  168. args:
  169. image(Image|array): RGB image
  170. boxes(list): boxes with shape(N, 4, 2)
  171. txts(list): the texts
  172. scores(list): txxs corresponding scores
  173. drop_score(float): only scores greater than drop_threshold will be visualized
  174. font_path: the path of font which is used to draw text
  175. return(array):
  176. the visualized img
  177. """
  178. if scores is None:
  179. scores = [1] * len(boxes)
  180. box_num = len(boxes)
  181. for i in range(box_num):
  182. if scores is not None and (scores[i] < drop_score or
  183. math.isnan(scores[i])):
  184. continue
  185. box = np.reshape(np.array(boxes[i]), [-1, 1, 2]).astype(np.int64)
  186. image = cv2.polylines(np.array(image), [box], True, (255, 0, 0), 2)
  187. if txts is not None:
  188. img = np.array(resize_img(image, input_size=600))
  189. txt_img = text_visual(
  190. txts,
  191. scores,
  192. img_h=img.shape[0],
  193. img_w=600,
  194. threshold=drop_score,
  195. font_path=font_path)
  196. img = np.concatenate([np.array(img), np.array(txt_img)], axis=1)
  197. return img
  198. return image
  199. def draw_ocr_box_txt(image,
  200. boxes,
  201. txts,
  202. scores=None,
  203. drop_score=0.5,
  204. font_path="./doc/simfang.ttf"):
  205. h, w = image.height, image.width
  206. img_left = image.copy()
  207. img_right = Image.new('RGB', (w, h), (255, 255, 255))
  208. import random
  209. random.seed(0)
  210. draw_left = ImageDraw.Draw(img_left)
  211. draw_right = ImageDraw.Draw(img_right)
  212. for idx, (box, txt) in enumerate(zip(boxes, txts)):
  213. if scores is not None and scores[idx] < drop_score:
  214. continue
  215. color = (random.randint(0, 255), random.randint(0, 255),
  216. random.randint(0, 255))
  217. draw_left.polygon(box, fill=color)
  218. draw_right.polygon(
  219. [
  220. box[0][0], box[0][1], box[1][0], box[1][1], box[2][0],
  221. box[2][1], box[3][0], box[3][1]
  222. ],
  223. outline=color)
  224. box_height = math.sqrt((box[0][0] - box[3][0])**2 + (box[0][1] - box[3][
  225. 1])**2)
  226. box_width = math.sqrt((box[0][0] - box[1][0])**2 + (box[0][1] - box[1][
  227. 1])**2)
  228. if box_height > 2 * box_width:
  229. font_size = max(int(box_width * 0.9), 10)
  230. font = ImageFont.truetype(font_path, font_size, encoding="utf-8")
  231. cur_y = box[0][1]
  232. for c in txt:
  233. char_size = font.getsize(c)
  234. draw_right.text(
  235. (box[0][0] + 3, cur_y), c, fill=(0, 0, 0), font=font)
  236. cur_y += char_size[1]
  237. else:
  238. font_size = max(int(box_height * 0.8), 10)
  239. font = ImageFont.truetype(font_path, font_size, encoding="utf-8")
  240. draw_right.text(
  241. [box[0][0], box[0][1]], txt, fill=(0, 0, 0), font=font)
  242. img_left = Image.blend(image, img_left, 0.5)
  243. img_show = Image.new('RGB', (w * 2, h), (255, 255, 255))
  244. img_show.paste(img_left, (0, 0, w, h))
  245. img_show.paste(img_right, (w, 0, w * 2, h))
  246. return np.array(img_show)
  247. def str_count(s):
  248. """
  249. Count the number of Chinese characters,
  250. a single English character and a single number
  251. equal to half the length of Chinese characters.
  252. args:
  253. s(string): the input of string
  254. return(int):
  255. the number of Chinese characters
  256. """
  257. import string
  258. count_zh = count_pu = 0
  259. s_len = len(s)
  260. en_dg_count = 0
  261. for c in s:
  262. if c in string.ascii_letters or c.isdigit() or c.isspace():
  263. en_dg_count += 1
  264. elif c.isalpha():
  265. count_zh += 1
  266. else:
  267. count_pu += 1
  268. return s_len - math.ceil(en_dg_count / 2)
  269. def text_visual(texts,
  270. scores,
  271. img_h=400,
  272. img_w=600,
  273. threshold=0.,
  274. font_path="./doc/simfang.ttf"):
  275. """
  276. create new blank img and draw txt on it
  277. args:
  278. texts(list): the text will be draw
  279. scores(list|None): corresponding score of each txt
  280. img_h(int): the height of blank img
  281. img_w(int): the width of blank img
  282. font_path: the path of font which is used to draw text
  283. return(array):
  284. """
  285. if scores is not None:
  286. assert len(texts) == len(
  287. scores), "The number of txts and corresponding scores must match"
  288. def create_blank_img():
  289. blank_img = np.ones(shape=[img_h, img_w], dtype=np.int8) * 255
  290. blank_img[:, img_w - 1:] = 0
  291. blank_img = Image.fromarray(blank_img).convert("RGB")
  292. draw_txt = ImageDraw.Draw(blank_img)
  293. return blank_img, draw_txt
  294. blank_img, draw_txt = create_blank_img()
  295. font_size = 20
  296. txt_color = (0, 0, 0)
  297. font = ImageFont.truetype(font_path, font_size, encoding="utf-8")
  298. gap = font_size + 5
  299. txt_img_list = []
  300. count, index = 1, 0
  301. for idx, txt in enumerate(texts):
  302. index += 1
  303. if scores[idx] < threshold or math.isnan(scores[idx]):
  304. index -= 1
  305. continue
  306. first_line = True
  307. while str_count(txt) >= img_w // font_size - 4:
  308. tmp = txt
  309. txt = tmp[:img_w // font_size - 4]
  310. if first_line:
  311. new_txt = str(index) + ': ' + txt
  312. first_line = False
  313. else:
  314. new_txt = ' ' + txt
  315. draw_txt.text((0, gap * count), new_txt, txt_color, font=font)
  316. txt = tmp[img_w // font_size - 4:]
  317. if count >= img_h // gap - 1:
  318. txt_img_list.append(np.array(blank_img))
  319. blank_img, draw_txt = create_blank_img()
  320. count = 0
  321. count += 1
  322. if first_line:
  323. new_txt = str(index) + ': ' + txt + ' ' + '%.3f' % (scores[idx])
  324. else:
  325. new_txt = " " + txt + " " + '%.3f' % (scores[idx])
  326. draw_txt.text((0, gap * count), new_txt, txt_color, font=font)
  327. # whether add new blank img or not
  328. if count >= img_h // gap - 1 and idx + 1 < len(texts):
  329. txt_img_list.append(np.array(blank_img))
  330. blank_img, draw_txt = create_blank_img()
  331. count = 0
  332. count += 1
  333. txt_img_list.append(np.array(blank_img))
  334. if len(txt_img_list) == 1:
  335. blank_img = np.array(txt_img_list[0])
  336. else:
  337. blank_img = np.concatenate(txt_img_list, axis=1)
  338. return np.array(blank_img)
  339. def base64_to_cv2(b64str):
  340. import base64
  341. data = base64.b64decode(b64str.encode('utf8'))
  342. data = np.fromstring(data, np.uint8)
  343. data = cv2.imdecode(data, cv2.IMREAD_COLOR)
  344. return data
  345. def draw_boxes(image, boxes, scores=None, drop_score=0.5):
  346. if scores is None:
  347. scores = [1] * len(boxes)
  348. for (box, score) in zip(boxes, scores):
  349. if score < drop_score:
  350. continue
  351. box = np.reshape(np.array(box), [-1, 1, 2]).astype(np.int64)
  352. image = cv2.polylines(np.array(image), [box], True, (255, 0, 0), 2)
  353. return image
  354. if __name__ == '__main__':
  355. test_img = "./doc/test_v2"
  356. predict_txt = "./doc/predict.txt"
  357. f = open(predict_txt, 'r')
  358. data = f.readlines()
  359. img_path, anno = data[0].strip().split('\t')
  360. img_name = os.path.basename(img_path)
  361. img_path = os.path.join(test_img, img_name)
  362. image = Image.open(img_path)
  363. data = json.loads(anno)
  364. boxes, txts, scores = [], [], []
  365. for dic in data:
  366. boxes.append(dic['points'])
  367. txts.append(dic['transcription'])
  368. scores.append(round(dic['scores'], 3))
  369. new_img = draw_ocr(image, boxes, txts, scores)
  370. cv2.imwrite(img_name, new_img)