# -*- coding: utf-8 -*- # @Time : 2020/6/16 10:57 # @Author : zhoujun import os import sys import pathlib # 将 torchocr路径加到python陆经里 __dir__ = pathlib.Path(os.path.abspath(__file__)) import numpy as np sys.path.append(str(__dir__)) sys.path.append(str(__dir__.parent.parent)) import torch from torch import nn from torchocr.networks import build_model from torchocr.datasets.RecDataSet import RecDataProcess from torchocr.utils import CTCLabelConverter class RecInfer: def __init__(self, model_path, batch_size=16): ckpt = torch.load(model_path, map_location='cpu') cfg = ckpt['cfg'] self.model = build_model(cfg['model']) state_dict = {} for k, v in ckpt['state_dict'].items(): state_dict[k.replace('module.', '')] = v self.model.load_state_dict(state_dict) self.device = torch.device('cuda:1' if torch.cuda.is_available() else 'cpu') self.model.to(self.device) self.model.eval() self.process = RecDataProcess(cfg['dataset']['train']['dataset']) self.converter = CTCLabelConverter(cfg['dataset']['alphabet']) # self.converter = CTCLabelConverter("C:\\Users\\Administrator\\Desktop\\OCR_pytorch\\PytorchOCR1\\char_std_7782.txt") self.batch_size = batch_size def predict(self, imgs): # 预处理根据训练来 if not isinstance(imgs,list): imgs = [imgs] imgs = [self.process.normalize_img(self.process.resize_with_specific_height(img)) for img in imgs] widths = np.array([img.shape[1] for img in imgs]) idxs = np.argsort(widths) txts = [] for idx in range(0, len(imgs), self.batch_size): batch_idxs = idxs[idx:min(len(imgs), idx+self.batch_size)] batch_imgs = [self.process.width_pad_img(imgs[idx], imgs[batch_idxs[-1]].shape[1]) for idx in batch_idxs] batch_imgs = np.stack(batch_imgs) tensor = torch.from_numpy(batch_imgs.transpose([0,3, 1, 2])).float() tensor = tensor.to(self.device) with torch.no_grad(): out = self.model(tensor) # print(out) # out[1] 为最后输出 out = out[1] out = out.softmax(dim=2) out = out.cpu().numpy() txts.extend([self.converter.decode(np.expand_dims(txt, 0)) for txt in out]) #按输入图像的顺序排序 idxs = np.argsort(idxs) out_txts = [txts[idx] for idx in idxs] return out_txts def init_args(): import argparse parser = argparse.ArgumentParser(description='PytorchOCR infer') # parser.add_argument('--model_path', required=True, type=str, help='rec model path') parser.add_argument('--model_path', required=False, type=str, default="C:\\Users\\Administrator\\Desktop\\OCR_pytorch\\PytorchOCR1\\best.pth", help='rec model path') # parser.add_argument('--img_path', required=True, type=str, help='img path for predict') parser.add_argument('--img_path', required=False, type=str, default="C:\\Users\\Administrator\\Desktop\\OCR_pytorch\\PytorchOCR1\\test_image/Snipaste_2023-08-21_17-07-58.jpg", help='img path for predict') args = parser.parse_args() return args if __name__ == '__main__': import cv2 import re from unicodedata import normalize # args = init_args() # img = cv2.imread(args.img_path) model_path = '/data2/znj/PytorchOCR/tools/output/CRNN/checkpoint_resnet3/best.pth' model = RecInfer(model_path) cnt = 0 right_cnt = 0 error_cnt = 0 with open("/data2/znj/ocr_data/image2.txt",mode='r') as f: for line in f.readlines(): line = line.strip() # line_split = line.split(" ") # line_split = re.split(" ",line,maxsplit=2) iamge_path,line_split2 = re.split(" ",line,maxsplit=1) text, box = re.split(" \[\[",line_split2,maxsplit=1) box = '[['+box # if len(line_split)==3 : try: if True : # iamge_path,text,box = line_split img = cv2.imread(iamge_path) img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) bbox = eval(box) x1 = int(min([i[0] for i in bbox])) x2 = int(max([i[0] for i in bbox])) y1 = int(min([i[1] for i in bbox])) y2 = int(max([i[1] for i in bbox])) img = img[y1:y2, x1:x2] out = model.predict(img) out = out[0][0][0] cnt += 1 # if out==text: if normalize('NFKD', out)==normalize('NFKD', text): right_cnt += 1 else: if not re.search("\s",out) and len(out)>0: print(iamge_path+" "+text+" "+box+" rec_res:"+out) # print(out+" -> "+ text) except: error_cnt += 1 pass # if cnt>= 500000: # break if cnt-right_cnt>= 400000: break print('count_num:',cnt) print('right_num:',right_cnt) print('right_%:',right_cnt/cnt) print('process_error_cnt:',error_cnt) pass