# -*- coding: utf-8 -*- # @Time : 2020/6/16 10:57 # @Author : zhoujun import os import sys import pathlib # 将 torchocr路径加到python陆经里 __dir__ = pathlib.Path(os.path.abspath(__file__)) import numpy as np sys.path.append(str(__dir__)) sys.path.append(str(__dir__.parent.parent)) import torch from torch import nn from torchocr.networks import build_model from torchocr.datasets.RecDataSet import RecDataProcess from torchocr.utils import CTCLabelConverter class RecInfer: def __init__(self, model_path, batch_size=16): ckpt = torch.load(model_path, map_location='cpu') cfg = ckpt['cfg'] self.model = build_model(cfg['model']) state_dict = {} for k, v in ckpt['state_dict'].items(): state_dict[k.replace('module.', '')] = v self.model.load_state_dict(state_dict) self.device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') self.model.to(self.device) self.model.eval() self.process = RecDataProcess(cfg['dataset']['train']['dataset']) # self.converter = CTCLabelConverter(cfg['dataset']['alphabet']) self.converter = CTCLabelConverter("C:\\Users\\Administrator\\Desktop\\OCR_pytorch\\PytorchOCR1\\char_std_7782.txt") self.batch_size = batch_size def predict(self, imgs): # 预处理根据训练来 if not isinstance(imgs,list): imgs = [imgs] imgs = [self.process.normalize_img(self.process.resize_with_specific_height(img)) for img in imgs] widths = np.array([img.shape[1] for img in imgs]) idxs = np.argsort(widths) txts = [] for idx in range(0, len(imgs), self.batch_size): batch_idxs = idxs[idx:min(len(imgs), idx+self.batch_size)] batch_imgs = [self.process.width_pad_img(imgs[idx], imgs[batch_idxs[-1]].shape[1]) for idx in batch_idxs] batch_imgs = np.stack(batch_imgs) tensor = torch.from_numpy(batch_imgs.transpose([0,3, 1, 2])).float() tensor = tensor.to(self.device) with torch.no_grad(): out = self.model(tensor) # print(out) # out[1] 为最后输出 out = out out = out.softmax(dim=2) out = out.cpu().numpy() txts.extend([self.converter.decode(np.expand_dims(txt, 0)) for txt in out]) #按输入图像的顺序排序 idxs = np.argsort(idxs) out_txts = [txts[idx] for idx in idxs] return out_txts def init_args(): import argparse parser = argparse.ArgumentParser(description='PytorchOCR infer') # parser.add_argument('--model_path', required=True, type=str, help='rec model path') parser.add_argument('--model_path', required=False, type=str, default="C:\\Users\\Administrator\\Desktop\\OCR_pytorch\\PytorchOCR1\\best2.pth", help='rec model path') # parser.add_argument('--img_path', required=True, type=str, help='img path for predict') parser.add_argument('--img_path', required=False, type=str, default="C:\\Users\\Administrator\\Desktop\\OCR_pytorch\\PytorchOCR1\\test_image/Snipaste_2023-09-08_10-45-17.jpg", help='img path for predict') args = parser.parse_args() return args if __name__ == '__main__': import cv2 args = init_args() img = cv2.imread(args.img_path) img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) model = RecInfer(args.model_path) out = model.predict(img) print(out)