1234567891011121314151617181920212223242526 |
- from __future__ import absolute_import
- from __future__ import division
- from __future__ import print_function
- from torch import nn
- import torch.nn.functional as F
- from torchocr.networks.CommonModules import ConvBNACT
- class PseHead(nn.Module):
- def __init__(self, in_channels, result_num=6, **kwargs):
- super(PseHead, self).__init__()
- self.H = kwargs.get('H', 640)
- self.W = kwargs.get('W', 640)
- self.scale = kwargs.get('scale', 1)
- self.conv = ConvBNACT(in_channels, in_channels // 4, kernel_size=3, padding=1, stride=1, act='relu')
- self.out_conv = nn.Conv2d(in_channels // 4, result_num, kernel_size=1, stride=1)
- def forward(self, x):
- x = self.conv(x)
- x = self.out_conv(x)
- if self.train:
- x = F.interpolate(x, size=(self.H, self.W), mode='bilinear', align_corners=True)
- else:
- x = F.interpolate(x, size=(self.H // self.scale, self.W // self.scale), mode='bilinear', align_corners=True)
- return x
|