DB_ASF.py 4.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118
  1. # -*- coding: utf-8 -*-
  2. """
  3. @time: 2021/2/8 21:28
  4. @author: Bourne-M
  5. """
  6. import torch
  7. from torch import nn
  8. import torch.nn.functional as F
  9. from torchocr.networks.CommonModules import ScaleFeatureSelection
  10. import numpy as np
  11. def weights_init(m):
  12. import torch.nn.init as init
  13. if isinstance(m, nn.Conv2d):
  14. init.kaiming_normal_(m.weight.data)
  15. if m.bias is not None:
  16. init.normal_(m.bias.data)
  17. elif isinstance(m, nn.ConvTranspose2d):
  18. init.kaiming_normal_(m.weight.data)
  19. if m.bias is not None:
  20. init.normal_(m.bias.data)
  21. elif isinstance(m, nn.BatchNorm2d):
  22. init.normal_(m.weight.data, mean=1, std=0.02)
  23. init.constant_(m.bias.data, 0)
  24. def weights_init(m):
  25. import torch.nn.init as init
  26. if isinstance(m, nn.Conv2d):
  27. init.kaiming_normal_(m.weight.data)
  28. if m.bias is not None:
  29. init.normal_(m.bias.data)
  30. elif isinstance(m, nn.ConvTranspose2d):
  31. init.kaiming_normal_(m.weight.data)
  32. if m.bias is not None:
  33. init.normal_(m.bias.data)
  34. elif isinstance(m, nn.BatchNorm2d):
  35. init.normal_(m.weight.data, mean=1, std=0.02)
  36. init.constant_(m.bias.data, 0)
  37. class DB_Asf(nn.Module):
  38. def __init__(self, in_channels, out_channels=256, **kwargs):
  39. """
  40. :param in_channels: 基础网络输出的维度
  41. :param kwargs:
  42. """
  43. super().__init__()
  44. inplace = True
  45. self.out_channels = out_channels
  46. # reduce layers
  47. self.in2_conv = nn.Conv2d(in_channels[0], self.out_channels, kernel_size=1, bias=False)
  48. self.in3_conv = nn.Conv2d(in_channels[1], self.out_channels, kernel_size=1, bias=False)
  49. self.in4_conv = nn.Conv2d(in_channels[2], self.out_channels, kernel_size=1, bias=False)
  50. self.in5_conv = nn.Conv2d(in_channels[3], self.out_channels, kernel_size=1, bias=False)
  51. # Smooth layers
  52. self.p5_conv = nn.Sequential(
  53. nn.Conv2d(self.out_channels, self.out_channels // 4, kernel_size=3, padding=1, bias=False),
  54. nn.Upsample(scale_factor=8, mode='nearest'))
  55. self.p4_conv = nn.Sequential(
  56. nn.Conv2d(self.out_channels, self.out_channels // 4, kernel_size=3, padding=1, bias=False),
  57. nn.Upsample(scale_factor=4, mode='nearest'))
  58. self.p3_conv = nn.Sequential(
  59. nn.Conv2d(self.out_channels, self.out_channels // 4, kernel_size=3, padding=1, bias=False),
  60. nn.Upsample(scale_factor=2, mode='nearest'))
  61. self.p2_conv = nn.Conv2d(self.out_channels, self.out_channels // 4, kernel_size=3, padding=1, bias=False)
  62. self.concat_attention = ScaleFeatureSelection(out_channels, out_channels // 4,
  63. attention_type='scale_channel_spatial')
  64. self.in2_conv.apply(weights_init)
  65. self.in3_conv.apply(weights_init)
  66. self.in4_conv.apply(weights_init)
  67. self.in5_conv.apply(weights_init)
  68. self.p5_conv.apply(weights_init)
  69. self.p4_conv.apply(weights_init)
  70. self.p3_conv.apply(weights_init)
  71. self.p2_conv.apply(weights_init)
  72. def weights_init(self, m):
  73. classname = m.__class__.__name__
  74. if classname.find('Conv') != -1:
  75. nn.init.kaiming_normal_(m.weight.data)
  76. elif classname.find('BatchNorm') != -1:
  77. m.weight.data.fill_(1.)
  78. m.bias.data.fill_(1e-4)
  79. def _upsample_add(self, x, y):
  80. return F.interpolate(x, scale_factor=2) + y
  81. def _upsample_cat(self, p2, p3, p4, p5):
  82. p3 = F.interpolate(p3, scale_factor=2)
  83. p4 = F.interpolate(p4, scale_factor=4)
  84. p5 = F.interpolate(p5, scale_factor=8)
  85. return torch.cat([p5, p4, p3, p2], dim=1)
  86. def forward(self, x):
  87. c2, c3, c4, c5 = x
  88. in5 = self.in5_conv(c5)
  89. in4 = self.in4_conv(c4)
  90. in3 = self.in3_conv(c3)
  91. in2 = self.in2_conv(c2)
  92. out4 = self._upsample_add(in5, in4)
  93. out3 = self._upsample_add(out4, in3)
  94. out2 = self._upsample_add(out3, in2)
  95. p5 = self.p5_conv(in5)
  96. p4 = self.p4_conv(out4)
  97. p3 = self.p3_conv(out3)
  98. p2 = self.p2_conv(out2)
  99. fuse = torch.cat((p5, p4, p3, p2), 1)
  100. fuse = self.concat_attention(fuse, [p5, p4, p3, p2])
  101. return fuse