123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254 |
- import torch
- import torch.nn as nn
- from einops import rearrange
- def conv_1x1_bn(inp, oup):
- return nn.Sequential(
- nn.Conv2d(inp, oup, 1, 1, 0, bias=False),
- nn.BatchNorm2d(oup),
- nn.SiLU()
- )
- def conv_nxn_bn(inp, oup, kernal_size=3, stride=1):
- return nn.Sequential(
- nn.Conv2d(inp, oup, kernal_size, stride, 1, bias=False),
- nn.BatchNorm2d(oup),
- nn.SiLU()
- )
- class PreNorm(nn.Module):
- def __init__(self, dim, fn):
- super().__init__()
- self.norm = nn.LayerNorm(dim)
- self.fn = fn
- def forward(self, x, **kwargs):
- return self.fn(self.norm(x), **kwargs)
- class FeedForward(nn.Module):
- def __init__(self, dim, hidden_dim, dropout=0.):
- super().__init__()
- self.net = nn.Sequential(
- nn.Linear(dim, hidden_dim),
- nn.SiLU(),
- nn.Dropout(dropout),
- nn.Linear(hidden_dim, dim),
- nn.Dropout(dropout)
- )
- def forward(self, x):
- return self.net(x)
- class Attention(nn.Module):
- def __init__(self, dim, heads=8, dim_head=64, dropout=0.):
- super().__init__()
- inner_dim = dim_head * heads
- project_out = not (heads == 1 and dim_head == dim)
- self.heads = heads
- self.scale = dim_head ** -0.5
- self.attend = nn.Softmax(dim=-1)
- self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False)
- self.to_out = nn.Sequential(
- nn.Linear(inner_dim, dim),
- nn.Dropout(dropout)
- ) if project_out else nn.Identity()
- def forward(self, x):
- qkv = self.to_qkv(x).chunk(3, dim=-1)
- q, k, v = map(lambda t: rearrange(t, 'b p n (h d) -> b p h n d', h=self.heads), qkv)
- dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
- attn = self.attend(dots)
- out = torch.matmul(attn, v)
- out = rearrange(out, 'b p h n d -> b p n (h d)')
- return self.to_out(out)
- class Transformer(nn.Module):
- def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout=0.):
- super().__init__()
- self.layers = nn.ModuleList([])
- for _ in range(depth):
- self.layers.append(nn.ModuleList([
- PreNorm(dim, Attention(dim, heads, dim_head, dropout)),
- PreNorm(dim, FeedForward(dim, mlp_dim, dropout))
- ]))
- def forward(self, x):
- for attn, ff in self.layers:
- x = attn(x) + x
- x = ff(x) + x
- return x
- class MV2Block(nn.Module):
- def __init__(self, inp, oup, stride=1, expansion=4):
- super().__init__()
- self.stride = stride
- assert stride in [1, 2]
- hidden_dim = int(inp * expansion)
- self.use_res_connect = self.stride == 1 and inp == oup
- if expansion == 1:
- self.conv = nn.Sequential(
- # dw
- nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, bias=False),
- nn.BatchNorm2d(hidden_dim),
- nn.SiLU(),
- # pw-linear
- nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
- nn.BatchNorm2d(oup),
- )
- else:
- self.conv = nn.Sequential(
- # pw
- nn.Conv2d(inp, hidden_dim, 1, 1, 0, bias=False),
- nn.BatchNorm2d(hidden_dim),
- nn.SiLU(),
- # dw
- nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, bias=False),
- nn.BatchNorm2d(hidden_dim),
- nn.SiLU(),
- # pw-linear
- nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
- nn.BatchNorm2d(oup),
- )
- def forward(self, x):
- if self.use_res_connect:
- return x + self.conv(x)
- else:
- return self.conv(x)
- class MobileViTBlock(nn.Module):
- def __init__(self, dim, depth, channel, kernel_size, patch_size, mlp_dim, dropout=0.):
- super().__init__()
- self.ph, self.pw = patch_size
- self.conv1 = conv_nxn_bn(channel, channel, kernel_size)
- self.conv2 = conv_1x1_bn(channel, dim)
- self.transformer = Transformer(dim, depth, 4, 8, mlp_dim, dropout)
- self.conv3 = conv_1x1_bn(dim, channel)
- self.conv4 = conv_nxn_bn(2 * channel, channel, kernel_size)
- def forward(self, x):
- y = x.clone()
- # Local representations
- x = self.conv1(x)
- x = self.conv2(x)
- # Global representations
- _, _, h, w = x.shape
- x = rearrange(x, 'b d (h ph) (w pw) -> b (ph pw) (h w) d', ph=self.ph, pw=self.pw)
- x = self.transformer(x)
- x = rearrange(x, 'b (ph pw) (h w) d -> b d (h ph) (w pw)', h=h // self.ph, w=w // self.pw, ph=self.ph,
- pw=self.pw)
- # Fusion
- x = self.conv3(x)
- x = torch.cat((x, y), 1)
- x = self.conv4(x)
- return x
- class MobileViT(nn.Module):
- def __init__(self, image_size, dims, channels, num_classes, expansion=4, kernel_size=3, patch_size=(2, 2),**kwargs):
- super().__init__()
- ih, iw = image_size
- ph, pw = patch_size
- assert ih % ph == 0 and iw % pw == 0
- L = [2, 4, 3]
- self.out_channels = [channels[3], channels[5], channels[7], channels[9]]
- self.conv1 = conv_nxn_bn(3, channels[0], stride=2)
- self.mv2 = nn.ModuleList([])
- self.mv2.append(MV2Block(channels[0], channels[1], 1, expansion))
- self.mv2.append(MV2Block(channels[1], channels[2], 2, expansion))
- self.mv2.append(MV2Block(channels[2], channels[3], 1, expansion))
- self.mv2.append(MV2Block(channels[2], channels[3], 1, expansion)) # Repeat
- self.mv2.append(MV2Block(channels[3], channels[4], 2, expansion))
- self.mv2.append(MV2Block(channels[5], channels[6], 2, expansion))
- self.mv2.append(MV2Block(channels[7], channels[8], 2, expansion))
- self.mvit = nn.ModuleList([])
- self.mvit.append(MobileViTBlock(dims[0], L[0], channels[5], kernel_size, patch_size, int(dims[0] * 2)))
- self.mvit.append(MobileViTBlock(dims[1], L[1], channels[7], kernel_size, patch_size, int(dims[1] * 4)))
- self.mvit.append(MobileViTBlock(dims[2], L[2], channels[9], kernel_size, patch_size, int(dims[2] * 4)))
- self.conv2 = conv_1x1_bn(channels[-2], channels[-1])
- self.pool = nn.AvgPool2d(ih // 32, 1)
- self.fc = nn.Linear(channels[-1], num_classes, bias=False)
- def forward(self, x):
- out = []
- x = self.conv1(x)
- x = self.mv2[0](x)
- x = self.mv2[1](x)
- x = self.mv2[2](x)
- out.append(x)
- x = self.mv2[3](x) # Repeat
- x = self.mv2[4](x) # b*48*32*32
- x = self.mvit[0](x)
- out.append(x)
- x = self.mv2[5](x) # b*64*16*16
- x = self.mvit[1](x)
- out.append(x)
- x = self.mv2[6](x) # b*80*8*8
- x = self.mvit[2](x) # b*80*8*8
- out.append(x)
- return out
- # x = self.conv2(x)
- # x = self.pool(x).view(-1, x.shape[1])
- # x = self.fc(x)
- # return x
- def mobilevit_xxs():
- dims = [64, 80, 96]
- channels = [16, 16, 24, 24, 48, 48, 64, 64, 80, 80, 320]
- return MobileViT((256, 256), dims, channels, num_classes=1000, expansion=2)
- def mobilevit_xs(inchannel,**kwargs):
- dims = [96, 120, 144]
- channels = [16, 32, 48, 48, 64, 64, 80, 80, 96, 96, 384]
- return MobileViT((512, 512), dims, channels, num_classes=1000)
- def mobilevit_s():
- dims = [144, 192, 240]
- channels = [16, 32, 64, 64, 96, 96, 128, 128, 160, 160, 640]
- return MobileViT((256, 256), dims, channels, num_classes=1000)
- def count_parameters(model):
- return sum(p.numel() for p in model.parameters() if p.requires_grad)
- if __name__ == '__main__':
- img = torch.randn(5, 3, 256, 256)
- vit = mobilevit_xs()
- out = vit(img)
- print(count_parameters(vit))
|