123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303 |
- import torch
- from torch import nn
- from torch.nn import functional as F
- from collections import OrderedDict
- class HSwish(nn.Module):
- def forward(self, x):
- out = x * F.relu6(x + 3, inplace=True) / 6
- return out
- class HardSigmoid(nn.Module):
- def __init__(self, type):
- super().__init__()
- self.type = type
- def forward(self, x):
- if self.type == 'paddle':
- x = (1.2 * x).add_(3.).clamp_(0., 6.).div_(6.)
- else:
- x = F.relu6(x + 3, inplace=True) / 6
- F.hardsigmoid()
- return x
- class HSigmoid(nn.Module):
- def forward(self, x):
- x = (1.2 * x).add_(3.).clamp_(0., 6.).div_(6.)
- return x
- class ConvBNACT(nn.Module):
- def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, groups=1, act=None):
- super().__init__()
- self.conv = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size,
- stride=stride, padding=padding, groups=groups,
- bias=False)
- self.bn = nn.BatchNorm2d(out_channels)
- if act == 'relu':
- self.act = nn.ReLU()
- elif act == 'hard_swish':
- self.act = HSwish()
- elif act is None:
- self.act = None
- def forward(self, x):
- x = self.conv(x)
- x = self.bn(x)
- if self.act is not None:
- x = self.act(x)
- return x
- class SEBlock(nn.Module):
- def __init__(self, in_channels, ratio=4):
- super().__init__()
- num_mid_filter = in_channels // ratio
- self.pool = nn.AdaptiveAvgPool2d(1)
- self.conv1 = nn.Conv2d(in_channels=in_channels, out_channels=num_mid_filter, kernel_size=1, bias=True)
- self.relu1 = nn.ReLU(inplace=True)
- self.conv2 = nn.Conv2d(in_channels=num_mid_filter, kernel_size=1, out_channels=in_channels, bias=True)
- # self.relu2 = HardSigmoid(hsigmoid_type)
- self.relu2 = HardSigmoid(type = 'paddle')
- def forward(self, x):
- attn = self.pool(x)
- attn = self.conv1(attn)
- attn = self.relu1(attn)
- attn = self.conv2(attn)
- attn = self.relu2(attn)
- return x * attn
- def global_avg_pool(x: torch.Tensor) -> torch.Tensor:
- N, C, H, W = x.shape
- y = x.view([N, C, H * W]).contiguous()
- y = y.sum(2)
- y = torch.unsqueeze(y, 2)
- y = torch.unsqueeze(y, 3)
- y = y / (H * W)
- return y
- def global_max_pool(x: torch.Tensor) -> torch.Tensor:
- N, C, H, W = x.shape
- y = x.view([N, C, H * W]).contiguous()
- y = torch.max(y, 2).values
- y = torch.unsqueeze(y, 2)
- y = torch.unsqueeze(y, 3)
- return y
- class ChannelAttention(nn.Module):
- def __init__(self, channels, ratio=16):
- super(ChannelAttention, self).__init__()
- self.avg_pool = nn.AdaptiveAvgPool2d((1, 1))
- self.max_pool = nn.AdaptiveMaxPool2d((1, 1))
- self.fc = nn.Sequential(nn.Conv2d(channels, channels // ratio, 1, bias=False),
- nn.ReLU(),
- nn.Conv2d(channels // ratio, channels, 1, bias=False), )
- self.sigmoid = nn.Sigmoid()
- def forward(self, x):
- y1 = self.avg_pool(x)
- y1 = self.fc(y1)
- y2 = self.max_pool(x)
- y2 = self.fc(y2)
- y = self.sigmoid(y1 + y2)
- return y
- class SpatialAttention(nn.Module):
- def __init__(self, kernel_size=7):
- super(SpatialAttention, self).__init__()
- assert kernel_size in (3, 7), 'kernel size must be 3 or 7'
- padding = 3 if kernel_size == 7 else 1
- self.conv1 = nn.Conv2d(2, 1, kernel_size, padding=padding, bias=False)
- self.sigmoid = nn.Sigmoid()
- def forward(self, x):
- avg_out = torch.mean(x, dim=1, keepdim=True)
- max_out, _ = torch.max(x, dim=1, keepdim=True)
- x = torch.cat([avg_out, max_out], dim=1)
- x = self.conv1(x)
- return self.sigmoid(x)
- class CBAM(nn.Module):
- def __init__(self, in_channels, ratio=16):
- super(CBAM, self).__init__()
- self.cam = ChannelAttention(in_channels, ratio)
- self.sam = SpatialAttention()
- def forward(self, x):
- x = x * self.cam(x)
- x = x * self.sam(x)
- return x
- class eca_layer(nn.Module):
- def __init__(self, channel, k_size=3):
- super(eca_layer, self).__init__()
- self.avg_pool = nn.AdaptiveAvgPool2d(1)
- self.conv = nn.Conv1d(1, 1, kernel_size=k_size, padding=(k_size - 1) // 2, bias=False)
- self.sigmoid = nn.Sigmoid()
- def forward(self, x):
- # feature descriptor on the global spatial information
- y = self.avg_pool(x)
- # Two different branches of ECA module
- y = self.conv(y.squeeze(-1).transpose(-1, -2)).transpose(-1, -2).unsqueeze(-1)
- # Multi-scale information fusion
- y = self.sigmoid(y)
- return x * y.expand_as(x)
- class ScaleChannelAttention(nn.Module):
- def __init__(self, in_planes, out_planes, num_features, init_weight=True):
- super(ScaleChannelAttention, self).__init__()
- self.avgpool = nn.AdaptiveAvgPool2d(1)
- print(self.avgpool)
- self.fc1 = nn.Conv2d(in_planes, out_planes, 1, bias=False)
- self.bn = nn.BatchNorm2d(out_planes)
- self.fc2 = nn.Conv2d(out_planes, num_features, 1, bias=False)
- if init_weight:
- self._initialize_weights()
- def _initialize_weights(self):
- for m in self.modules():
- if isinstance(m, nn.Conv2d):
- nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
- if m.bias is not None:
- nn.init.constant_(m.bias, 0)
- if isinstance(m ,nn.BatchNorm2d):
- nn.init.constant_(m.weight, 1)
- nn.init.constant_(m.bias, 0)
- def forward(self, x):
- global_x = self.avgpool(x)
- global_x = self.fc1(global_x)
- global_x = F.relu(self.bn(global_x))
- global_x = self.fc2(global_x)
- global_x = F.softmax(global_x, 1)
- return global_x
- class ScaleChannelSpatialAttention(nn.Module):
- def __init__(self, in_planes, out_planes, num_features, init_weight=True):
- super(ScaleChannelSpatialAttention, self).__init__()
- self.channel_wise = nn.Sequential(
- nn.AdaptiveAvgPool2d(1),
- nn.Conv2d(in_planes, out_planes , 1, bias=False),
- # nn.BatchNorm2d(out_planes),
- nn.ReLU(),
- nn.Conv2d(out_planes, in_planes, 1, bias=False)
- )
- self.spatial_wise = nn.Sequential(
- #Nx1xHxW
- nn.Conv2d(1, 1, 3, bias=False, padding=1),
- nn.ReLU(),
- nn.Conv2d(1, 1, 1, bias=False),
- nn.Sigmoid()
- )
- self.attention_wise = nn.Sequential(
- nn.Conv2d(in_planes, num_features, 1, bias=False),
- nn.Sigmoid()
- )
- if init_weight:
- self._initialize_weights()
- def _initialize_weights(self):
- for m in self.modules():
- if isinstance(m, nn.Conv2d):
- nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
- if m.bias is not None:
- nn.init.constant_(m.bias, 0)
- if isinstance(m ,nn.BatchNorm2d):
- nn.init.constant_(m.weight, 1)
- nn.init.constant_(m.bias, 0)
- def forward(self, x):
- # global_x = self.avgpool(x)
- #shape Nx4x1x1
- global_x = self.channel_wise(x).sigmoid()
- #shape: NxCxHxW
- global_x = global_x + x
- #shape:Nx1xHxW
- x = torch.mean(global_x, dim=1, keepdim=True)
- global_x = self.spatial_wise(x) + global_x
- global_x = self.attention_wise(global_x)
- return global_x
- class ScaleSpatialAttention(nn.Module):
- def __init__(self, in_planes, out_planes, num_features, init_weight=True):
- super(ScaleSpatialAttention, self).__init__()
- self.spatial_wise = nn.Sequential(
- #Nx1xHxW
- nn.Conv2d(1, 1, 3, bias=False, padding=1),
- nn.ReLU(),
- nn.Conv2d(1, 1, 1, bias=False),
- nn.Sigmoid()
- )
- self.attention_wise = nn.Sequential(
- nn.Conv2d(in_planes, num_features, 1, bias=False),
- nn.Sigmoid()
- )
- if init_weight:
- self._initialize_weights()
- def _initialize_weights(self):
- for m in self.modules():
- if isinstance(m, nn.Conv2d):
- nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
- if m.bias is not None:
- nn.init.constant_(m.bias, 0)
- if isinstance(m ,nn.BatchNorm2d):
- nn.init.constant_(m.weight, 1)
- nn.init.constant_(m.bias, 0)
- def forward(self, x):
- global_x = torch.mean(x, dim=1, keepdim=True)
- global_x = self.spatial_wise(global_x) + x
- global_x = self.attention_wise(global_x)
- return global_x
- class ScaleFeatureSelection(nn.Module):
- def __init__(self, in_channels, inter_channels , out_features_num=4, attention_type='scale_spatial'):
- super(ScaleFeatureSelection, self).__init__()
- self.in_channels=in_channels
- self.inter_channels = inter_channels
- self.out_features_num = out_features_num
- self.conv = nn.Conv2d(in_channels, inter_channels, 3, padding=1)
- self.type = attention_type
- if self.type == 'scale_spatial':
- self.enhanced_attention = ScaleSpatialAttention(inter_channels, inter_channels//4, out_features_num)
- elif self.type == 'scale_channel_spatial':
- self.enhanced_attention = ScaleChannelSpatialAttention(inter_channels, inter_channels // 4, out_features_num)
- elif self.type == 'scale_channel':
- self.enhanced_attention = ScaleChannelAttention(inter_channels, inter_channels//2, out_features_num)
- def _initialize_weights(self, m):
- classname = m.__class__.__name__
- if classname.find('Conv') != -1:
- nn.init.kaiming_normal_(m.weight.data)
- elif classname.find('BatchNorm') != -1:
- m.weight.data.fill_(1.)
- m.bias.data.fill_(1e-4)
- def forward(self, concat_x, features_list):
- concat_x = self.conv(concat_x)
- score = self.enhanced_attention(concat_x)
- assert len(features_list) == self.out_features_num
- if self.type not in ['scale_channel_spatial', 'scale_spatial']:
- shape = features_list[0].shape[2:]
- score = F.interpolate(score, size=shape, mode='bilinear')
- x = []
- for i in range(self.out_features_num):
- x.append(score[:, i:i+1] * features_list[i])
- return torch.cat(x, dim=1)
|