predict_rec_pytorch.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479
  1. # encoding=utf8
  2. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. import io
  16. import logging
  17. import os
  18. import sys
  19. # __dir__ = os.path.dirname(os.path.abspath(__file__))
  20. import zlib
  21. import requests
  22. # sys.path.append(__dir__)
  23. # sys.path.append(os.path.abspath(os.path.join(__dir__, '../..')))
  24. sys.path.append(os.path.dirname(os.path.abspath(__file__)) + "/../../../")
  25. os.environ["FLAGS_allocator_strategy"] = 'auto_growth'
  26. import cv2
  27. import numpy as np
  28. import math
  29. import time
  30. import traceback
  31. os.environ['FLAGS_eager_delete_tensor_gb'] = '0'
  32. import paddle
  33. import ocr.tools.infer.utility as utility
  34. from ocr.ppocr.postprocess import build_post_process
  35. from ocr.ppocr.utils.logging import get_logger
  36. from ocr.ppocr.utils.utility import get_image_file_list, check_and_read_gif
  37. from config.max_compute_config import MAX_COMPUTE
  38. from format_convert.utils import judge_error_code, log, namespace_to_dict, get_platform, file_lock, \
  39. get_gpu_memory_usage, get_current_process_gpu_id
  40. from format_convert import _global
  41. import torch
  42. from torch import nn
  43. from ocr.tools.infer.torch_rec_model import Rec_ResNet_34
  44. import gc
  45. logger = get_logger()
  46. class TextRecognizer(object):
  47. shrink_memory_count = 0
  48. def __init__(self, args):
  49. self.rec_image_shape = [int(v) for v in args.rec_image_shape.split(",")]
  50. self.character_type = args.rec_char_type
  51. self.rec_batch_num = args.rec_batch_num
  52. self.rec_batch_num = 16
  53. print('self.rec_batch_num', self.rec_batch_num)
  54. self.rec_algorithm = args.rec_algorithm
  55. postprocess_params = {
  56. 'name': 'CTCLabelDecode',
  57. "character_type": args.rec_char_type,
  58. "character_dict_path": args.rec_char_dict_path,
  59. # "use_space_char": args.use_space_char
  60. "use_space_char": False
  61. }
  62. self.postprocess_op = build_post_process(postprocess_params)
  63. rec_model_path = args.rec_model_dir
  64. self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
  65. model = Rec_ResNet_34()
  66. mode_state_dict = torch.load(rec_model_path, self.device)['state_dict']
  67. if str(self.device)=='cpu': # cpu处理时精度调整,加速推理
  68. for name, value in mode_state_dict.items():
  69. if get_platform() != "Windows":
  70. value = value.double()
  71. value = torch.where((value < 1.0e-23) & (value > 0.0), 1.0e-23, value)
  72. value = torch.where((value > -1.0e-23) & (value < 0.0), -1.0e-23, value)
  73. mode_state_dict[name] = value
  74. model.load_state_dict(mode_state_dict)
  75. self.predictor = model
  76. self.predictor.to(self.device)
  77. self.predictor.eval()
  78. if str(self.device) != 'cpu':
  79. self.gpu_id = get_current_process_gpu_id()
  80. else:
  81. self.gpu_id = None
  82. def resize_norm_img(self, img, max_wh_ratio):
  83. h, w = img.shape[:2]
  84. imgC, imgH, imgW = self.rec_image_shape
  85. assert imgC == img.shape[2]
  86. # print('max_wh_ratio', max_wh_ratio)
  87. # max_wh_ratio h是w的10倍,直接返回
  88. if max_wh_ratio < 0.1:
  89. # log('max_wh_ratio < 0.1', )
  90. resized_image = img.astype('float32')
  91. resized_image = resized_image.transpose((2, 0, 1)) / 255
  92. return resized_image
  93. else:
  94. if self.character_type == "ch":
  95. imgW = int((32 * max_wh_ratio))
  96. ratio = w / float(h)
  97. if math.ceil(imgH * ratio) > imgW:
  98. resized_w = imgW
  99. else:
  100. resized_w = int(math.ceil(imgH * ratio))
  101. try:
  102. resized_image = cv2.resize(img, (resized_w, imgH))
  103. except:
  104. log("predict_rec.py resize_norm_img resize shape " + str((resized_w, imgH, imgW, h, w, ratio, max_wh_ratio)) + ' ' + str(self.rec_image_shape))
  105. raise
  106. resized_image = resized_image.astype('float32')
  107. resized_image = resized_image.transpose((2, 0, 1)) / 255
  108. resized_image -= 0.5
  109. resized_image /= 0.5
  110. padding_im = np.zeros((imgC, imgH, imgW), dtype=np.float32)
  111. if resized_w is not None:
  112. padding_im[:, :, 0:resized_w] = resized_image
  113. return padding_im
  114. def predict(self, norm_img_batch):
  115. tensor = torch.from_numpy(norm_img_batch).float()
  116. # if norm_img.shape[3] >= 100 and get_platform() != "Windows" and not MAX_COMPUTE:
  117. if get_platform() != "Windows" and not MAX_COMPUTE:
  118. # 加锁
  119. time2 = time.time()
  120. lock_file_sub = 'ocr'
  121. lock_file = os.path.abspath(os.path.dirname(__file__)) + "/" + lock_file_sub + ".lock"
  122. f = file_lock(lock_file)
  123. log("rec get file_lock " + lock_file + " time " + str(time.time()-time2))
  124. try:
  125. time2 = time.time()
  126. if str(self.device) != 'cpu':
  127. torch.cuda.empty_cache()
  128. tensor = tensor.to(self.device)
  129. with torch.no_grad():
  130. out = self.predictor(tensor)
  131. log("get file_lock run rec" + " time " + str(time.time()-time2))
  132. except RuntimeError:
  133. log("ocr/tools/infer/predict_rec.py predict.run error! maybe no gpu memory!")
  134. log("rec predictor shrink memory! ori_im.shape " + str(norm_img_batch.shape))
  135. get_gpu_memory_usage()
  136. raise RuntimeError
  137. finally:
  138. f.close()
  139. if str(self.device) != 'cpu':
  140. torch.cuda.empty_cache()
  141. gc.collect()
  142. else:
  143. tensor = tensor.to(self.device)
  144. with torch.no_grad():
  145. out = self.predictor(tensor)
  146. # logging.info("ocr model predict time - rec" + str(time.time()-start_time))
  147. out = out.cpu().numpy()
  148. preds = out
  149. return preds
  150. def predict_batch(self, batch_list):
  151. batch_out_list = []
  152. if get_platform() != "Windows" and not MAX_COMPUTE and self.gpu_id is not None:
  153. # 加锁
  154. time2 = time.time()
  155. lock_file_sub = f'ocr_{self.gpu_id}'
  156. lock_file = os.path.abspath(os.path.dirname(__file__)) + "/" + lock_file_sub + ".lock"
  157. f = file_lock(lock_file)
  158. log("rec get file_lock " + lock_file + " time " + str(time.time()-time2))
  159. try:
  160. time2 = time.time()
  161. if str(self.device) != 'cpu':
  162. torch.cuda.empty_cache()
  163. for sub_batch_list in batch_list:
  164. sub_batch_out = []
  165. for tensor in sub_batch_list:
  166. with torch.no_grad():
  167. out = self.predictor(tensor)
  168. out = out.cpu().numpy()
  169. sub_batch_out.append(out)
  170. # sub_batch_out = np.concatenate(sub_batch_out, axis=0)
  171. batch_out_list.append(sub_batch_out)
  172. log("get file_lock run rec" + " time " + str(time.time()-time2))
  173. except RuntimeError:
  174. log("ocr/tools/infer/predict_rec.py predict.run error! maybe no gpu memory!")
  175. log("rec predictor shrink memory! ori_im.shape " + str(tensor.shape))
  176. get_gpu_memory_usage()
  177. raise RuntimeError
  178. finally:
  179. f.close()
  180. if str(self.device) != 'cpu':
  181. torch.cuda.empty_cache()
  182. else:
  183. for sub_batch_list in batch_list:
  184. sub_batch_out = []
  185. for tensor in sub_batch_list:
  186. # print('tensor.shape', tensor.shape)
  187. with torch.no_grad():
  188. out = self.predictor(tensor)
  189. out = out.cpu().numpy()
  190. # print('out.shape', out.shape)
  191. sub_batch_out.append(out)
  192. # sub_batch_out = np.concatenate(sub_batch_out, axis=0)
  193. batch_out_list.append(sub_batch_out)
  194. # 转为numpy
  195. for bi, sub_batch_out in enumerate(batch_out_list):
  196. batch_out_list[bi] = np.concatenate(sub_batch_out, axis=0)
  197. return batch_out_list
  198. def __call__(self, img_list):
  199. start_time = time.time()
  200. # print('into TextRecognizer __call__')
  201. img_num = len(img_list)
  202. # 过滤图片比例异常的
  203. # print('rec len(img_list)', len(img_list))
  204. temp_list = []
  205. for img in img_list:
  206. if img.shape[0] == 0 or img.shape[1] == 0 \
  207. or img.shape[0] >= 10000 or img.shape[1] >= 10000 \
  208. or img.shape[1] / img.shape[0] <= 0.5 \
  209. or img.shape[1] / img.shape[0] >= 100:
  210. # print('rec img.shape[1] / img.shape[0] <= 0.5', img.shape)
  211. continue
  212. temp_list.append(img)
  213. if not temp_list:
  214. return None, 0
  215. img_list = temp_list
  216. # 按比例排序
  217. width_list = []
  218. i = 0
  219. for img in img_list:
  220. width_list.append(img.shape[1] / float(img.shape[0]))
  221. # Sorting can speed up the recognition process
  222. indices = np.argsort(np.array(width_list))
  223. # 分批预测
  224. # rec_res = []
  225. rec_res = [['', 0.0]] * img_num
  226. batch_num = self.rec_batch_num
  227. elapse = 0
  228. batch_list = []
  229. for beg_img_no in range(0, img_num, batch_num):
  230. end_img_no = min(img_num, beg_img_no + batch_num)
  231. norm_img_batch = []
  232. max_wh_ratio = 0
  233. # 取这个batch中比例最大的
  234. for ino in range(beg_img_no, end_img_no):
  235. # h, w = img_list[ino].shape[0:2]
  236. h, w = img_list[indices[ino]].shape[0:2]
  237. wh_ratio = w * 1.0 / h
  238. max_wh_ratio = max(max_wh_ratio, wh_ratio)
  239. # print('max_wh_ratio', max_wh_ratio)
  240. # resize image
  241. for ino in range(beg_img_no, end_img_no):
  242. # print('img_list[indices[ino]].shape', img_list[indices[ino]].shape)
  243. norm_img = self.resize_norm_img(img_list[indices[ino]],
  244. max_wh_ratio)
  245. # print('norm_img.shape', norm_img.shape)
  246. norm_img = norm_img[np.newaxis, :]
  247. norm_img_batch.append(norm_img)
  248. norm_img_batch = np.concatenate(norm_img_batch)
  249. norm_img_batch = norm_img_batch.copy()
  250. # 预测
  251. # starttime = time.time()
  252. # # 当图片很长时,降低batch,防止爆内存
  253. # # print('norm_img_batch.shape', norm_img_batch.shape)
  254. # preds = []
  255. # if norm_img_batch.shape[-1] >= 400:
  256. # if norm_img_batch.shape[-1] <= 1000:
  257. # mini_batch_size = 4
  258. # elif norm_img_batch.shape[-1] <= 3000:
  259. # mini_batch_size = 2
  260. # else:
  261. # mini_batch_size = 1
  262. # for bi in range(0, norm_img_batch.shape[0], mini_batch_size):
  263. # sub_batch = norm_img_batch[bi:bi+mini_batch_size]
  264. # sub_preds = self.predict(sub_batch)
  265. # preds.append(sub_preds)
  266. # # print('type(sub_preds), sub_preds.shape', type(sub_preds), sub_preds.shape)
  267. # preds = np.concatenate(preds, axis=0)
  268. # else:
  269. # preds = self.predict(norm_img_batch)
  270. # # print('type(preds), preds.shape', type(preds), preds.shape)
  271. #
  272. # # 后处理
  273. # rec_result = self.postprocess_op(preds)
  274. # for rno in range(len(rec_result)):
  275. # rec_res[indices[beg_img_no + rno]] = rec_result[rno]
  276. # elapse += time.time() - starttime
  277. # 根据长度,动态batch
  278. if norm_img_batch.shape[-1] >= 400:
  279. if norm_img_batch.shape[-1] <= 1000:
  280. mini_batch_size = 4
  281. elif norm_img_batch.shape[-1] <= 3000:
  282. mini_batch_size = 2
  283. else:
  284. mini_batch_size = 1
  285. sub_batch_list = []
  286. for bi in range(0, norm_img_batch.shape[0], mini_batch_size):
  287. sub_batch = norm_img_batch[bi:bi+mini_batch_size]
  288. tensor = torch.from_numpy(sub_batch).float()
  289. tensor = tensor.to(self.device)
  290. sub_batch_list.append(tensor)
  291. else:
  292. tensor = torch.from_numpy(norm_img_batch).float()
  293. tensor = tensor.to(self.device)
  294. sub_batch_list = [tensor]
  295. batch_list.append(sub_batch_list)
  296. # 预测
  297. batch_out_list = self.predict_batch(batch_list)
  298. # 后处理
  299. for bi, out in enumerate(batch_out_list):
  300. begin_img_no = bi * batch_num
  301. rec_result = self.postprocess_op(out)
  302. for ri in range(len(rec_result)):
  303. rec_res[indices[begin_img_no + ri]] = rec_result[ri]
  304. elapse += time.time() - start_time
  305. return rec_res, elapse
  306. class TextRecognizer2(object):
  307. shrink_memory_count = 0
  308. def __init__(self, args):
  309. self.rec_image_shape = [int(v) for v in args.rec_image_shape.split(",")]
  310. self.character_type = args.rec_char_type
  311. self.rec_batch_num = args.rec_batch_num
  312. self.rec_algorithm = args.rec_algorithm
  313. postprocess_params = {
  314. 'name': 'CTCLabelDecode',
  315. "character_type": args.rec_char_type,
  316. "character_dict_path": args.rec_char_dict_path,
  317. "use_space_char": args.use_space_char
  318. }
  319. self.postprocess_op = build_post_process(postprocess_params)
  320. self.args = args
  321. # self.predictor, self.input_tensor, self.output_tensors = \
  322. # utility.create_predictor(args, 'rec', logger)
  323. def resize_norm_img(self, img, max_wh_ratio):
  324. imgC, imgH, imgW = self.rec_image_shape
  325. assert imgC == img.shape[2]
  326. if self.character_type == "ch":
  327. imgW = int((32 * max_wh_ratio))
  328. h, w = img.shape[:2]
  329. ratio = w / float(h)
  330. if math.ceil(imgH * ratio) > imgW:
  331. resized_w = imgW
  332. else:
  333. resized_w = int(math.ceil(imgH * ratio))
  334. # print("predict_rec.py resize_norm_img resize shape", (resized_w, imgH))
  335. resized_image = cv2.resize(img, (resized_w, imgH))
  336. resized_image = resized_image.astype('float32')
  337. resized_image = resized_image.transpose((2, 0, 1)) / 255
  338. resized_image -= 0.5
  339. resized_image /= 0.5
  340. padding_im = np.zeros((imgC, imgH, imgW), dtype=np.float32)
  341. padding_im[:, :, 0:resized_w] = resized_image
  342. return padding_im
  343. def __call__(self, img_list):
  344. from format_convert.convert_need_interface import from_gpu_interface_redis
  345. img_num = len(img_list)
  346. # Calculate the aspect ratio of all text bars
  347. width_list = []
  348. for img in img_list:
  349. width_list.append(img.shape[1] / float(img.shape[0]))
  350. # Sorting can speed up the recognition process
  351. indices = np.argsort(np.array(width_list))
  352. rec_res = [['', 0.0]] * img_num
  353. batch_num = self.rec_batch_num
  354. elapse = 0
  355. all_gpu_time = 0
  356. for beg_img_no in range(0, img_num, batch_num):
  357. # 预处理
  358. end_img_no = min(img_num, beg_img_no + batch_num)
  359. norm_img_batch = []
  360. max_wh_ratio = 0
  361. for ino in range(beg_img_no, end_img_no):
  362. h, w = img_list[indices[ino]].shape[0:2]
  363. wh_ratio = w * 1.0 / h
  364. max_wh_ratio = max(max_wh_ratio, wh_ratio)
  365. for ino in range(beg_img_no, end_img_no):
  366. norm_img = self.resize_norm_img(img_list[indices[ino]],
  367. max_wh_ratio)
  368. norm_img = norm_img[np.newaxis, :]
  369. norm_img_batch.append(norm_img)
  370. norm_img_batch = np.concatenate(norm_img_batch)
  371. norm_img_batch = norm_img_batch.copy()
  372. starttime = time.time()
  373. # # 压缩numpy
  374. # compressed_array = io.BytesIO()
  375. # np.savez_compressed(compressed_array, norm_img_batch)
  376. # compressed_array.seek(0)
  377. # norm_img_batch = compressed_array.read()
  378. # 调用GPU接口
  379. _dict = {"inputs": norm_img_batch, "args": str(namespace_to_dict(self.args)), "md5": _global.get("md5")}
  380. result = from_gpu_interface_redis(_dict, model_type="ocr", predictor_type="rec")
  381. if judge_error_code(result):
  382. logging.error("from_gpu_interface failed! " + str(result))
  383. raise requests.exceptions.RequestException
  384. preds = result.get("preds")
  385. gpu_time = result.get("gpu_time")
  386. all_gpu_time += round(gpu_time, 2)
  387. # # 解压numpy
  388. # decompressed_array = io.BytesIO()
  389. # decompressed_array.write(preds)
  390. # decompressed_array.seek(0)
  391. # preds = np.load(decompressed_array, allow_pickle=True)['arr_0']
  392. # log("inputs.shape" + str(preds.shape))
  393. # 后处理
  394. rec_result = self.postprocess_op(preds)
  395. for rno in range(len(rec_result)):
  396. rec_res[indices[beg_img_no + rno]] = rec_result[rno]
  397. elapse += time.time() - starttime
  398. log("ocr model predict time - rec - time " + str(all_gpu_time) + " - num " + str(img_num))
  399. return rec_res, elapse
  400. def main(args):
  401. image_file_list = get_image_file_list(args.image_dir)
  402. text_recognizer = TextRecognizer(args)
  403. valid_image_file_list = []
  404. img_list = []
  405. for image_file in image_file_list:
  406. img, flag = check_and_read_gif(image_file)
  407. if not flag:
  408. img = cv2.imread(image_file)
  409. if img is None:
  410. logger.info("error in loading image:{}".format(image_file))
  411. continue
  412. valid_image_file_list.append(image_file)
  413. img_list.append(img)
  414. try:
  415. rec_res, predict_time = text_recognizer(img_list)
  416. except:
  417. logger.info(traceback.format_exc())
  418. logger.info(
  419. "ERROR!!!! \n"
  420. "Please read the FAQ:https://github.com/PaddlePaddle/PaddleOCR#faq \n"
  421. "If your model has tps module: "
  422. "TPS does not support variable shape.\n"
  423. "Please set --rec_image_shape='3,32,100' and --rec_char_type='en' ")
  424. exit()
  425. for ino in range(len(img_list)):
  426. logger.info("Predicts of {}:{}".format(valid_image_file_list[ino],
  427. rec_res[ino]))
  428. logger.info("Total predict time for {} images, cost: {:.3f}".format(
  429. len(img_list), predict_time))
  430. if __name__ == "__main__":
  431. main(utility.parse_args())