RNN.py 1.9 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667
  1. # encoding: utf-8
  2. """
  3. @time: 2021/2/9 22:21
  4. @author: Bourne-M
  5. """
  6. from __future__ import absolute_import
  7. from __future__ import division
  8. from __future__ import print_function
  9. from torch import nn
  10. class Im2Seq(nn.Module):
  11. def __init__(self, in_channels, **kwargs):
  12. super().__init__()
  13. self.out_channels = in_channels
  14. def forward(self, x):
  15. B, C, H, W = x.shape
  16. assert H == 1
  17. x = x.reshape(B, C, H * W)
  18. x = x.permute((0, 2, 1))
  19. return x
  20. class EncoderWithRNN(nn.Module):
  21. def __init__(self, in_channels,**kwargs):
  22. super(EncoderWithRNN, self).__init__()
  23. hidden_size = kwargs.get('hidden_size', 256)
  24. self.out_channels = hidden_size * 2
  25. self.lstm = nn.LSTM(in_channels, hidden_size, bidirectional=True, num_layers=2,batch_first=True)
  26. def forward(self, x):
  27. self.lstm.flatten_parameters()
  28. x, _ = self.lstm(x)
  29. return x
  30. class SequenceEncoder(nn.Module):
  31. def __init__(self, in_channels, encoder_type='rnn', **kwargs):
  32. super(SequenceEncoder, self).__init__()
  33. self.encoder_reshape = Im2Seq(in_channels)
  34. self.out_channels = self.encoder_reshape.out_channels
  35. if encoder_type == 'reshape':
  36. self.only_reshape = True
  37. else:
  38. support_encoder_dict = {
  39. 'reshape': Im2Seq,
  40. 'rnn': EncoderWithRNN
  41. }
  42. assert encoder_type in support_encoder_dict, '{} must in {}'.format(
  43. encoder_type, support_encoder_dict.keys())
  44. self.encoder = support_encoder_dict[encoder_type](
  45. self.encoder_reshape.out_channels,**kwargs)
  46. self.out_channels = self.encoder.out_channels
  47. self.only_reshape = False
  48. def forward(self, x):
  49. x = self.encoder_reshape(x)
  50. if not self.only_reshape:
  51. x = self.encoder(x)
  52. return x