RecModel.py 1.5 KB

1234567891011121314151617181920212223242526272829303132333435363738394041
  1. # -*- coding: utf-8 -*-
  2. # @Time : 2020/5/16 11:18
  3. # @Author : zhoujun
  4. from torch import nn
  5. from torchocr.networks.backbones.RecMobileNetV3 import MobileNetV3
  6. from torchocr.networks.backbones.RecResNetvd import ResNet
  7. from torchocr.networks.necks.RNN import SequenceEncoder, Im2Seq
  8. from torchocr.networks.heads.RecCTCHead import CTC
  9. backbone_dict = {'MobileNetV3': MobileNetV3, 'ResNet': ResNet}
  10. neck_dict = {'PPaddleRNN': SequenceEncoder, 'None': Im2Seq}
  11. head_dict = {'CTC': CTC}
  12. class RecModel(nn.Module):
  13. def __init__(self, config):
  14. super().__init__()
  15. assert 'in_channels' in config, 'in_channels must in model config'
  16. backbone_type = config.backbone.pop('type')
  17. assert backbone_type in backbone_dict, f'backbone.type must in {backbone_dict}'
  18. self.backbone = backbone_dict[backbone_type](config.in_channels, **config.backbone)
  19. neck_type = config.neck.pop('type')
  20. assert neck_type in neck_dict, f'neck.type must in {neck_dict}'
  21. self.neck = neck_dict[neck_type](self.backbone.out_channels, **config.neck)
  22. head_type = config.head.pop('type')
  23. assert head_type in head_dict, f'head.type must in {head_dict}'
  24. self.head = head_dict[head_type](self.neck.out_channels, **config.head)
  25. self.name = f'RecModel_{backbone_type}_{neck_type}_{head_type}'
  26. def forward(self, x):
  27. x = self.backbone(x)
  28. x = self.neck(x)
  29. features = x
  30. x = self.head(x)
  31. # return features,x
  32. return x