DetModel.py 2.9 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182
  1. # -*- coding: utf-8 -*-
  2. # @Time : 2020/5/21 14:23
  3. # @Author : zhoujun
  4. from torch import nn
  5. from addict import Dict as AttrDict
  6. from torchocr.networks.backbones.DetMobilenetV3 import MobileNetV3
  7. from torchocr.networks.backbones.DetResNetvd import ResNet
  8. from torchocr.networks.necks.DB_fpn import DB_fpn, RSEFPN, LKPAN
  9. from torchocr.networks.necks.FCE_Fpn import FCEFPN
  10. from torchocr.networks.necks.pse_fpn import PSEFpn
  11. from torchocr.networks.necks.DB_ASF import DB_Asf
  12. from torchocr.networks.heads.DetDbHead import DBHead
  13. from torchocr.networks.heads.FCEHead import FCEHead
  14. from torchocr.networks.heads.DetPseHead import PseHead
  15. from torchocr.networks.backbones.DetGhostNet import GhostNet
  16. from torchocr.networks.backbones.Transformer import *
  17. from torchocr.networks.backbones.ConvNext import ConvNeXt
  18. backbone_dict = {'MobileNetV3': MobileNetV3,
  19. 'ResNet': ResNet,
  20. 'GhostNet': GhostNet,
  21. 'SwinTransformer': SwinTransformer,
  22. 'ConvNeXt': ConvNeXt
  23. }
  24. neck_dict = {'DB_fpn': DB_fpn,
  25. 'pse_fpn': PSEFpn,
  26. 'ASF': DB_Asf,
  27. 'RSEFPN': RSEFPN,
  28. 'LKPAN': LKPAN,
  29. 'FCEFPN': FCEFPN
  30. }
  31. head_dict = {'DBHead': DBHead,
  32. 'PseHead': PseHead,
  33. 'FCEHead': FCEHead
  34. }
  35. class DetModel(nn.Module):
  36. def __init__(self, config):
  37. super().__init__()
  38. assert 'in_channels' in config, 'in_channels must in model config'
  39. backbone_type = config.backbone.pop('type')
  40. assert backbone_type in backbone_dict, f'backbone.type must in {backbone_dict}'
  41. self.backbone = backbone_dict[backbone_type](config.in_channels, **config.backbone)
  42. neck_type = config.neck.pop('type')
  43. assert neck_type in neck_dict, f'neck.type must in {neck_dict}'
  44. self.neck = neck_dict[neck_type](self.backbone.out_channels, **config.neck)
  45. head_type = config.head.pop('type')
  46. assert head_type in head_dict, f'head.type must in {head_dict}'
  47. self.head = head_dict[head_type](self.neck.out_channels, **config.head)
  48. self.name = f'DetModel_{backbone_type}_{neck_type}_{head_type}'
  49. def forward(self, x):
  50. x = self.backbone(x)
  51. x = self.neck(x)
  52. x = self.head(x)
  53. return x
  54. if __name__ == '__main__':
  55. import torch
  56. # db_config = AttrDict(
  57. # in_channels=3,
  58. # backbone=AttrDict(type='MobileNetV3', layers=50, model_name='large',pretrained=True),
  59. # neck=AttrDict(type='FPN', out_channels=256),
  60. # head=AttrDict(type='DBHead')
  61. # )
  62. # x = torch.zeros(1, 3, 640, 640)
  63. # model = DetModel(db_config)
  64. db_config = AttrDict(
  65. in_channels=3,
  66. backbone=AttrDict(type='ResNet', layers=50, pretrained=True),
  67. neck=AttrDict(type='pse_fpn', out_channels=256),
  68. head=AttrDict(type='PseHead', H=640, W=640, scale=1)
  69. )
  70. x = torch.zeros(1, 3, 640, 640)
  71. model = DetModel(db_config)