DetMobilenetV3.py 6.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175
  1. from __future__ import absolute_import
  2. from __future__ import division
  3. from __future__ import print_function
  4. import logging
  5. import os
  6. import torch
  7. from torch import nn
  8. from torchocr.networks.CommonModules import ConvBNACT, SEBlock
  9. from collections import OrderedDict
  10. class ResidualUnit(nn.Module):
  11. def __init__(self, num_in_filter, num_mid_filter, num_out_filter, stride, kernel_size, act=None, use_se=False):
  12. super().__init__()
  13. self.conv0 = ConvBNACT(in_channels=num_in_filter, out_channels=num_mid_filter, kernel_size=1, stride=1,
  14. padding=0, act=act)
  15. self.conv1 = ConvBNACT(in_channels=num_mid_filter, out_channels=num_mid_filter, kernel_size=kernel_size,
  16. stride=stride,
  17. padding=int((kernel_size - 1) // 2), act=act, groups=num_mid_filter)
  18. if use_se:
  19. self.se = SEBlock(in_channels=num_mid_filter)
  20. else:
  21. self.se = None
  22. self.conv2 = ConvBNACT(in_channels=num_mid_filter, out_channels=num_out_filter, kernel_size=1, stride=1,
  23. padding=0)
  24. self.not_add = num_in_filter != num_out_filter or stride != 1
  25. def forward(self, x):
  26. y = self.conv0(x)
  27. y = self.conv1(y)
  28. if self.se is not None:
  29. y = self.se(y)
  30. y = self.conv2(y)
  31. if not self.not_add:
  32. y = x + y
  33. return y
  34. class MobileNetV3(nn.Module):
  35. def __init__(self, in_channels, pretrained=True, **kwargs):
  36. """
  37. the MobilenetV3 backbone network for detection module.
  38. Args:
  39. params(dict): the super parameters for build network
  40. """
  41. super().__init__()
  42. self.scale = kwargs.get('scale', 0.5)
  43. model_name = kwargs.get('model_name', 'large')
  44. self.disable_se = kwargs.get('disable_se', True)
  45. self.inplanes = 16
  46. if model_name == "large":
  47. self.cfg = [
  48. # k, exp, c, se, nl, s,
  49. [3, 16, 16, False, 'relu', 1],
  50. [3, 64, 24, False, 'relu', 2],
  51. [3, 72, 24, False, 'relu', 1],
  52. [5, 72, 40, True, 'relu', 2],
  53. [5, 120, 40, True, 'relu', 1],
  54. [5, 120, 40, True, 'relu', 1],
  55. [3, 240, 80, False, 'hard_swish', 2],
  56. [3, 200, 80, False, 'hard_swish', 1],
  57. [3, 184, 80, False, 'hard_swish', 1],
  58. [3, 184, 80, False, 'hard_swish', 1],
  59. [3, 480, 112, True, 'hard_swish', 1],
  60. [3, 672, 112, True, 'hard_swish', 1],
  61. [5, 672, 160, True, 'hard_swish', 2],
  62. [5, 960, 160, True, 'hard_swish', 1],
  63. [5, 960, 160, True, 'hard_swish', 1],
  64. ]
  65. self.cls_ch_squeeze = 960
  66. self.cls_ch_expand = 1280
  67. elif model_name == "small":
  68. self.cfg = [
  69. # k, exp, c, se, nl, s,
  70. [3, 16, 16, True, 'relu', 2],
  71. [3, 72, 24, False, 'relu', 2],
  72. [3, 88, 24, False, 'relu', 1],
  73. [5, 96, 40, True, 'hard_swish', 2],
  74. [5, 240, 40, True, 'hard_swish', 1],
  75. [5, 240, 40, True, 'hard_swish', 1],
  76. [5, 120, 48, True, 'hard_swish', 1],
  77. [5, 144, 48, True, 'hard_swish', 1],
  78. [5, 288, 96, True, 'hard_swish', 2],
  79. [5, 576, 96, True, 'hard_swish', 1],
  80. [5, 576, 96, True, 'hard_swish', 1],
  81. ]
  82. self.cls_ch_squeeze = 576
  83. self.cls_ch_expand = 1280
  84. else:
  85. raise NotImplementedError("mode[" + model_name +
  86. "_model] is not implemented!")
  87. supported_scale = [0.35, 0.5, 0.75, 1.0, 1.25]
  88. assert self.scale in supported_scale, \
  89. "supported scale are {} but input scale is {}".format(supported_scale, self.scale)
  90. scale = self.scale
  91. inplanes = self.inplanes
  92. cfg = self.cfg
  93. cls_ch_squeeze = self.cls_ch_squeeze
  94. # conv1
  95. self.conv1 = ConvBNACT(in_channels=in_channels,
  96. out_channels=self.make_divisible(inplanes * scale),
  97. kernel_size=3,
  98. stride=2,
  99. padding=1,
  100. groups=1,
  101. act='hard_swish')
  102. i = 0
  103. inplanes = self.make_divisible(inplanes * scale)
  104. self.stages = nn.ModuleList()
  105. block_list = []
  106. self.out_channels = []
  107. for layer_cfg in cfg:
  108. se = layer_cfg[3] and not self.disable_se
  109. if layer_cfg[5] == 2 and i > 2:
  110. self.out_channels.append(inplanes)
  111. self.stages.append(nn.Sequential(*block_list))
  112. block_list = []
  113. block = ResidualUnit(num_in_filter=inplanes,
  114. num_mid_filter=self.make_divisible(scale * layer_cfg[1]),
  115. num_out_filter=self.make_divisible(scale * layer_cfg[2]),
  116. act=layer_cfg[4],
  117. stride=layer_cfg[5],
  118. kernel_size=layer_cfg[0],
  119. use_se=se)
  120. block_list.append(block)
  121. inplanes = self.make_divisible(scale * layer_cfg[2])
  122. i += 1
  123. block_list.append(ConvBNACT(
  124. in_channels=inplanes,
  125. out_channels=self.make_divisible(scale * cls_ch_squeeze),
  126. kernel_size=1,
  127. stride=1,
  128. padding=0,
  129. groups=1,
  130. act='hard_swish'))
  131. self.stages.append(nn.Sequential(*block_list))
  132. self.out_channels.append(self.make_divisible(scale * cls_ch_squeeze))
  133. if pretrained:
  134. ckpt_path = f'./weights/MobileNetV3_{model_name}_x{str(scale).replace(".", "_")}.pth'
  135. logger = logging.getLogger('torchocr')
  136. if os.path.exists(ckpt_path):
  137. logger.info('load imagenet weights')
  138. dic_ckpt = torch.load(ckpt_path)
  139. filtered_dict = OrderedDict()
  140. for key in dic_ckpt.keys():
  141. flag = key.find('se') != -1
  142. if self.disable_se and flag:
  143. continue
  144. filtered_dict[key] = dic_ckpt[key]
  145. self.load_state_dict(filtered_dict)
  146. else:
  147. logger.info(f'{ckpt_path} not exists')
  148. def make_divisible(self, v, divisor=8, min_value=None):
  149. if min_value is None:
  150. min_value = divisor
  151. new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
  152. if new_v < 0.9 * v:
  153. new_v += divisor
  154. return new_v
  155. def forward(self, x):
  156. x = self.conv1(x)
  157. out = []
  158. for stage in self.stages:
  159. x = stage(x)
  160. out.append(x)
  161. return out