demo.py 2.9 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485
  1. import numpy as np
  2. import time
  3. import cv2
  4. import torch
  5. from torch.autograd import Variable
  6. import lib.utils.utils as utils
  7. import lib.models.crnn as crnn
  8. import lib.config.alphabets as alphabets
  9. import yaml
  10. from easydict import EasyDict as edict
  11. import argparse
  12. def parse_arg():
  13. parser = argparse.ArgumentParser(description="demo")
  14. parser.add_argument('--cfg', help='experiment configuration filename', type=str, default='lib/config/360CC_config.yaml')
  15. parser.add_argument('--image_path', type=str, default='images/test.png', help='the path to your image')
  16. parser.add_argument('--checkpoint', type=str, default='output/checkpoints/mixed_second_finetune_acc_97P7.pth',
  17. help='the path to your checkpoints')
  18. args = parser.parse_args()
  19. with open(args.cfg, 'r') as f:
  20. config = yaml.load(f)
  21. config = edict(config)
  22. config.DATASET.ALPHABETS = alphabets.alphabet
  23. config.MODEL.NUM_CLASSES = len(config.DATASET.ALPHABETS)
  24. return config, args
  25. def recognition(config, img, model, converter, device):
  26. # github issues: https://github.com/Sierkinhane/CRNN_Chinese_Characters_Rec/issues/211
  27. h, w = img.shape
  28. # fisrt step: resize the height and width of image to (32, x)
  29. img = cv2.resize(img, (0, 0), fx=config.MODEL.IMAGE_SIZE.H / h, fy=config.MODEL.IMAGE_SIZE.H / h, interpolation=cv2.INTER_CUBIC)
  30. # second step: keep the ratio of image's text same with training
  31. h, w = img.shape
  32. w_cur = int(img.shape[1] / (config.MODEL.IMAGE_SIZE.OW / config.MODEL.IMAGE_SIZE.W))
  33. img = cv2.resize(img, (0, 0), fx=w_cur / w, fy=1.0, interpolation=cv2.INTER_CUBIC)
  34. img = np.reshape(img, (config.MODEL.IMAGE_SIZE.H, w_cur, 1))
  35. # normalize
  36. img = img.astype(np.float32)
  37. img = (img / 255. - config.DATASET.MEAN) / config.DATASET.STD
  38. img = img.transpose([2, 0, 1])
  39. img = torch.from_numpy(img)
  40. img = img.to(device)
  41. img = img.view(1, *img.size())
  42. model.eval()
  43. preds = model(img)
  44. print(preds.shape)
  45. _, preds = preds.max(2)
  46. preds = preds.transpose(1, 0).contiguous().view(-1)
  47. preds_size = Variable(torch.IntTensor([preds.size(0)]))
  48. sim_pred = converter.decode(preds.data, preds_size.data, raw=False)
  49. print('results: {0}'.format(sim_pred))
  50. if __name__ == '__main__':
  51. config, args = parse_arg()
  52. device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')
  53. model = crnn.get_crnn(config).to(device)
  54. print('loading pretrained model from {0}'.format(args.checkpoint))
  55. model.load_state_dict(torch.load(args.checkpoint))
  56. started = time.time()
  57. img = cv2.imread(args.image_path)
  58. img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
  59. converter = utils.strLabelConverter(config.DATASET.ALPHABETS)
  60. recognition(config, img, model, converter, device)
  61. finished = time.time()
  62. print('elapsed time: {0}'.format(finished - started))