123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118 |
- # -*- coding: utf-8 -*-
- """
- @time: 2021/2/8 21:28
- @author: Bourne-M
- """
- import torch
- from torch import nn
- import torch.nn.functional as F
- from torchocr.networks.CommonModules import ScaleFeatureSelection
- import numpy as np
- def weights_init(m):
- import torch.nn.init as init
- if isinstance(m, nn.Conv2d):
- init.kaiming_normal_(m.weight.data)
- if m.bias is not None:
- init.normal_(m.bias.data)
- elif isinstance(m, nn.ConvTranspose2d):
- init.kaiming_normal_(m.weight.data)
- if m.bias is not None:
- init.normal_(m.bias.data)
- elif isinstance(m, nn.BatchNorm2d):
- init.normal_(m.weight.data, mean=1, std=0.02)
- init.constant_(m.bias.data, 0)
- def weights_init(m):
- import torch.nn.init as init
- if isinstance(m, nn.Conv2d):
- init.kaiming_normal_(m.weight.data)
- if m.bias is not None:
- init.normal_(m.bias.data)
- elif isinstance(m, nn.ConvTranspose2d):
- init.kaiming_normal_(m.weight.data)
- if m.bias is not None:
- init.normal_(m.bias.data)
- elif isinstance(m, nn.BatchNorm2d):
- init.normal_(m.weight.data, mean=1, std=0.02)
- init.constant_(m.bias.data, 0)
- class DB_Asf(nn.Module):
- def __init__(self, in_channels, out_channels=256, **kwargs):
- """
- :param in_channels: 基础网络输出的维度
- :param kwargs:
- """
- super().__init__()
- inplace = True
- self.out_channels = out_channels
- # reduce layers
- self.in2_conv = nn.Conv2d(in_channels[0], self.out_channels, kernel_size=1, bias=False)
- self.in3_conv = nn.Conv2d(in_channels[1], self.out_channels, kernel_size=1, bias=False)
- self.in4_conv = nn.Conv2d(in_channels[2], self.out_channels, kernel_size=1, bias=False)
- self.in5_conv = nn.Conv2d(in_channels[3], self.out_channels, kernel_size=1, bias=False)
- # Smooth layers
- self.p5_conv = nn.Sequential(
- nn.Conv2d(self.out_channels, self.out_channels // 4, kernel_size=3, padding=1, bias=False),
- nn.Upsample(scale_factor=8, mode='nearest'))
- self.p4_conv = nn.Sequential(
- nn.Conv2d(self.out_channels, self.out_channels // 4, kernel_size=3, padding=1, bias=False),
- nn.Upsample(scale_factor=4, mode='nearest'))
- self.p3_conv = nn.Sequential(
- nn.Conv2d(self.out_channels, self.out_channels // 4, kernel_size=3, padding=1, bias=False),
- nn.Upsample(scale_factor=2, mode='nearest'))
- self.p2_conv = nn.Conv2d(self.out_channels, self.out_channels // 4, kernel_size=3, padding=1, bias=False)
- self.concat_attention = ScaleFeatureSelection(out_channels, out_channels // 4,
- attention_type='scale_channel_spatial')
- self.in2_conv.apply(weights_init)
- self.in3_conv.apply(weights_init)
- self.in4_conv.apply(weights_init)
- self.in5_conv.apply(weights_init)
- self.p5_conv.apply(weights_init)
- self.p4_conv.apply(weights_init)
- self.p3_conv.apply(weights_init)
- self.p2_conv.apply(weights_init)
- def weights_init(self, m):
- classname = m.__class__.__name__
- if classname.find('Conv') != -1:
- nn.init.kaiming_normal_(m.weight.data)
- elif classname.find('BatchNorm') != -1:
- m.weight.data.fill_(1.)
- m.bias.data.fill_(1e-4)
- def _upsample_add(self, x, y):
- return F.interpolate(x, scale_factor=2) + y
- def _upsample_cat(self, p2, p3, p4, p5):
- p3 = F.interpolate(p3, scale_factor=2)
- p4 = F.interpolate(p4, scale_factor=4)
- p5 = F.interpolate(p5, scale_factor=8)
- return torch.cat([p5, p4, p3, p2], dim=1)
- def forward(self, x):
- c2, c3, c4, c5 = x
- in5 = self.in5_conv(c5)
- in4 = self.in4_conv(c4)
- in3 = self.in3_conv(c3)
- in2 = self.in2_conv(c2)
- out4 = self._upsample_add(in5, in4)
- out3 = self._upsample_add(out4, in3)
- out2 = self._upsample_add(out3, in2)
- p5 = self.p5_conv(in5)
- p4 = self.p4_conv(out4)
- p3 = self.p3_conv(out3)
- p2 = self.p2_conv(out2)
- fuse = torch.cat((p5, p4, p3, p2), 1)
- fuse = self.concat_attention(fuse, [p5, p4, p3, p2])
- return fuse
|