rec_infer_att_test.py 5.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141
  1. # -*- coding: utf-8 -*-
  2. # @Time : 2020/6/16 10:57
  3. # @Author : zhoujun
  4. import os
  5. import sys
  6. import pathlib
  7. # 将 torchocr路径加到python陆经里
  8. __dir__ = pathlib.Path(os.path.abspath(__file__))
  9. import numpy as np
  10. sys.path.append(str(__dir__))
  11. sys.path.append(str(__dir__.parent.parent))
  12. import torch
  13. from torch import nn
  14. from torchocr.networks import build_model
  15. from torchocr.datasets.RecDataSet import RecDataProcess
  16. from torchocr.utils import CTCLabelConverter
  17. class RecInfer:
  18. def __init__(self, model_path, batch_size=16):
  19. ckpt = torch.load(model_path, map_location='cpu')
  20. cfg = ckpt['cfg']
  21. self.model = build_model(cfg['model'])
  22. state_dict = {}
  23. for k, v in ckpt['state_dict'].items():
  24. state_dict[k.replace('module.', '')] = v
  25. self.model.load_state_dict(state_dict)
  26. self.device = torch.device('cuda:1' if torch.cuda.is_available() else 'cpu')
  27. self.model.to(self.device)
  28. self.model.eval()
  29. self.process = RecDataProcess(cfg['dataset']['train']['dataset'])
  30. self.converter = CTCLabelConverter(cfg['dataset']['alphabet'])
  31. # self.converter = CTCLabelConverter("C:\\Users\\Administrator\\Desktop\\OCR_pytorch\\PytorchOCR1\\char_std_7782.txt")
  32. self.batch_size = batch_size
  33. def predict(self, imgs):
  34. # 预处理根据训练来
  35. if not isinstance(imgs,list):
  36. imgs = [imgs]
  37. imgs = [self.process.normalize_img(self.process.resize_with_specific_height(img)) for img in imgs]
  38. widths = np.array([img.shape[1] for img in imgs])
  39. idxs = np.argsort(widths)
  40. txts = []
  41. for idx in range(0, len(imgs), self.batch_size):
  42. batch_idxs = idxs[idx:min(len(imgs), idx+self.batch_size)]
  43. batch_imgs = [self.process.width_pad_img(imgs[idx], imgs[batch_idxs[-1]].shape[1]) for idx in batch_idxs]
  44. batch_imgs = np.stack(batch_imgs)
  45. tensor = torch.from_numpy(batch_imgs.transpose([0,3, 1, 2])).float()
  46. tensor = tensor.to(self.device)
  47. with torch.no_grad():
  48. out = self.model(tensor)
  49. # print(out)
  50. # out[1] 为最后输出
  51. out = out[1]
  52. out = out.softmax(dim=2)
  53. out = out.cpu().numpy()
  54. txts.extend([self.converter.decode(np.expand_dims(txt, 0)) for txt in out])
  55. #按输入图像的顺序排序
  56. idxs = np.argsort(idxs)
  57. out_txts = [txts[idx] for idx in idxs]
  58. return out_txts
  59. def init_args():
  60. import argparse
  61. parser = argparse.ArgumentParser(description='PytorchOCR infer')
  62. # parser.add_argument('--model_path', required=True, type=str, help='rec model path')
  63. parser.add_argument('--model_path', required=False, type=str,
  64. default="C:\\Users\\Administrator\\Desktop\\OCR_pytorch\\PytorchOCR1\\best.pth", help='rec model path')
  65. # parser.add_argument('--img_path', required=True, type=str, help='img path for predict')
  66. parser.add_argument('--img_path', required=False, type=str,
  67. default="C:\\Users\\Administrator\\Desktop\\OCR_pytorch\\PytorchOCR1\\test_image/Snipaste_2023-08-21_17-07-58.jpg", help='img path for predict')
  68. args = parser.parse_args()
  69. return args
  70. if __name__ == '__main__':
  71. import cv2
  72. import re
  73. from unicodedata import normalize
  74. # args = init_args()
  75. # img = cv2.imread(args.img_path)
  76. model_path = '/data2/znj/PytorchOCR/tools/output/CRNN/checkpoint_resnet3/best.pth'
  77. model = RecInfer(model_path)
  78. cnt = 0
  79. right_cnt = 0
  80. error_cnt = 0
  81. with open("/data2/znj/ocr_data/image2.txt",mode='r') as f:
  82. for line in f.readlines():
  83. line = line.strip()
  84. # line_split = line.split(" ")
  85. # line_split = re.split(" ",line,maxsplit=2)
  86. iamge_path,line_split2 = re.split(" ",line,maxsplit=1)
  87. text, box = re.split(" \[\[",line_split2,maxsplit=1)
  88. box = '[['+box
  89. # if len(line_split)==3 :
  90. try:
  91. if True :
  92. # iamge_path,text,box = line_split
  93. img = cv2.imread(iamge_path)
  94. img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
  95. bbox = eval(box)
  96. x1 = int(min([i[0] for i in bbox]))
  97. x2 = int(max([i[0] for i in bbox]))
  98. y1 = int(min([i[1] for i in bbox]))
  99. y2 = int(max([i[1] for i in bbox]))
  100. img = img[y1:y2, x1:x2]
  101. out = model.predict(img)
  102. out = out[0][0][0]
  103. cnt += 1
  104. # if out==text:
  105. if normalize('NFKD', out)==normalize('NFKD', text):
  106. right_cnt += 1
  107. else:
  108. if not re.search("\s",out) and len(out)>0:
  109. print(iamge_path+" "+text+" "+box+" rec_res:"+out)
  110. # print(out+" -> "+ text)
  111. except:
  112. error_cnt += 1
  113. pass
  114. # if cnt>= 500000:
  115. # break
  116. if cnt-right_cnt>= 400000:
  117. break
  118. print('count_num:',cnt)
  119. print('right_num:',right_cnt)
  120. print('right_%:',right_cnt/cnt)
  121. print('process_error_cnt:',error_cnt)
  122. pass