RecMobileNetV3.py 5.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141
  1. from __future__ import absolute_import
  2. from __future__ import division
  3. from __future__ import print_function
  4. from torch import nn
  5. from torchocr.networks.CommonModules import ConvBNACT, SEBlock
  6. class ResidualUnit(nn.Module):
  7. def __init__(self, num_in_filter, num_mid_filter, num_out_filter, stride, kernel_size, act=None, use_se=False):
  8. super().__init__()
  9. self.expand_conv = ConvBNACT(in_channels=num_in_filter, out_channels=num_mid_filter, kernel_size=1, stride=1,
  10. padding=0, act=act)
  11. self.bottleneck_conv = ConvBNACT(in_channels=num_mid_filter, out_channels=num_mid_filter, kernel_size=kernel_size,
  12. stride=stride,
  13. padding=int((kernel_size - 1) // 2), act=act, groups=num_mid_filter)
  14. if use_se:
  15. self.se = SEBlock(in_channels=num_mid_filter)
  16. else:
  17. self.se = None
  18. self.linear_conv = ConvBNACT(in_channels=num_mid_filter, out_channels=num_out_filter, kernel_size=1, stride=1,
  19. padding=0)
  20. self.not_add = num_in_filter != num_out_filter or stride != 1
  21. def forward(self, x):
  22. y = self.expand_conv(x)
  23. y = self.bottleneck_conv(y)
  24. if self.se is not None:
  25. y = self.se(y)
  26. y = self.linear_conv(y)
  27. if not self.not_add:
  28. y = x + y
  29. return y
  30. class MobileNetV3(nn.Module):
  31. def __init__(self, in_channels=3, **kwargs):
  32. super().__init__()
  33. self.scale = kwargs.get('scale', 0.5)
  34. model_name = kwargs.get('model_name', 'small')
  35. self.inplanes = 16
  36. if model_name == "large":
  37. self.cfg = [
  38. # k, exp, c, se, nl, s,
  39. [3, 16, 16, False, 'relu', 1],
  40. [3, 64, 24, False, 'relu', (2, 1)],
  41. [3, 72, 24, False, 'relu', 1],
  42. [5, 72, 40, True, 'relu', (2, 1)],
  43. [5, 120, 40, True, 'relu', 1],
  44. [5, 120, 40, True, 'relu', 1],
  45. [3, 240, 80, False, 'hard_swish', 1],
  46. [3, 200, 80, False, 'hard_swish', 1],
  47. [3, 184, 80, False, 'hard_swish', 1],
  48. [3, 184, 80, False, 'hard_swish', 1],
  49. [3, 480, 112, True, 'hard_swish', 1],
  50. [3, 672, 112, True, 'hard_swish', 1],
  51. [5, 672, 160, True, 'hard_swish', (2, 1)],
  52. [5, 960, 160, True, 'hard_swish', 1],
  53. [5, 960, 160, True, 'hard_swish', 1],
  54. ]
  55. self.cls_ch_squeeze = 960
  56. self.cls_ch_expand = 1280
  57. elif model_name == "small":
  58. self.cfg = [
  59. # k, exp, c, se, nl, s,
  60. [3, 16, 16, True, 'relu', (1, 1)],
  61. [3, 72, 24, False, 'relu', (2, 1)],
  62. [3, 88, 24, False, 'relu', 1],
  63. [5, 96, 40, True, 'hard_swish', (2, 1)],
  64. [5, 240, 40, True, 'hard_swish', 1],
  65. [5, 240, 40, True, 'hard_swish', 1],
  66. [5, 120, 48, True, 'hard_swish', 1],
  67. [5, 144, 48, True, 'hard_swish', 1],
  68. [5, 288, 96, True, 'hard_swish', (2, 1)],
  69. [5, 576, 96, True, 'hard_swish', 1],
  70. [5, 576, 96, True, 'hard_swish', 1],
  71. ]
  72. self.cls_ch_squeeze = 576
  73. self.cls_ch_expand = 1280
  74. else:
  75. raise NotImplementedError("mode[" + model_name +
  76. "_model] is not implemented!")
  77. supported_scale = [0.35, 0.5, 0.75, 1.0, 1.25]
  78. assert self.scale in supported_scale, "supported scale are {} but input scale is {}".format(supported_scale,
  79. self.scale)
  80. scale = self.scale
  81. inplanes = self.inplanes
  82. cfg = self.cfg
  83. cls_ch_squeeze = self.cls_ch_squeeze
  84. # conv1
  85. self.conv1 = ConvBNACT(in_channels=in_channels,
  86. out_channels=self.make_divisible(inplanes * scale),
  87. kernel_size=3,
  88. stride=2,
  89. padding=1,
  90. groups=1,
  91. act='hard_swish')
  92. inplanes = self.make_divisible(inplanes * scale)
  93. block_list = []
  94. for layer_cfg in cfg:
  95. block = ResidualUnit(num_in_filter=inplanes,
  96. num_mid_filter=self.make_divisible(scale * layer_cfg[1]),
  97. num_out_filter=self.make_divisible(scale * layer_cfg[2]),
  98. act=layer_cfg[4],
  99. stride=layer_cfg[5],
  100. kernel_size=layer_cfg[0],
  101. use_se=layer_cfg[3])
  102. block_list.append(block)
  103. inplanes = self.make_divisible(scale * layer_cfg[2])
  104. self.blocks = nn.Sequential(*block_list)
  105. self.conv2 = ConvBNACT(in_channels=inplanes,
  106. out_channels=self.make_divisible(scale * cls_ch_squeeze),
  107. kernel_size=1,
  108. stride=1,
  109. padding=0,
  110. groups=1,
  111. act='hard_swish')
  112. self.pool = nn.MaxPool2d(kernel_size=2, stride=2, padding=0)
  113. self.out_channels = self.make_divisible(scale * cls_ch_squeeze)
  114. def make_divisible(self, v, divisor=8, min_value=None):
  115. if min_value is None:
  116. min_value = divisor
  117. new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
  118. if new_v < 0.9 * v:
  119. new_v += divisor
  120. return new_v
  121. def forward(self, x):
  122. x = self.conv1(x)
  123. x = self.blocks(x)
  124. x = self.conv2(x)
  125. x = self.pool(x)
  126. return x