import numpy as np import time import cv2 import torch from torch.autograd import Variable import lib.utils.utils as utils import lib.models.crnn as crnn import lib.config.alphabets as alphabets import yaml from easydict import EasyDict as edict import argparse def parse_arg(): parser = argparse.ArgumentParser(description="demo") parser.add_argument('--cfg', help='experiment configuration filename', type=str, default='lib/config/360CC_config.yaml') parser.add_argument('--image_path', type=str, default='images/test.png', help='the path to your image') parser.add_argument('--checkpoint', type=str, default='output/checkpoints/mixed_second_finetune_acc_97P7.pth', help='the path to your checkpoints') args = parser.parse_args() with open(args.cfg, 'r') as f: config = yaml.load(f) config = edict(config) config.DATASET.ALPHABETS = alphabets.alphabet config.MODEL.NUM_CLASSES = len(config.DATASET.ALPHABETS) return config, args def recognition(config, img, model, converter, device): # ratio resize w_cur = int(img.shape[1] / (config.MODEL.IMAGE_SIZE.OW / config.MODEL.IMAGE_SIZE.W)) h, w = img.shape img = cv2.resize(img, (0, 0), fx=w_cur / w, fy=config.MODEL.IMAGE_SIZE.H / h, interpolation=cv2.INTER_CUBIC) img = np.reshape(img, (config.MODEL.IMAGE_SIZE.H, w_cur, 1)) # normalize img = img.astype(np.float32) img = (img / 255. - config.DATASET.MEAN) / config.DATASET.STD img = img.transpose([2, 0, 1]) img = torch.from_numpy(img) img = img.to(device) img = img.view(1, *img.size()) model.eval() preds = model(img) _, preds = preds.max(2) preds = preds.transpose(1, 0).contiguous().view(-1) preds_size = Variable(torch.IntTensor([preds.size(0)])) sim_pred = converter.decode(preds.data, preds_size.data, raw=False) print('results: {0}'.format(sim_pred)) if __name__ == '__main__': config, args = parse_arg() device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu') model = crnn.get_crnn(config).to(device) print('loading pretrained model from {0}'.format(args.checkpoint)) model.load_state_dict(torch.load(args.checkpoint)) started = time.time() img = cv2.imread(args.image_path) img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) converter = utils.strLabelConverter(config.DATASET.ALPHABETS) recognition(config, img, model, converter, device) finished = time.time() print('elapsed time: {0}'.format(finished - started))