demo.py 3.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100
  1. import numpy as np
  2. import time
  3. import cv2
  4. import torch
  5. from torch.autograd import Variable
  6. import lib.utils.utils as utils
  7. import lib.models.crnn as crnn
  8. import lib.config.alphabets as alphabets
  9. import yaml
  10. from easydict import EasyDict as edict
  11. import argparse
  12. def parse_arg():
  13. parser = argparse.ArgumentParser(description="demo")
  14. parser.add_argument('--cfg', help='experiment configuration filename', type=str, default='lib/config/360CC_config.yaml')
  15. # parser.add_argument('--image_path', type=str, default='images/test.png', help='the path to your image')
  16. # parser.add_argument('--image_path', type=str, default='images/test_2.png', help='the path to your image')
  17. parser.add_argument('--image_path', type=str, default='images/Snipaste_2023-06-12_16-30-19.jpg', help='the path to your image')
  18. # parser.add_argument('--checkpoint', type=str, default='output/checkpoints/mixed_second_finetune_acc_97P7.pth',
  19. # help='the path to your checkpoints')
  20. # parser.add_argument('--checkpoint', type=str, default='output/360CC/crnn/2023-05-25-10-10/checkpoints/checkpoint_2_acc_0.9707.pth',
  21. # help='the path to your checkpoints')
  22. parser.add_argument('--checkpoint', type=str, default='output/360CC/crnn/2023-06-08-14-23/checkpoints/checkpoint_8_acc_0.8455.pth',
  23. help='the path to your checkpoints')
  24. args = parser.parse_args()
  25. with open(args.cfg, 'r') as f:
  26. config = yaml.load(f,Loader=yaml.FullLoader)
  27. config = edict(config)
  28. config.DATASET.ALPHABETS = alphabets.alphabet
  29. config.MODEL.NUM_CLASSES = len(config.DATASET.ALPHABETS)
  30. print("dict_num ",config.MODEL.NUM_CLASSES)
  31. return config, args
  32. def recognition(config, img, model, converter, device):
  33. # github issues: https://github.com/Sierkinhane/CRNN_Chinese_Characters_Rec/issues/211
  34. h, w = img.shape
  35. # fisrt step: resize the height and width of image to (32, x)
  36. img = cv2.resize(img, (0, 0), fx=config.MODEL.IMAGE_SIZE.H / h, fy=config.MODEL.IMAGE_SIZE.H / h, interpolation=cv2.INTER_CUBIC)
  37. # second step: keep the ratio of image's text same with training
  38. h, w = img.shape
  39. w_cur = int(img.shape[1] / (config.MODEL.IMAGE_SIZE.OW / config.MODEL.IMAGE_SIZE.W))
  40. img = cv2.resize(img, (0, 0), fx=w_cur / w, fy=1.0, interpolation=cv2.INTER_CUBIC)
  41. img = np.reshape(img, (config.MODEL.IMAGE_SIZE.H, w_cur, 1))
  42. # normalize
  43. img = img.astype(np.float32)
  44. img = (img / 255. - config.DATASET.MEAN) / config.DATASET.STD
  45. img = img.transpose([2, 0, 1])
  46. img = torch.from_numpy(img)
  47. img = img.to(device)
  48. img = img.view(1, *img.size())
  49. model.eval()
  50. preds = model(img)
  51. print(preds.shape)
  52. _, preds = preds.max(2)
  53. preds = preds.transpose(1, 0).contiguous().view(-1)
  54. preds_size = Variable(torch.IntTensor([preds.size(0)]))
  55. sim_pred = converter.decode(preds.data, preds_size.data, raw=False)
  56. print('results: {0}'.format(sim_pred))
  57. if __name__ == '__main__':
  58. config, args = parse_arg()
  59. device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')
  60. model = crnn.get_crnn(config).to(device)
  61. print('loading pretrained model from {0}'.format(args.checkpoint))
  62. checkpoint = torch.load(args.checkpoint,map_location=torch.device('cpu'))
  63. if 'state_dict' in checkpoint.keys():
  64. model.load_state_dict(checkpoint['state_dict'])
  65. else:
  66. model.load_state_dict(checkpoint)
  67. started = time.time()
  68. from image_test import AddGauss,addNoise,AddNoiseSingleChannel,tfactor
  69. img = cv2.imread(args.image_path)
  70. # img = AddNoiseSingleChannel(img)
  71. img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
  72. # img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
  73. # img = cv2.bitwise_not(img)
  74. converter = utils.strLabelConverter(config.DATASET.ALPHABETS)
  75. recognition(config, img, model, converter, device)
  76. finished = time.time()
  77. print('elapsed time: {0}'.format(finished - started))