123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141 |
- # -*- coding: utf-8 -*-
- # @Time : 2020/6/16 10:57
- # @Author : zhoujun
- import os
- import sys
- import pathlib
- # 将 torchocr路径加到python陆经里
- __dir__ = pathlib.Path(os.path.abspath(__file__))
- import numpy as np
- sys.path.append(str(__dir__))
- sys.path.append(str(__dir__.parent.parent))
- import torch
- from torch import nn
- from torchocr.networks import build_model
- from torchocr.datasets.RecDataSet import RecDataProcess
- from torchocr.utils import CTCLabelConverter
- class RecInfer:
- def __init__(self, model_path, batch_size=16):
- ckpt = torch.load(model_path, map_location='cpu')
- cfg = ckpt['cfg']
- self.model = build_model(cfg['model'])
- state_dict = {}
- for k, v in ckpt['state_dict'].items():
- state_dict[k.replace('module.', '')] = v
- self.model.load_state_dict(state_dict)
- self.device = torch.device('cuda:1' if torch.cuda.is_available() else 'cpu')
- self.model.to(self.device)
- self.model.eval()
- self.process = RecDataProcess(cfg['dataset']['train']['dataset'])
- self.converter = CTCLabelConverter(cfg['dataset']['alphabet'])
- # self.converter = CTCLabelConverter("C:\\Users\\Administrator\\Desktop\\OCR_pytorch\\PytorchOCR1\\char_std_7782.txt")
- self.batch_size = batch_size
- def predict(self, imgs):
- # 预处理根据训练来
- if not isinstance(imgs,list):
- imgs = [imgs]
- imgs = [self.process.normalize_img(self.process.resize_with_specific_height(img)) for img in imgs]
- widths = np.array([img.shape[1] for img in imgs])
- idxs = np.argsort(widths)
- txts = []
- for idx in range(0, len(imgs), self.batch_size):
- batch_idxs = idxs[idx:min(len(imgs), idx+self.batch_size)]
- batch_imgs = [self.process.width_pad_img(imgs[idx], imgs[batch_idxs[-1]].shape[1]) for idx in batch_idxs]
- batch_imgs = np.stack(batch_imgs)
- tensor = torch.from_numpy(batch_imgs.transpose([0,3, 1, 2])).float()
- tensor = tensor.to(self.device)
- with torch.no_grad():
- out = self.model(tensor)
- # print(out)
- # out[1] 为最后输出
- out = out[1]
- out = out.softmax(dim=2)
- out = out.cpu().numpy()
- txts.extend([self.converter.decode(np.expand_dims(txt, 0)) for txt in out])
- #按输入图像的顺序排序
- idxs = np.argsort(idxs)
- out_txts = [txts[idx] for idx in idxs]
- return out_txts
- def init_args():
- import argparse
- parser = argparse.ArgumentParser(description='PytorchOCR infer')
- # parser.add_argument('--model_path', required=True, type=str, help='rec model path')
- parser.add_argument('--model_path', required=False, type=str,
- default="C:\\Users\\Administrator\\Desktop\\OCR_pytorch\\PytorchOCR1\\best.pth", help='rec model path')
- # parser.add_argument('--img_path', required=True, type=str, help='img path for predict')
- parser.add_argument('--img_path', required=False, type=str,
- default="C:\\Users\\Administrator\\Desktop\\OCR_pytorch\\PytorchOCR1\\test_image/Snipaste_2023-08-21_17-07-58.jpg", help='img path for predict')
- args = parser.parse_args()
- return args
- if __name__ == '__main__':
- import cv2
- import re
- from unicodedata import normalize
- # args = init_args()
- # img = cv2.imread(args.img_path)
- model_path = '/data2/znj/PytorchOCR/tools/output/CRNN/checkpoint_resnet3/best.pth'
- model = RecInfer(model_path)
- cnt = 0
- right_cnt = 0
- error_cnt = 0
- with open("/data2/znj/ocr_data/image2.txt",mode='r') as f:
- for line in f.readlines():
- line = line.strip()
- # line_split = line.split(" ")
- # line_split = re.split(" ",line,maxsplit=2)
- iamge_path,line_split2 = re.split(" ",line,maxsplit=1)
- text, box = re.split(" \[\[",line_split2,maxsplit=1)
- box = '[['+box
- # if len(line_split)==3 :
- try:
- if True :
- # iamge_path,text,box = line_split
- img = cv2.imread(iamge_path)
- img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
- bbox = eval(box)
- x1 = int(min([i[0] for i in bbox]))
- x2 = int(max([i[0] for i in bbox]))
- y1 = int(min([i[1] for i in bbox]))
- y2 = int(max([i[1] for i in bbox]))
- img = img[y1:y2, x1:x2]
- out = model.predict(img)
- out = out[0][0][0]
- cnt += 1
- # if out==text:
- if normalize('NFKD', out)==normalize('NFKD', text):
- right_cnt += 1
- else:
- if not re.search("\s",out) and len(out)>0:
- print(iamge_path+" "+text+" "+box+" rec_res:"+out)
- # print(out+" -> "+ text)
- except:
- error_cnt += 1
- pass
- # if cnt>= 500000:
- # break
- if cnt-right_cnt>= 400000:
- break
- print('count_num:',cnt)
- print('right_num:',right_cnt)
- print('right_%:',right_cnt/cnt)
- print('process_error_cnt:',error_cnt)
- pass
|