rec_infer.py 3.4 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192
  1. # -*- coding: utf-8 -*-
  2. # @Time : 2020/6/16 10:57
  3. # @Author : zhoujun
  4. import os
  5. import sys
  6. import pathlib
  7. # 将 torchocr路径加到python陆经里
  8. __dir__ = pathlib.Path(os.path.abspath(__file__))
  9. import numpy as np
  10. sys.path.append(str(__dir__))
  11. sys.path.append(str(__dir__.parent.parent))
  12. import torch
  13. from torch import nn
  14. from torchocr.networks import build_model
  15. from torchocr.datasets.RecDataSet import RecDataProcess
  16. from torchocr.utils import CTCLabelConverter
  17. class RecInfer:
  18. def __init__(self, model_path, batch_size=16):
  19. ckpt = torch.load(model_path, map_location='cpu')
  20. cfg = ckpt['cfg']
  21. self.model = build_model(cfg['model'])
  22. state_dict = {}
  23. for k, v in ckpt['state_dict'].items():
  24. state_dict[k.replace('module.', '')] = v
  25. self.model.load_state_dict(state_dict)
  26. self.device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
  27. self.model.to(self.device)
  28. self.model.eval()
  29. self.process = RecDataProcess(cfg['dataset']['train']['dataset'])
  30. # self.converter = CTCLabelConverter(cfg['dataset']['alphabet'])
  31. self.converter = CTCLabelConverter("C:\\Users\\Administrator\\Desktop\\OCR_pytorch\\PytorchOCR1\\char_std_7782.txt")
  32. self.batch_size = batch_size
  33. def predict(self, imgs):
  34. # 预处理根据训练来
  35. if not isinstance(imgs,list):
  36. imgs = [imgs]
  37. imgs = [self.process.normalize_img(self.process.resize_with_specific_height(img)) for img in imgs]
  38. widths = np.array([img.shape[1] for img in imgs])
  39. idxs = np.argsort(widths)
  40. txts = []
  41. for idx in range(0, len(imgs), self.batch_size):
  42. batch_idxs = idxs[idx:min(len(imgs), idx+self.batch_size)]
  43. batch_imgs = [self.process.width_pad_img(imgs[idx], imgs[batch_idxs[-1]].shape[1]) for idx in batch_idxs]
  44. batch_imgs = np.stack(batch_imgs)
  45. tensor = torch.from_numpy(batch_imgs.transpose([0,3, 1, 2])).float()
  46. tensor = tensor.to(self.device)
  47. with torch.no_grad():
  48. out = self.model(tensor)
  49. # print(out)
  50. # out[1] 为最后输出
  51. out = out
  52. out = out.softmax(dim=2)
  53. out = out.cpu().numpy()
  54. txts.extend([self.converter.decode(np.expand_dims(txt, 0)) for txt in out])
  55. #按输入图像的顺序排序
  56. idxs = np.argsort(idxs)
  57. out_txts = [txts[idx] for idx in idxs]
  58. return out_txts
  59. def init_args():
  60. import argparse
  61. parser = argparse.ArgumentParser(description='PytorchOCR infer')
  62. # parser.add_argument('--model_path', required=True, type=str, help='rec model path')
  63. parser.add_argument('--model_path', required=False, type=str,
  64. default="C:\\Users\\Administrator\\Desktop\\OCR_pytorch\\PytorchOCR1\\best2.pth", help='rec model path')
  65. # parser.add_argument('--img_path', required=True, type=str, help='img path for predict')
  66. parser.add_argument('--img_path', required=False, type=str,
  67. default="C:\\Users\\Administrator\\Desktop\\OCR_pytorch\\PytorchOCR1\\test_image/Snipaste_2023-09-08_10-45-17.jpg", help='img path for predict')
  68. args = parser.parse_args()
  69. return args
  70. if __name__ == '__main__':
  71. import cv2
  72. args = init_args()
  73. img = cv2.imread(args.img_path)
  74. img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
  75. model = RecInfer(args.model_path)
  76. out = model.predict(img)
  77. print(out)