RecSequenceDecoder.py 1.3 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950
  1. # -*- coding: utf-8 -*-
  2. # @Time : 2020/5/15 17:46
  3. # @Author : zhoujun
  4. from __future__ import absolute_import
  5. from __future__ import division
  6. from __future__ import print_function
  7. from collections import OrderedDict
  8. from torch import nn
  9. import torch
  10. class DecoderWithRNN(nn.Module):
  11. def __init__(self, in_channels, **kwargs):
  12. super().__init__()
  13. rnn_hidden_size = kwargs.get('hidden_size', 96)
  14. self.out_channels = rnn_hidden_size * 2
  15. self.layers = 2
  16. self.lstm = nn.LSTM(in_channels, rnn_hidden_size, bidirectional=True, batch_first=True, num_layers=self.layers)
  17. def forward(self, x):
  18. x = self.lstm(x)[0]
  19. return x
  20. class Reshape(nn.Module):
  21. def __init__(self, in_channels, **kwargs):
  22. super().__init__()
  23. self.out_channels = in_channels
  24. def forward(self, x):
  25. B, C, H, W = x.shape
  26. x = x.reshape(B, C, H * W)
  27. x = x.permute((0, 2, 1)) # (NTC)(batch, width, channel)s
  28. return x
  29. class SequenceDecoder(nn.Module):
  30. def __init__(self, in_channels, **kwargs):
  31. super().__init__()
  32. self.reshape = Reshape(in_channels)
  33. self.decoder = DecoderWithRNN(in_channels, **kwargs)
  34. self.out_channels = self.decoder.out_channels
  35. def forward(self, x):
  36. x = self.reshape(x)
  37. x = self.decoder(x)
  38. return x