DetPseHead.py 1013 B

1234567891011121314151617181920212223242526
  1. from __future__ import absolute_import
  2. from __future__ import division
  3. from __future__ import print_function
  4. from torch import nn
  5. import torch.nn.functional as F
  6. from torchocr.networks.CommonModules import ConvBNACT
  7. class PseHead(nn.Module):
  8. def __init__(self, in_channels, result_num=6, **kwargs):
  9. super(PseHead, self).__init__()
  10. self.H = kwargs.get('H', 640)
  11. self.W = kwargs.get('W', 640)
  12. self.scale = kwargs.get('scale', 1)
  13. self.conv = ConvBNACT(in_channels, in_channels // 4, kernel_size=3, padding=1, stride=1, act='relu')
  14. self.out_conv = nn.Conv2d(in_channels // 4, result_num, kernel_size=1, stride=1)
  15. def forward(self, x):
  16. x = self.conv(x)
  17. x = self.out_conv(x)
  18. if self.train:
  19. x = F.interpolate(x, size=(self.H, self.W), mode='bilinear', align_corners=True)
  20. else:
  21. x = F.interpolate(x, size=(self.H // self.scale, self.W // self.scale), mode='bilinear', align_corners=True)
  22. return x