1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192 |
- # -*- coding: utf-8 -*-
- # @Time : 2020/6/16 10:57
- # @Author : zhoujun
- import os
- import sys
- import pathlib
- # 将 torchocr路径加到python陆经里
- __dir__ = pathlib.Path(os.path.abspath(__file__))
- import numpy as np
- sys.path.append(str(__dir__))
- sys.path.append(str(__dir__.parent.parent))
- import torch
- from torch import nn
- from torchocr.networks import build_model
- from torchocr.datasets.RecDataSet import RecDataProcess
- from torchocr.utils import CTCLabelConverter
- class RecInfer:
- def __init__(self, model_path, batch_size=16):
- ckpt = torch.load(model_path, map_location='cpu')
- cfg = ckpt['cfg']
- self.model = build_model(cfg['model'])
- state_dict = {}
- for k, v in ckpt['state_dict'].items():
- state_dict[k.replace('module.', '')] = v
- self.model.load_state_dict(state_dict)
- self.device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
- self.model.to(self.device)
- self.model.eval()
- self.process = RecDataProcess(cfg['dataset']['train']['dataset'])
- # self.converter = CTCLabelConverter(cfg['dataset']['alphabet'])
- self.converter = CTCLabelConverter("C:\\Users\\Administrator\\Desktop\\OCR_pytorch\\PytorchOCR1\\char_std_7782.txt")
- self.batch_size = batch_size
- def predict(self, imgs):
- # 预处理根据训练来
- if not isinstance(imgs,list):
- imgs = [imgs]
- imgs = [self.process.normalize_img(self.process.resize_with_specific_height(img)) for img in imgs]
- widths = np.array([img.shape[1] for img in imgs])
- idxs = np.argsort(widths)
- txts = []
- for idx in range(0, len(imgs), self.batch_size):
- batch_idxs = idxs[idx:min(len(imgs), idx+self.batch_size)]
- batch_imgs = [self.process.width_pad_img(imgs[idx], imgs[batch_idxs[-1]].shape[1]) for idx in batch_idxs]
- batch_imgs = np.stack(batch_imgs)
- tensor = torch.from_numpy(batch_imgs.transpose([0,3, 1, 2])).float()
- tensor = tensor.to(self.device)
- with torch.no_grad():
- out = self.model(tensor)
- # print(out)
- # out[1] 为最后输出
- out = out
- out = out.softmax(dim=2)
- out = out.cpu().numpy()
- txts.extend([self.converter.decode(np.expand_dims(txt, 0)) for txt in out])
- #按输入图像的顺序排序
- idxs = np.argsort(idxs)
- out_txts = [txts[idx] for idx in idxs]
- return out_txts
- def init_args():
- import argparse
- parser = argparse.ArgumentParser(description='PytorchOCR infer')
- # parser.add_argument('--model_path', required=True, type=str, help='rec model path')
- parser.add_argument('--model_path', required=False, type=str,
- default="C:\\Users\\Administrator\\Desktop\\OCR_pytorch\\PytorchOCR1\\best2.pth", help='rec model path')
- # parser.add_argument('--img_path', required=True, type=str, help='img path for predict')
- parser.add_argument('--img_path', required=False, type=str,
- default="C:\\Users\\Administrator\\Desktop\\OCR_pytorch\\PytorchOCR1\\test_image/Snipaste_2023-09-08_10-45-17.jpg", help='img path for predict')
- args = parser.parse_args()
- return args
- if __name__ == '__main__':
- import cv2
- args = init_args()
- img = cv2.imread(args.img_path)
- img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
- model = RecInfer(args.model_path)
- out = model.predict(img)
- print(out)
|