CommonModules.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303
  1. import torch
  2. from torch import nn
  3. from torch.nn import functional as F
  4. from collections import OrderedDict
  5. class HSwish(nn.Module):
  6. def forward(self, x):
  7. out = x * F.relu6(x + 3, inplace=True) / 6
  8. return out
  9. class HardSigmoid(nn.Module):
  10. def __init__(self, type):
  11. super().__init__()
  12. self.type = type
  13. def forward(self, x):
  14. if self.type == 'paddle':
  15. x = (1.2 * x).add_(3.).clamp_(0., 6.).div_(6.)
  16. else:
  17. x = F.relu6(x + 3, inplace=True) / 6
  18. F.hardsigmoid()
  19. return x
  20. class HSigmoid(nn.Module):
  21. def forward(self, x):
  22. x = (1.2 * x).add_(3.).clamp_(0., 6.).div_(6.)
  23. return x
  24. class ConvBNACT(nn.Module):
  25. def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, groups=1, act=None):
  26. super().__init__()
  27. self.conv = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size,
  28. stride=stride, padding=padding, groups=groups,
  29. bias=False)
  30. self.bn = nn.BatchNorm2d(out_channels)
  31. if act == 'relu':
  32. self.act = nn.ReLU()
  33. elif act == 'hard_swish':
  34. self.act = HSwish()
  35. elif act is None:
  36. self.act = None
  37. def forward(self, x):
  38. x = self.conv(x)
  39. x = self.bn(x)
  40. if self.act is not None:
  41. x = self.act(x)
  42. return x
  43. class SEBlock(nn.Module):
  44. def __init__(self, in_channels, ratio=4):
  45. super().__init__()
  46. num_mid_filter = in_channels // ratio
  47. self.pool = nn.AdaptiveAvgPool2d(1)
  48. self.conv1 = nn.Conv2d(in_channels=in_channels, out_channels=num_mid_filter, kernel_size=1, bias=True)
  49. self.relu1 = nn.ReLU(inplace=True)
  50. self.conv2 = nn.Conv2d(in_channels=num_mid_filter, kernel_size=1, out_channels=in_channels, bias=True)
  51. # self.relu2 = HardSigmoid(hsigmoid_type)
  52. self.relu2 = HardSigmoid(type = 'paddle')
  53. def forward(self, x):
  54. attn = self.pool(x)
  55. attn = self.conv1(attn)
  56. attn = self.relu1(attn)
  57. attn = self.conv2(attn)
  58. attn = self.relu2(attn)
  59. return x * attn
  60. def global_avg_pool(x: torch.Tensor) -> torch.Tensor:
  61. N, C, H, W = x.shape
  62. y = x.view([N, C, H * W]).contiguous()
  63. y = y.sum(2)
  64. y = torch.unsqueeze(y, 2)
  65. y = torch.unsqueeze(y, 3)
  66. y = y / (H * W)
  67. return y
  68. def global_max_pool(x: torch.Tensor) -> torch.Tensor:
  69. N, C, H, W = x.shape
  70. y = x.view([N, C, H * W]).contiguous()
  71. y = torch.max(y, 2).values
  72. y = torch.unsqueeze(y, 2)
  73. y = torch.unsqueeze(y, 3)
  74. return y
  75. class ChannelAttention(nn.Module):
  76. def __init__(self, channels, ratio=16):
  77. super(ChannelAttention, self).__init__()
  78. self.avg_pool = nn.AdaptiveAvgPool2d((1, 1))
  79. self.max_pool = nn.AdaptiveMaxPool2d((1, 1))
  80. self.fc = nn.Sequential(nn.Conv2d(channels, channels // ratio, 1, bias=False),
  81. nn.ReLU(),
  82. nn.Conv2d(channels // ratio, channels, 1, bias=False), )
  83. self.sigmoid = nn.Sigmoid()
  84. def forward(self, x):
  85. y1 = self.avg_pool(x)
  86. y1 = self.fc(y1)
  87. y2 = self.max_pool(x)
  88. y2 = self.fc(y2)
  89. y = self.sigmoid(y1 + y2)
  90. return y
  91. class SpatialAttention(nn.Module):
  92. def __init__(self, kernel_size=7):
  93. super(SpatialAttention, self).__init__()
  94. assert kernel_size in (3, 7), 'kernel size must be 3 or 7'
  95. padding = 3 if kernel_size == 7 else 1
  96. self.conv1 = nn.Conv2d(2, 1, kernel_size, padding=padding, bias=False)
  97. self.sigmoid = nn.Sigmoid()
  98. def forward(self, x):
  99. avg_out = torch.mean(x, dim=1, keepdim=True)
  100. max_out, _ = torch.max(x, dim=1, keepdim=True)
  101. x = torch.cat([avg_out, max_out], dim=1)
  102. x = self.conv1(x)
  103. return self.sigmoid(x)
  104. class CBAM(nn.Module):
  105. def __init__(self, in_channels, ratio=16):
  106. super(CBAM, self).__init__()
  107. self.cam = ChannelAttention(in_channels, ratio)
  108. self.sam = SpatialAttention()
  109. def forward(self, x):
  110. x = x * self.cam(x)
  111. x = x * self.sam(x)
  112. return x
  113. class eca_layer(nn.Module):
  114. def __init__(self, channel, k_size=3):
  115. super(eca_layer, self).__init__()
  116. self.avg_pool = nn.AdaptiveAvgPool2d(1)
  117. self.conv = nn.Conv1d(1, 1, kernel_size=k_size, padding=(k_size - 1) // 2, bias=False)
  118. self.sigmoid = nn.Sigmoid()
  119. def forward(self, x):
  120. # feature descriptor on the global spatial information
  121. y = self.avg_pool(x)
  122. # Two different branches of ECA module
  123. y = self.conv(y.squeeze(-1).transpose(-1, -2)).transpose(-1, -2).unsqueeze(-1)
  124. # Multi-scale information fusion
  125. y = self.sigmoid(y)
  126. return x * y.expand_as(x)
  127. class ScaleChannelAttention(nn.Module):
  128. def __init__(self, in_planes, out_planes, num_features, init_weight=True):
  129. super(ScaleChannelAttention, self).__init__()
  130. self.avgpool = nn.AdaptiveAvgPool2d(1)
  131. print(self.avgpool)
  132. self.fc1 = nn.Conv2d(in_planes, out_planes, 1, bias=False)
  133. self.bn = nn.BatchNorm2d(out_planes)
  134. self.fc2 = nn.Conv2d(out_planes, num_features, 1, bias=False)
  135. if init_weight:
  136. self._initialize_weights()
  137. def _initialize_weights(self):
  138. for m in self.modules():
  139. if isinstance(m, nn.Conv2d):
  140. nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
  141. if m.bias is not None:
  142. nn.init.constant_(m.bias, 0)
  143. if isinstance(m ,nn.BatchNorm2d):
  144. nn.init.constant_(m.weight, 1)
  145. nn.init.constant_(m.bias, 0)
  146. def forward(self, x):
  147. global_x = self.avgpool(x)
  148. global_x = self.fc1(global_x)
  149. global_x = F.relu(self.bn(global_x))
  150. global_x = self.fc2(global_x)
  151. global_x = F.softmax(global_x, 1)
  152. return global_x
  153. class ScaleChannelSpatialAttention(nn.Module):
  154. def __init__(self, in_planes, out_planes, num_features, init_weight=True):
  155. super(ScaleChannelSpatialAttention, self).__init__()
  156. self.channel_wise = nn.Sequential(
  157. nn.AdaptiveAvgPool2d(1),
  158. nn.Conv2d(in_planes, out_planes , 1, bias=False),
  159. # nn.BatchNorm2d(out_planes),
  160. nn.ReLU(),
  161. nn.Conv2d(out_planes, in_planes, 1, bias=False)
  162. )
  163. self.spatial_wise = nn.Sequential(
  164. #Nx1xHxW
  165. nn.Conv2d(1, 1, 3, bias=False, padding=1),
  166. nn.ReLU(),
  167. nn.Conv2d(1, 1, 1, bias=False),
  168. nn.Sigmoid()
  169. )
  170. self.attention_wise = nn.Sequential(
  171. nn.Conv2d(in_planes, num_features, 1, bias=False),
  172. nn.Sigmoid()
  173. )
  174. if init_weight:
  175. self._initialize_weights()
  176. def _initialize_weights(self):
  177. for m in self.modules():
  178. if isinstance(m, nn.Conv2d):
  179. nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
  180. if m.bias is not None:
  181. nn.init.constant_(m.bias, 0)
  182. if isinstance(m ,nn.BatchNorm2d):
  183. nn.init.constant_(m.weight, 1)
  184. nn.init.constant_(m.bias, 0)
  185. def forward(self, x):
  186. # global_x = self.avgpool(x)
  187. #shape Nx4x1x1
  188. global_x = self.channel_wise(x).sigmoid()
  189. #shape: NxCxHxW
  190. global_x = global_x + x
  191. #shape:Nx1xHxW
  192. x = torch.mean(global_x, dim=1, keepdim=True)
  193. global_x = self.spatial_wise(x) + global_x
  194. global_x = self.attention_wise(global_x)
  195. return global_x
  196. class ScaleSpatialAttention(nn.Module):
  197. def __init__(self, in_planes, out_planes, num_features, init_weight=True):
  198. super(ScaleSpatialAttention, self).__init__()
  199. self.spatial_wise = nn.Sequential(
  200. #Nx1xHxW
  201. nn.Conv2d(1, 1, 3, bias=False, padding=1),
  202. nn.ReLU(),
  203. nn.Conv2d(1, 1, 1, bias=False),
  204. nn.Sigmoid()
  205. )
  206. self.attention_wise = nn.Sequential(
  207. nn.Conv2d(in_planes, num_features, 1, bias=False),
  208. nn.Sigmoid()
  209. )
  210. if init_weight:
  211. self._initialize_weights()
  212. def _initialize_weights(self):
  213. for m in self.modules():
  214. if isinstance(m, nn.Conv2d):
  215. nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
  216. if m.bias is not None:
  217. nn.init.constant_(m.bias, 0)
  218. if isinstance(m ,nn.BatchNorm2d):
  219. nn.init.constant_(m.weight, 1)
  220. nn.init.constant_(m.bias, 0)
  221. def forward(self, x):
  222. global_x = torch.mean(x, dim=1, keepdim=True)
  223. global_x = self.spatial_wise(global_x) + x
  224. global_x = self.attention_wise(global_x)
  225. return global_x
  226. class ScaleFeatureSelection(nn.Module):
  227. def __init__(self, in_channels, inter_channels , out_features_num=4, attention_type='scale_spatial'):
  228. super(ScaleFeatureSelection, self).__init__()
  229. self.in_channels=in_channels
  230. self.inter_channels = inter_channels
  231. self.out_features_num = out_features_num
  232. self.conv = nn.Conv2d(in_channels, inter_channels, 3, padding=1)
  233. self.type = attention_type
  234. if self.type == 'scale_spatial':
  235. self.enhanced_attention = ScaleSpatialAttention(inter_channels, inter_channels//4, out_features_num)
  236. elif self.type == 'scale_channel_spatial':
  237. self.enhanced_attention = ScaleChannelSpatialAttention(inter_channels, inter_channels // 4, out_features_num)
  238. elif self.type == 'scale_channel':
  239. self.enhanced_attention = ScaleChannelAttention(inter_channels, inter_channels//2, out_features_num)
  240. def _initialize_weights(self, m):
  241. classname = m.__class__.__name__
  242. if classname.find('Conv') != -1:
  243. nn.init.kaiming_normal_(m.weight.data)
  244. elif classname.find('BatchNorm') != -1:
  245. m.weight.data.fill_(1.)
  246. m.bias.data.fill_(1e-4)
  247. def forward(self, concat_x, features_list):
  248. concat_x = self.conv(concat_x)
  249. score = self.enhanced_attention(concat_x)
  250. assert len(features_list) == self.out_features_num
  251. if self.type not in ['scale_channel_spatial', 'scale_spatial']:
  252. shape = features_list[0].shape[2:]
  253. score = F.interpolate(score, size=shape, mode='bilinear')
  254. x = []
  255. for i in range(self.out_features_num):
  256. x.append(score[:, i:i+1] * features_list[i])
  257. return torch.cat(x, dim=1)