utility.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415
  1. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import argparse
  15. import os
  16. import sys
  17. import cv2
  18. import numpy as np
  19. import json
  20. from PIL import Image, ImageDraw, ImageFont
  21. import math
  22. os.environ['FLAGS_eager_delete_tensor_gb'] = '0'
  23. from paddle import inference
  24. def parse_args(return_parse=False):
  25. def str2bool(v):
  26. return v.lower() in ("true", "t", "1")
  27. parser = argparse.ArgumentParser()
  28. # params for prediction engine
  29. parser.add_argument("--use_gpu", type=str2bool, default=True)
  30. parser.add_argument("--ir_optim", type=str2bool, default=True)
  31. parser.add_argument("--use_tensorrt", type=str2bool, default=False)
  32. parser.add_argument("--use_fp16", type=str2bool, default=False)
  33. parser.add_argument("--gpu_mem", type=int, default=500)
  34. # params for text detector
  35. parser.add_argument("--image_dir", type=str)
  36. parser.add_argument("--det_algorithm", type=str, default='DB')
  37. parser.add_argument("--det_model_dir", type=str)
  38. parser.add_argument("--det_limit_side_len", type=float, default=960)
  39. parser.add_argument("--det_limit_type", type=str, default='max')
  40. # DB parmas
  41. parser.add_argument("--det_db_thresh", type=float, default=0.3)
  42. parser.add_argument("--det_db_box_thresh", type=float, default=0.5)
  43. parser.add_argument("--det_db_unclip_ratio", type=float, default=1.6)
  44. parser.add_argument("--max_batch_size", type=int, default=10)
  45. parser.add_argument("--use_dilation", type=bool, default=False)
  46. # EAST parmas
  47. parser.add_argument("--det_east_score_thresh", type=float, default=0.8)
  48. parser.add_argument("--det_east_cover_thresh", type=float, default=0.1)
  49. parser.add_argument("--det_east_nms_thresh", type=float, default=0.2)
  50. # SAST parmas
  51. parser.add_argument("--det_sast_score_thresh", type=float, default=0.5)
  52. parser.add_argument("--det_sast_nms_thresh", type=float, default=0.2)
  53. parser.add_argument("--det_sast_polygon", type=bool, default=False)
  54. # params for text recognizer
  55. parser.add_argument("--rec_algorithm", type=str, default='CRNN')
  56. parser.add_argument("--rec_model_dir", type=str)
  57. parser.add_argument("--rec_image_shape", type=str, default="3, 32, 320")
  58. parser.add_argument("--rec_char_type", type=str, default='ch')
  59. parser.add_argument("--rec_batch_num", type=int, default=6)
  60. parser.add_argument("--max_text_length", type=int, default=25)
  61. parser.add_argument(
  62. "--rec_char_dict_path",
  63. type=str,
  64. default="./ppocr/utils/ppocr_keys_v1.txt")
  65. parser.add_argument("--use_space_char", type=str2bool, default=True)
  66. parser.add_argument(
  67. "--vis_font_path", type=str, default="./doc/fonts/simfang.ttf")
  68. parser.add_argument("--drop_score", type=float, default=0.5)
  69. # params for text classifier
  70. parser.add_argument("--use_angle_cls", type=str2bool, default=False)
  71. parser.add_argument("--cls_model_dir", type=str)
  72. parser.add_argument("--cls_image_shape", type=str, default="3, 48, 192")
  73. parser.add_argument("--label_list", type=list, default=['0', '180'])
  74. parser.add_argument("--cls_batch_num", type=int, default=6)
  75. parser.add_argument("--cls_thresh", type=float, default=0.9)
  76. parser.add_argument("--enable_mkldnn", type=str2bool, default=False)
  77. parser.add_argument("--use_pdserving", type=str2bool, default=False)
  78. if return_parse:
  79. return parser
  80. return parser.parse_args()
  81. def create_predictor(args, mode, logger):
  82. if mode == "det":
  83. model_dir = args.det_model_dir
  84. elif mode == 'cls':
  85. model_dir = args.cls_model_dir
  86. else:
  87. model_dir = args.rec_model_dir
  88. if model_dir is None:
  89. logger.info("not find {} model file path {}".format(mode, model_dir))
  90. sys.exit(0)
  91. model_file_path = model_dir + "/inference.pdmodel"
  92. params_file_path = model_dir + "/inference.pdiparams"
  93. if not os.path.exists(model_file_path):
  94. logger.info("not find model file path {}".format(model_file_path))
  95. sys.exit(0)
  96. if not os.path.exists(params_file_path):
  97. logger.info("not find params file path {}".format(params_file_path))
  98. sys.exit(0)
  99. config = inference.Config(model_file_path, params_file_path)
  100. if args.use_gpu:
  101. config.enable_use_gpu(args.gpu_mem, 0)
  102. if args.use_tensorrt:
  103. config.enable_tensorrt_engine(
  104. precision_mode=inference.PrecisionType.Half
  105. if args.use_fp16 else inference.PrecisionType.Float32,
  106. max_batch_size=args.max_batch_size)
  107. else:
  108. config.disable_gpu()
  109. config.set_cpu_math_library_num_threads(1)
  110. if args.enable_mkldnn:
  111. # cache 10 different shapes for mkldnn to avoid memory leak
  112. config.set_mkldnn_cache_capacity(1)
  113. config.enable_mkldnn()
  114. # TODO LDOUBLEV: fix mkldnn bug when bach_size > 1
  115. #config.set_mkldnn_op({'conv2d', 'depthwise_conv2d', 'pool2d', 'batch_norm'})
  116. args.rec_batch_num = 1
  117. # config.disable_gpu()
  118. # config.enable_use_gpu(args.gpu_mem, 0)
  119. config.enable_memory_optim()
  120. config.set_cpu_math_library_num_threads(1)
  121. config.disable_glog_info()
  122. config.switch_use_feed_fetch_ops(False)
  123. config.switch_specify_input_names(True)
  124. #
  125. config.delete_pass("conv_transpose_eltwiseadd_bn_fuse_pass")
  126. config.delete_pass('conv_elementwise_add_act_fuse_pass')
  127. config.delete_pass('conv_elementwise_add2_act_fuse_pass')
  128. # config.switch_use_feed_fetch_ops(False)
  129. # create predictor
  130. predictor = inference.create_predictor(config)
  131. input_names = predictor.get_input_names()
  132. for name in input_names:
  133. input_tensor = predictor.get_input_handle(name)
  134. output_names = predictor.get_output_names()
  135. output_tensors = []
  136. for output_name in output_names:
  137. output_tensor = predictor.get_output_handle(output_name)
  138. output_tensors.append(output_tensor)
  139. return predictor, input_tensor, output_tensors
  140. def draw_text_det_res(dt_boxes, img_path):
  141. src_im = cv2.imread(img_path)
  142. for box in dt_boxes:
  143. box = np.array(box).astype(np.int32).reshape(-1, 2)
  144. cv2.polylines(src_im, [box], True, color=(255, 255, 0), thickness=2)
  145. return src_im
  146. def resize_img(img, input_size=600):
  147. """
  148. resize img and limit the longest side of the image to input_size
  149. """
  150. img = np.array(img)
  151. im_shape = img.shape
  152. im_size_max = np.max(im_shape[0:2])
  153. im_scale = float(input_size) / float(im_size_max)
  154. img = cv2.resize(img, None, None, fx=im_scale, fy=im_scale)
  155. return img
  156. def draw_ocr(image,
  157. boxes,
  158. txts=None,
  159. scores=None,
  160. drop_score=0.5,
  161. font_path="./doc/simfang.ttf"):
  162. """
  163. Visualize the results of OCR detection and recognition
  164. args:
  165. image(Image|array): RGB image
  166. boxes(list): boxes with shape(N, 4, 2)
  167. txts(list): the texts
  168. scores(list): txxs corresponding scores
  169. drop_score(float): only scores greater than drop_threshold will be visualized
  170. font_path: the path of font which is used to draw text
  171. return(array):
  172. the visualized img
  173. """
  174. if scores is None:
  175. scores = [1] * len(boxes)
  176. box_num = len(boxes)
  177. for i in range(box_num):
  178. if scores is not None and (scores[i] < drop_score or
  179. math.isnan(scores[i])):
  180. continue
  181. box = np.reshape(np.array(boxes[i]), [-1, 1, 2]).astype(np.int64)
  182. image = cv2.polylines(np.array(image), [box], True, (255, 0, 0), 2)
  183. if txts is not None:
  184. img = np.array(resize_img(image, input_size=600))
  185. txt_img = text_visual(
  186. txts,
  187. scores,
  188. img_h=img.shape[0],
  189. img_w=600,
  190. threshold=drop_score,
  191. font_path=font_path)
  192. img = np.concatenate([np.array(img), np.array(txt_img)], axis=1)
  193. return img
  194. return image
  195. def draw_ocr_box_txt(image,
  196. boxes,
  197. txts,
  198. scores=None,
  199. drop_score=0.5,
  200. font_path="./doc/simfang.ttf"):
  201. h, w = image.height, image.width
  202. img_left = image.copy()
  203. img_right = Image.new('RGB', (w, h), (255, 255, 255))
  204. import random
  205. random.seed(0)
  206. draw_left = ImageDraw.Draw(img_left)
  207. draw_right = ImageDraw.Draw(img_right)
  208. for idx, (box, txt) in enumerate(zip(boxes, txts)):
  209. if scores is not None and scores[idx] < drop_score:
  210. continue
  211. color = (random.randint(0, 255), random.randint(0, 255),
  212. random.randint(0, 255))
  213. draw_left.polygon(box, fill=color)
  214. draw_right.polygon(
  215. [
  216. box[0][0], box[0][1], box[1][0], box[1][1], box[2][0],
  217. box[2][1], box[3][0], box[3][1]
  218. ],
  219. outline=color)
  220. box_height = math.sqrt((box[0][0] - box[3][0])**2 + (box[0][1] - box[3][
  221. 1])**2)
  222. box_width = math.sqrt((box[0][0] - box[1][0])**2 + (box[0][1] - box[1][
  223. 1])**2)
  224. if box_height > 2 * box_width:
  225. font_size = max(int(box_width * 0.9), 10)
  226. font = ImageFont.truetype(font_path, font_size, encoding="utf-8")
  227. cur_y = box[0][1]
  228. for c in txt:
  229. char_size = font.getsize(c)
  230. draw_right.text(
  231. (box[0][0] + 3, cur_y), c, fill=(0, 0, 0), font=font)
  232. cur_y += char_size[1]
  233. else:
  234. font_size = max(int(box_height * 0.8), 10)
  235. font = ImageFont.truetype(font_path, font_size, encoding="utf-8")
  236. draw_right.text(
  237. [box[0][0], box[0][1]], txt, fill=(0, 0, 0), font=font)
  238. img_left = Image.blend(image, img_left, 0.5)
  239. img_show = Image.new('RGB', (w * 2, h), (255, 255, 255))
  240. img_show.paste(img_left, (0, 0, w, h))
  241. img_show.paste(img_right, (w, 0, w * 2, h))
  242. return np.array(img_show)
  243. def str_count(s):
  244. """
  245. Count the number of Chinese characters,
  246. a single English character and a single number
  247. equal to half the length of Chinese characters.
  248. args:
  249. s(string): the input of string
  250. return(int):
  251. the number of Chinese characters
  252. """
  253. import string
  254. count_zh = count_pu = 0
  255. s_len = len(s)
  256. en_dg_count = 0
  257. for c in s:
  258. if c in string.ascii_letters or c.isdigit() or c.isspace():
  259. en_dg_count += 1
  260. elif c.isalpha():
  261. count_zh += 1
  262. else:
  263. count_pu += 1
  264. return s_len - math.ceil(en_dg_count / 2)
  265. def text_visual(texts,
  266. scores,
  267. img_h=400,
  268. img_w=600,
  269. threshold=0.,
  270. font_path="./doc/simfang.ttf"):
  271. """
  272. create new blank img and draw txt on it
  273. args:
  274. texts(list): the text will be draw
  275. scores(list|None): corresponding score of each txt
  276. img_h(int): the height of blank img
  277. img_w(int): the width of blank img
  278. font_path: the path of font which is used to draw text
  279. return(array):
  280. """
  281. if scores is not None:
  282. assert len(texts) == len(
  283. scores), "The number of txts and corresponding scores must match"
  284. def create_blank_img():
  285. blank_img = np.ones(shape=[img_h, img_w], dtype=np.int8) * 255
  286. blank_img[:, img_w - 1:] = 0
  287. blank_img = Image.fromarray(blank_img).convert("RGB")
  288. draw_txt = ImageDraw.Draw(blank_img)
  289. return blank_img, draw_txt
  290. blank_img, draw_txt = create_blank_img()
  291. font_size = 20
  292. txt_color = (0, 0, 0)
  293. font = ImageFont.truetype(font_path, font_size, encoding="utf-8")
  294. gap = font_size + 5
  295. txt_img_list = []
  296. count, index = 1, 0
  297. for idx, txt in enumerate(texts):
  298. index += 1
  299. if scores[idx] < threshold or math.isnan(scores[idx]):
  300. index -= 1
  301. continue
  302. first_line = True
  303. while str_count(txt) >= img_w // font_size - 4:
  304. tmp = txt
  305. txt = tmp[:img_w // font_size - 4]
  306. if first_line:
  307. new_txt = str(index) + ': ' + txt
  308. first_line = False
  309. else:
  310. new_txt = ' ' + txt
  311. draw_txt.text((0, gap * count), new_txt, txt_color, font=font)
  312. txt = tmp[img_w // font_size - 4:]
  313. if count >= img_h // gap - 1:
  314. txt_img_list.append(np.array(blank_img))
  315. blank_img, draw_txt = create_blank_img()
  316. count = 0
  317. count += 1
  318. if first_line:
  319. new_txt = str(index) + ': ' + txt + ' ' + '%.3f' % (scores[idx])
  320. else:
  321. new_txt = " " + txt + " " + '%.3f' % (scores[idx])
  322. draw_txt.text((0, gap * count), new_txt, txt_color, font=font)
  323. # whether add new blank img or not
  324. if count >= img_h // gap - 1 and idx + 1 < len(texts):
  325. txt_img_list.append(np.array(blank_img))
  326. blank_img, draw_txt = create_blank_img()
  327. count = 0
  328. count += 1
  329. txt_img_list.append(np.array(blank_img))
  330. if len(txt_img_list) == 1:
  331. blank_img = np.array(txt_img_list[0])
  332. else:
  333. blank_img = np.concatenate(txt_img_list, axis=1)
  334. return np.array(blank_img)
  335. def base64_to_cv2(b64str):
  336. import base64
  337. data = base64.b64decode(b64str.encode('utf8'))
  338. data = np.fromstring(data, np.uint8)
  339. data = cv2.imdecode(data, cv2.IMREAD_COLOR)
  340. return data
  341. def draw_boxes(image, boxes, scores=None, drop_score=0.5):
  342. if scores is None:
  343. scores = [1] * len(boxes)
  344. for (box, score) in zip(boxes, scores):
  345. if score < drop_score:
  346. continue
  347. box = np.reshape(np.array(box), [-1, 1, 2]).astype(np.int64)
  348. image = cv2.polylines(np.array(image), [box], True, (255, 0, 0), 2)
  349. return image
  350. if __name__ == '__main__':
  351. test_img = "./doc/test_v2"
  352. predict_txt = "./doc/predict.txt"
  353. f = open(predict_txt, 'r')
  354. data = f.readlines()
  355. img_path, anno = data[0].strip().split('\t')
  356. img_name = os.path.basename(img_path)
  357. img_path = os.path.join(test_img, img_name)
  358. image = Image.open(img_path)
  359. data = json.loads(anno)
  360. boxes, txts, scores = [], [], []
  361. for dic in data:
  362. boxes.append(dic['points'])
  363. txts.append(dic['transcription'])
  364. scores.append(round(dic['scores'], 3))
  365. new_img = draw_ocr(image, boxes, txts, scores)
  366. cv2.imwrite(img_name, new_img)