det_infer.py 3.1 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788
  1. # -*- coding: utf-8 -*-
  2. # @Time : 2020/6/16 10:57
  3. # @Author : zhoujun
  4. import os
  5. import sys
  6. import pathlib
  7. # 将 torchocr路径加到python陆经里
  8. __dir__ = pathlib.Path(os.path.abspath(__file__))
  9. sys.path.append(str(__dir__))
  10. sys.path.append(str(__dir__.parent.parent))
  11. import torch
  12. from torch import nn
  13. from torchvision import transforms
  14. from torchocr.networks import build_model
  15. from torchocr.datasets.det_modules import ResizeShortSize, ResizeFixedSize
  16. from torchocr.postprocess import build_post_process
  17. class DetInfer:
  18. def __init__(self, model_path):
  19. ckpt = torch.load(model_path, map_location='cpu')
  20. cfg = ckpt['cfg']
  21. self.model = build_model(cfg['model'])
  22. state_dict = {}
  23. for k, v in ckpt['state_dict'].items():
  24. state_dict[k.replace('module.', '')] = v
  25. self.model.load_state_dict(state_dict)
  26. self.device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
  27. self.model.to(self.device)
  28. self.model.eval()
  29. self.resize = ResizeFixedSize(736, False)
  30. self.post_process = build_post_process(cfg['post_process'])
  31. self.transform = transforms.Compose([
  32. transforms.ToTensor(),
  33. transforms.Normalize(mean=cfg['dataset']['train']['dataset']['mean'], std=cfg['dataset']['train']['dataset']['std'])
  34. ])
  35. def predict(self, img):
  36. # 预处理根据训练来
  37. data = {'img': img, 'shape': [img.shape[:2]], 'text_polys': []}
  38. data = self.resize(data)
  39. tensor = self.transform(data['img'])
  40. tensor = tensor.unsqueeze(dim=0)
  41. tensor = tensor.to(self.device)
  42. with torch.no_grad():
  43. out = self.model(tensor)
  44. out = out.cpu().numpy()
  45. box_list, score_list = self.post_process(out, data['shape'])
  46. box_list, score_list = box_list[0], score_list[0]
  47. if len(box_list) > 0:
  48. idx = [x.sum() > 0 for x in box_list]
  49. box_list = [box_list[i] for i, v in enumerate(idx) if v]
  50. score_list = [score_list[i] for i, v in enumerate(idx) if v]
  51. else:
  52. box_list, score_list = [], []
  53. return box_list, score_list
  54. def init_args():
  55. import argparse
  56. parser = argparse.ArgumentParser(description='PytorchOCR infer')
  57. parser.add_argument('--model_path', required=True, type=str, help='rec model path')
  58. parser.add_argument('--img_path', required=True, type=str, help='img dir for predict')
  59. args = parser.parse_args()
  60. return args
  61. if __name__ == '__main__':
  62. import cv2
  63. import time
  64. from matplotlib import pyplot as plt
  65. from torchocr.utils import draw_bbox
  66. args = init_args()
  67. model = DetInfer(args.model_path)
  68. names = next(os.walk(args.img_path))[2]
  69. st = time.time()
  70. for name in names:
  71. path = os.path.join(args.img_path, name)
  72. img = cv2.imread(path)
  73. box_list, score_list = model.predict(img)
  74. out_path = os.path.join(args.img_path, 'res', name)
  75. img = draw_bbox(img, box_list)
  76. cv2.imwrite(out_path[:-4] + '_res.jpg', img)
  77. print((time.time() - st) / len(names))