12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788 |
- # -*- coding: utf-8 -*-
- # @Time : 2020/6/16 10:57
- # @Author : zhoujun
- import os
- import sys
- import pathlib
- # 将 torchocr路径加到python陆经里
- __dir__ = pathlib.Path(os.path.abspath(__file__))
- sys.path.append(str(__dir__))
- sys.path.append(str(__dir__.parent.parent))
- import torch
- from torch import nn
- from torchvision import transforms
- from torchocr.networks import build_model
- from torchocr.datasets.det_modules import ResizeShortSize, ResizeFixedSize
- from torchocr.postprocess import build_post_process
- class DetInfer:
- def __init__(self, model_path):
- ckpt = torch.load(model_path, map_location='cpu')
- cfg = ckpt['cfg']
- self.model = build_model(cfg['model'])
- state_dict = {}
- for k, v in ckpt['state_dict'].items():
- state_dict[k.replace('module.', '')] = v
- self.model.load_state_dict(state_dict)
- self.device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
- self.model.to(self.device)
- self.model.eval()
- self.resize = ResizeFixedSize(736, False)
- self.post_process = build_post_process(cfg['post_process'])
- self.transform = transforms.Compose([
- transforms.ToTensor(),
- transforms.Normalize(mean=cfg['dataset']['train']['dataset']['mean'], std=cfg['dataset']['train']['dataset']['std'])
- ])
- def predict(self, img):
- # 预处理根据训练来
- data = {'img': img, 'shape': [img.shape[:2]], 'text_polys': []}
- data = self.resize(data)
- tensor = self.transform(data['img'])
- tensor = tensor.unsqueeze(dim=0)
- tensor = tensor.to(self.device)
- with torch.no_grad():
- out = self.model(tensor)
- out = out.cpu().numpy()
- box_list, score_list = self.post_process(out, data['shape'])
- box_list, score_list = box_list[0], score_list[0]
- if len(box_list) > 0:
- idx = [x.sum() > 0 for x in box_list]
- box_list = [box_list[i] for i, v in enumerate(idx) if v]
- score_list = [score_list[i] for i, v in enumerate(idx) if v]
- else:
- box_list, score_list = [], []
- return box_list, score_list
- def init_args():
- import argparse
- parser = argparse.ArgumentParser(description='PytorchOCR infer')
- parser.add_argument('--model_path', required=True, type=str, help='rec model path')
- parser.add_argument('--img_path', required=True, type=str, help='img dir for predict')
- args = parser.parse_args()
- return args
- if __name__ == '__main__':
- import cv2
- import time
- from matplotlib import pyplot as plt
- from torchocr.utils import draw_bbox
- args = init_args()
- model = DetInfer(args.model_path)
- names = next(os.walk(args.img_path))[2]
- st = time.time()
- for name in names:
- path = os.path.join(args.img_path, name)
- img = cv2.imread(path)
- box_list, score_list = model.predict(img)
- out_path = os.path.join(args.img_path, 'res', name)
- img = draw_bbox(img, box_list)
- cv2.imwrite(out_path[:-4] + '_res.jpg', img)
- print((time.time() - st) / len(names))
|