pse_fpn.py 2.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445
  1. # -*- coding: utf-8 -*-
  2. """
  3. @time: 2021/2/8 21:28
  4. @author: Bourne-M
  5. """
  6. import torch
  7. from torch import nn
  8. import torch.nn.functional as F
  9. from torchocr.networks.CommonModules import ConvBNACT
  10. class PSEFpn(nn.Module):
  11. def __init__(self, in_channels, out_channels=256, inplace=True, **kwargs):
  12. super().__init__()
  13. self.out_channels = out_channels * 4
  14. self.toplayer = ConvBNACT(in_channels=in_channels[3], out_channels=out_channels, kernel_size=1, stride=1, padding=0, act='relu')
  15. self.latlayer1 = ConvBNACT(in_channels=in_channels[2], out_channels=out_channels, kernel_size=1, stride=1, padding=0, act='relu')
  16. self.latlayer2 = ConvBNACT(in_channels=in_channels[1], out_channels=out_channels, kernel_size=1, stride=1, padding=0, act='relu')
  17. self.latlayer3 = ConvBNACT(in_channels=in_channels[0], out_channels=out_channels, kernel_size=1, stride=1, padding=0, act='relu')
  18. # Smooth layers
  19. self.smooth1 = ConvBNACT(in_channels=out_channels, out_channels=out_channels, kernel_size=3, stride=1, padding=1, act='relu')
  20. self.smooth2 = ConvBNACT(in_channels=out_channels, out_channels=out_channels, kernel_size=3, stride=1, padding=1, act='relu')
  21. self.smooth3 = ConvBNACT(in_channels=out_channels, out_channels=out_channels, kernel_size=3, stride=1, padding=1, act='relu')
  22. def forward(self, x):
  23. c2, c3, c4, c5 = x
  24. p5 = self.toplayer(c5)
  25. p4 = self._upsample_add(p5, self.latlayer1(c4))
  26. p4 = self.smooth1(p4)
  27. p3 = self._upsample_add(p4, self.latlayer2(c3))
  28. p3 = self.smooth2(p3)
  29. p2 = self._upsample_add(p3, self.latlayer3(c2))
  30. p2 = self.smooth3(p2)
  31. x = self._upsample_cat(p2, p3, p4, p5)
  32. return x
  33. def _upsample_add(self, x, y):
  34. return F.interpolate(x, size=y.size()[2:], mode='bilinear', align_corners=False) + y
  35. def _upsample_cat(self, p2, p3, p4, p5):
  36. h, w = p2.size()[2:]
  37. p3 = F.interpolate(p3, size=(h, w), mode='bilinear', align_corners=False)
  38. p4 = F.interpolate(p4, size=(h, w), mode='bilinear', align_corners=False)
  39. p5 = F.interpolate(p5, size=(h, w), mode='bilinear', align_corners=False)
  40. return torch.cat([p2, p3, p4, p5], dim=1)