FPN.py 2.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172
  1. # -*- coding: utf-8 -*-
  2. # @Time : 2020/5/21 13:50
  3. # @Author : zhoujun
  4. import torch
  5. from torch import nn
  6. import torch.nn.functional as F
  7. class ConvBnRelu(nn.Module):
  8. def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True, padding_mode='zeros', inplace=True):
  9. super().__init__()
  10. self.conv = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation,
  11. groups=groups, bias=bias, padding_mode=padding_mode)
  12. self.bn = nn.BatchNorm2d(out_channels)
  13. self.relu = nn.ReLU(inplace=inplace)
  14. def forward(self, x):
  15. x = self.conv(x)
  16. x = self.bn(x)
  17. x = self.relu(x)
  18. return x
  19. class FPN(nn.Module):
  20. def __init__(self, in_channels, out_channels=256, **kwargs):
  21. """
  22. :param in_channels: 基础网络输出的维度
  23. :param kwargs:
  24. """
  25. super().__init__()
  26. inplace = True
  27. self.conv_out = out_channels
  28. out_channels = out_channels // 4
  29. # reduce layers
  30. self.reduce_conv_c2 = ConvBnRelu(in_channels[0], out_channels, kernel_size=1, inplace=inplace)
  31. self.reduce_conv_c3 = ConvBnRelu(in_channels[1], out_channels, kernel_size=1, inplace=inplace)
  32. self.reduce_conv_c4 = ConvBnRelu(in_channels[2], out_channels, kernel_size=1, inplace=inplace)
  33. self.reduce_conv_c5 = ConvBnRelu(in_channels[3], out_channels, kernel_size=1, inplace=inplace)
  34. # Smooth layers
  35. self.smooth_p4 = ConvBnRelu(out_channels, out_channels, kernel_size=3, padding=1, inplace=inplace)
  36. self.smooth_p3 = ConvBnRelu(out_channels, out_channels, kernel_size=3, padding=1, inplace=inplace)
  37. self.smooth_p2 = ConvBnRelu(out_channels, out_channels, kernel_size=3, padding=1, inplace=inplace)
  38. self.conv = nn.Sequential(
  39. nn.Conv2d(self.conv_out, self.conv_out, kernel_size=3, padding=1, stride=1),
  40. nn.BatchNorm2d(self.conv_out),
  41. nn.ReLU(inplace=inplace)
  42. )
  43. self.out_channels = self.conv_out
  44. def forward(self, x):
  45. c2, c3, c4, c5 = x
  46. # Top-down
  47. p5 = self.reduce_conv_c5(c5)
  48. p4 = self._upsample_add(p5, self.reduce_conv_c4(c4))
  49. p4 = self.smooth_p4(p4)
  50. p3 = self._upsample_add(p4, self.reduce_conv_c3(c3))
  51. p3 = self.smooth_p3(p3)
  52. p2 = self._upsample_add(p3, self.reduce_conv_c2(c2))
  53. p2 = self.smooth_p2(p2)
  54. x = self._upsample_cat(p2, p3, p4, p5)
  55. x = self.conv(x)
  56. return x
  57. def _upsample_add(self, x, y):
  58. return F.interpolate(x, size=y.size()[2:]) + y
  59. def _upsample_cat(self, p2, p3, p4, p5):
  60. h, w = p2.size()[2:]
  61. p3 = F.interpolate(p3, size=(h, w))
  62. p4 = F.interpolate(p4, size=(h, w))
  63. p5 = F.interpolate(p5, size=(h, w))
  64. return torch.cat([p2, p3, p4, p5], dim=1)