import torch from torch import nn from torch.nn import functional as F from collections import OrderedDict class HSwish(nn.Module): def forward(self, x): out = x * F.relu6(x + 3, inplace=True) / 6 return out class HardSigmoid(nn.Module): def __init__(self, type): super().__init__() self.type = type def forward(self, x): if self.type == 'paddle': x = (1.2 * x).add_(3.).clamp_(0., 6.).div_(6.) else: x = F.relu6(x + 3, inplace=True) / 6 F.hardsigmoid() return x class HSigmoid(nn.Module): def forward(self, x): x = (1.2 * x).add_(3.).clamp_(0., 6.).div_(6.) return x class ConvBNACT(nn.Module): def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, groups=1, act=None): super().__init__() self.conv = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, padding=padding, groups=groups, bias=False) self.bn = nn.BatchNorm2d(out_channels) if act == 'relu': self.act = nn.ReLU() elif act == 'hard_swish': self.act = HSwish() elif act is None: self.act = None def forward(self, x): x = self.conv(x) x = self.bn(x) if self.act is not None: x = self.act(x) return x class SEBlock(nn.Module): def __init__(self, in_channels, ratio=4): super().__init__() num_mid_filter = in_channels // ratio self.pool = nn.AdaptiveAvgPool2d(1) self.conv1 = nn.Conv2d(in_channels=in_channels, out_channels=num_mid_filter, kernel_size=1, bias=True) self.relu1 = nn.ReLU(inplace=True) self.conv2 = nn.Conv2d(in_channels=num_mid_filter, kernel_size=1, out_channels=in_channels, bias=True) # self.relu2 = HardSigmoid(hsigmoid_type) self.relu2 = HardSigmoid(type = 'paddle') def forward(self, x): attn = self.pool(x) attn = self.conv1(attn) attn = self.relu1(attn) attn = self.conv2(attn) attn = self.relu2(attn) return x * attn def global_avg_pool(x: torch.Tensor) -> torch.Tensor: N, C, H, W = x.shape y = x.view([N, C, H * W]).contiguous() y = y.sum(2) y = torch.unsqueeze(y, 2) y = torch.unsqueeze(y, 3) y = y / (H * W) return y def global_max_pool(x: torch.Tensor) -> torch.Tensor: N, C, H, W = x.shape y = x.view([N, C, H * W]).contiguous() y = torch.max(y, 2).values y = torch.unsqueeze(y, 2) y = torch.unsqueeze(y, 3) return y class ChannelAttention(nn.Module): def __init__(self, channels, ratio=16): super(ChannelAttention, self).__init__() self.avg_pool = nn.AdaptiveAvgPool2d((1, 1)) self.max_pool = nn.AdaptiveMaxPool2d((1, 1)) self.fc = nn.Sequential(nn.Conv2d(channels, channels // ratio, 1, bias=False), nn.ReLU(), nn.Conv2d(channels // ratio, channels, 1, bias=False), ) self.sigmoid = nn.Sigmoid() def forward(self, x): y1 = self.avg_pool(x) y1 = self.fc(y1) y2 = self.max_pool(x) y2 = self.fc(y2) y = self.sigmoid(y1 + y2) return y class SpatialAttention(nn.Module): def __init__(self, kernel_size=7): super(SpatialAttention, self).__init__() assert kernel_size in (3, 7), 'kernel size must be 3 or 7' padding = 3 if kernel_size == 7 else 1 self.conv1 = nn.Conv2d(2, 1, kernel_size, padding=padding, bias=False) self.sigmoid = nn.Sigmoid() def forward(self, x): avg_out = torch.mean(x, dim=1, keepdim=True) max_out, _ = torch.max(x, dim=1, keepdim=True) x = torch.cat([avg_out, max_out], dim=1) x = self.conv1(x) return self.sigmoid(x) class CBAM(nn.Module): def __init__(self, in_channels, ratio=16): super(CBAM, self).__init__() self.cam = ChannelAttention(in_channels, ratio) self.sam = SpatialAttention() def forward(self, x): x = x * self.cam(x) x = x * self.sam(x) return x class eca_layer(nn.Module): def __init__(self, channel, k_size=3): super(eca_layer, self).__init__() self.avg_pool = nn.AdaptiveAvgPool2d(1) self.conv = nn.Conv1d(1, 1, kernel_size=k_size, padding=(k_size - 1) // 2, bias=False) self.sigmoid = nn.Sigmoid() def forward(self, x): # feature descriptor on the global spatial information y = self.avg_pool(x) # Two different branches of ECA module y = self.conv(y.squeeze(-1).transpose(-1, -2)).transpose(-1, -2).unsqueeze(-1) # Multi-scale information fusion y = self.sigmoid(y) return x * y.expand_as(x) class ScaleChannelAttention(nn.Module): def __init__(self, in_planes, out_planes, num_features, init_weight=True): super(ScaleChannelAttention, self).__init__() self.avgpool = nn.AdaptiveAvgPool2d(1) print(self.avgpool) self.fc1 = nn.Conv2d(in_planes, out_planes, 1, bias=False) self.bn = nn.BatchNorm2d(out_planes) self.fc2 = nn.Conv2d(out_planes, num_features, 1, bias=False) if init_weight: self._initialize_weights() def _initialize_weights(self): for m in self.modules(): if isinstance(m, nn.Conv2d): nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') if m.bias is not None: nn.init.constant_(m.bias, 0) if isinstance(m ,nn.BatchNorm2d): nn.init.constant_(m.weight, 1) nn.init.constant_(m.bias, 0) def forward(self, x): global_x = self.avgpool(x) global_x = self.fc1(global_x) global_x = F.relu(self.bn(global_x)) global_x = self.fc2(global_x) global_x = F.softmax(global_x, 1) return global_x class ScaleChannelSpatialAttention(nn.Module): def __init__(self, in_planes, out_planes, num_features, init_weight=True): super(ScaleChannelSpatialAttention, self).__init__() self.channel_wise = nn.Sequential( nn.AdaptiveAvgPool2d(1), nn.Conv2d(in_planes, out_planes , 1, bias=False), # nn.BatchNorm2d(out_planes), nn.ReLU(), nn.Conv2d(out_planes, in_planes, 1, bias=False) ) self.spatial_wise = nn.Sequential( #Nx1xHxW nn.Conv2d(1, 1, 3, bias=False, padding=1), nn.ReLU(), nn.Conv2d(1, 1, 1, bias=False), nn.Sigmoid() ) self.attention_wise = nn.Sequential( nn.Conv2d(in_planes, num_features, 1, bias=False), nn.Sigmoid() ) if init_weight: self._initialize_weights() def _initialize_weights(self): for m in self.modules(): if isinstance(m, nn.Conv2d): nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') if m.bias is not None: nn.init.constant_(m.bias, 0) if isinstance(m ,nn.BatchNorm2d): nn.init.constant_(m.weight, 1) nn.init.constant_(m.bias, 0) def forward(self, x): # global_x = self.avgpool(x) #shape Nx4x1x1 global_x = self.channel_wise(x).sigmoid() #shape: NxCxHxW global_x = global_x + x #shape:Nx1xHxW x = torch.mean(global_x, dim=1, keepdim=True) global_x = self.spatial_wise(x) + global_x global_x = self.attention_wise(global_x) return global_x class ScaleSpatialAttention(nn.Module): def __init__(self, in_planes, out_planes, num_features, init_weight=True): super(ScaleSpatialAttention, self).__init__() self.spatial_wise = nn.Sequential( #Nx1xHxW nn.Conv2d(1, 1, 3, bias=False, padding=1), nn.ReLU(), nn.Conv2d(1, 1, 1, bias=False), nn.Sigmoid() ) self.attention_wise = nn.Sequential( nn.Conv2d(in_planes, num_features, 1, bias=False), nn.Sigmoid() ) if init_weight: self._initialize_weights() def _initialize_weights(self): for m in self.modules(): if isinstance(m, nn.Conv2d): nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') if m.bias is not None: nn.init.constant_(m.bias, 0) if isinstance(m ,nn.BatchNorm2d): nn.init.constant_(m.weight, 1) nn.init.constant_(m.bias, 0) def forward(self, x): global_x = torch.mean(x, dim=1, keepdim=True) global_x = self.spatial_wise(global_x) + x global_x = self.attention_wise(global_x) return global_x class ScaleFeatureSelection(nn.Module): def __init__(self, in_channels, inter_channels , out_features_num=4, attention_type='scale_spatial'): super(ScaleFeatureSelection, self).__init__() self.in_channels=in_channels self.inter_channels = inter_channels self.out_features_num = out_features_num self.conv = nn.Conv2d(in_channels, inter_channels, 3, padding=1) self.type = attention_type if self.type == 'scale_spatial': self.enhanced_attention = ScaleSpatialAttention(inter_channels, inter_channels//4, out_features_num) elif self.type == 'scale_channel_spatial': self.enhanced_attention = ScaleChannelSpatialAttention(inter_channels, inter_channels // 4, out_features_num) elif self.type == 'scale_channel': self.enhanced_attention = ScaleChannelAttention(inter_channels, inter_channels//2, out_features_num) def _initialize_weights(self, m): classname = m.__class__.__name__ if classname.find('Conv') != -1: nn.init.kaiming_normal_(m.weight.data) elif classname.find('BatchNorm') != -1: m.weight.data.fill_(1.) m.bias.data.fill_(1e-4) def forward(self, concat_x, features_list): concat_x = self.conv(concat_x) score = self.enhanced_attention(concat_x) assert len(features_list) == self.out_features_num if self.type not in ['scale_channel_spatial', 'scale_spatial']: shape = features_list[0].shape[2:] score = F.interpolate(score, size=shape, mode='bilinear') x = [] for i in range(self.out_features_num): x.append(score[:, i:i+1] * features_list[i]) return torch.cat(x, dim=1)