demo.py 2.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778
  1. import numpy as np
  2. import time
  3. import cv2
  4. import torch
  5. from torch.autograd import Variable
  6. import lib.utils.utils as utils
  7. import lib.models.crnn as crnn
  8. import lib.config.alphabets as alphabets
  9. import yaml
  10. from easydict import EasyDict as edict
  11. import argparse
  12. def parse_arg():
  13. parser = argparse.ArgumentParser(description="demo")
  14. parser.add_argument('--cfg', help='experiment configuration filename', type=str, default='lib/config/360CC_config.yaml')
  15. parser.add_argument('--image_path', type=str, default='images/test.png', help='the path to your image')
  16. parser.add_argument('--checkpoint', type=str, default='output/checkpoints/mixed_second_finetune_acc_97P7.pth',
  17. help='the path to your checkpoints')
  18. args = parser.parse_args()
  19. with open(args.cfg, 'r') as f:
  20. config = yaml.load(f)
  21. config = edict(config)
  22. config.DATASET.ALPHABETS = alphabets.alphabet
  23. config.MODEL.NUM_CLASSES = len(config.DATASET.ALPHABETS)
  24. return config, args
  25. def recognition(config, img, model, converter, device):
  26. # ratio resize
  27. w_cur = int(img.shape[1] / (config.MODEL.IMAGE_SIZE.OW / config.MODEL.IMAGE_SIZE.W))
  28. h, w = img.shape
  29. img = cv2.resize(img, (0, 0), fx=w_cur / w, fy=config.MODEL.IMAGE_SIZE.H / h, interpolation=cv2.INTER_CUBIC)
  30. img = np.reshape(img, (config.MODEL.IMAGE_SIZE.H, w_cur, 1))
  31. # normalize
  32. img = img.astype(np.float32)
  33. img = (img / 255. - config.DATASET.MEAN) / config.DATASET.STD
  34. img = img.transpose([2, 0, 1])
  35. img = torch.from_numpy(img)
  36. img = img.to(device)
  37. img = img.view(1, *img.size())
  38. model.eval()
  39. preds = model(img)
  40. _, preds = preds.max(2)
  41. preds = preds.transpose(1, 0).contiguous().view(-1)
  42. preds_size = Variable(torch.IntTensor([preds.size(0)]))
  43. sim_pred = converter.decode(preds.data, preds_size.data, raw=False)
  44. print('results: {0}'.format(sim_pred))
  45. if __name__ == '__main__':
  46. config, args = parse_arg()
  47. device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')
  48. model = crnn.get_crnn(config).to(device)
  49. print('loading pretrained model from {0}'.format(args.checkpoint))
  50. model.load_state_dict(torch.load(args.checkpoint))
  51. started = time.time()
  52. img = cv2.imread(args.image_path)
  53. img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
  54. converter = utils.strLabelConverter(config.DATASET.ALPHABETS)
  55. recognition(config, img, model, converter, device)
  56. finished = time.time()
  57. print('elapsed time: {0}'.format(finished - started))