FeaturePyramidNetwork.py 3.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110
  1. # -*- coding:utf-8 -*-
  2. # @author :adolf
  3. import torch.nn.functional as F
  4. from torch import nn, Tensor
  5. """
  6. out_channels=96时,和现有的fpn相比,这个fpn精度差不多,但是模型尺寸会大500k
  7. """
  8. class FeaturePyramidNetwork(nn.Module):
  9. def __init__(self, in_channels, out_channels=256):
  10. super(FeaturePyramidNetwork, self).__init__()
  11. self.inner_blocks = nn.ModuleList()
  12. self.layer_blocks = nn.ModuleList()
  13. self.out_channels = out_channels
  14. for in_channels in in_channels:
  15. if in_channels == 0:
  16. raise ValueError("in_channels=0 is currently not supported")
  17. inner_block_module = nn.Conv2d(in_channels, out_channels, 1)
  18. layer_block_module = nn.Conv2d(out_channels, out_channels, 3, padding=1)
  19. self.inner_blocks.append(inner_block_module)
  20. self.layer_blocks.append(layer_block_module)
  21. # initialize parameters now to avoid modifying the initialization of top_blocks
  22. for m in self.children():
  23. if isinstance(m, nn.Conv2d):
  24. nn.init.kaiming_uniform_(m.weight, a=1)
  25. nn.init.constant_(m.bias, 0)
  26. def get_result_from_inner_blocks(self, x, idx):
  27. num_blocks = 0
  28. for m in self.inner_blocks:
  29. num_blocks += 1
  30. if idx < 0:
  31. idx += num_blocks
  32. i = 0
  33. out = x
  34. for module in self.inner_blocks:
  35. if i == idx:
  36. out = module(x)
  37. i += 1
  38. return out
  39. def get_result_from_layer_blocks(self, x, idx):
  40. num_blocks = 0
  41. for m in self.layer_blocks:
  42. num_blocks += 1
  43. if idx < 0:
  44. idx += num_blocks
  45. i = 0
  46. out = x
  47. for module in self.layer_blocks:
  48. if i == idx:
  49. out = module(x)
  50. i += 1
  51. return out
  52. def forward(self, x):
  53. # unpack OrderedDict into two lists for easier handling
  54. # names = list(x.keys())
  55. # x = list(x.values())
  56. last_inner = self.get_result_from_inner_blocks(x[-1], -1)
  57. results = []
  58. results.append(self.get_result_from_layer_blocks(last_inner, -1))
  59. for idx in range(len(x) - 2, -1, -1):
  60. inner_lateral = self.get_result_from_inner_blocks(x[idx], idx)
  61. feat_shape = inner_lateral.shape[-2:]
  62. inner_top_down = F.interpolate(last_inner, size=feat_shape, mode="nearest")
  63. last_inner = inner_lateral + inner_top_down
  64. results.insert(0, self.get_result_from_layer_blocks(last_inner, idx))
  65. # make it back an OrderedDict
  66. # out = OrderedDict([(k, v) for k, v in zip(names, results)])
  67. out = results[0]
  68. return out
  69. class ExtraFPNBlock(nn.Module):
  70. def forward(self, results, x, names):
  71. pass
  72. class LastLevelMaxPool(ExtraFPNBlock):
  73. def forward(self, x, y, names):
  74. names.append("pool")
  75. x.append(F.max_pool2d(x[-1], 1, 2, 0))
  76. return x, names
  77. class LastLevelP6P7(ExtraFPNBlock):
  78. def __init__(self, in_channels, out_channels):
  79. super(LastLevelP6P7, self).__init__()
  80. self.p6 = nn.Conv2d(in_channels, out_channels, 3, 2, 1)
  81. self.p7 = nn.Conv2d(out_channels, out_channels, 3, 2, 1)
  82. for module in [self.p6, self.p7]:
  83. nn.init.kaiming_uniform_(module.weight, a=1)
  84. nn.init.constant_(module.bias, 0)
  85. self.use_P5 = in_channels == out_channels
  86. def forward(self, p, c, names):
  87. p5, c5 = p[-1], c[-1]
  88. x = p5 if self.use_P5 else c5
  89. p6 = self.p6(x)
  90. p7 = self.p7(F.relu(p6))
  91. p.extend([p6, p7])
  92. names.extend(["p6", "p7"])
  93. return p, names