rnn.py 3.1 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495
  1. # copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from __future__ import absolute_import
  15. from __future__ import division
  16. from __future__ import print_function
  17. from paddle import nn, ParamAttr
  18. from ppocr.modeling.heads.rec_ctc_head import get_para_bias_attr
  19. class Im2Seq(nn.Layer):
  20. def __init__(self, in_channels, **kwargs):
  21. super().__init__()
  22. self.out_channels = in_channels
  23. def forward(self, x):
  24. B, C, H, W = x.shape
  25. assert H == 1
  26. x = x.squeeze(axis=2)
  27. x = x.transpose([0, 2, 1]) # (NTC)(batch, width, channels)
  28. return x
  29. class EncoderWithRNN(nn.Layer):
  30. def __init__(self, in_channels, hidden_size):
  31. super(EncoderWithRNN, self).__init__()
  32. self.out_channels = hidden_size * 2
  33. self.lstm = nn.LSTM(
  34. in_channels, hidden_size, direction='bidirectional', num_layers=2
  35. # , weight_hh_attr=ParamAttr(trainable=False)
  36. # , weight_ih_attr=ParamAttr(trainable=False)
  37. )
  38. def forward(self, x):
  39. x, _ = self.lstm(x)
  40. return x
  41. class EncoderWithFC(nn.Layer):
  42. def __init__(self, in_channels, hidden_size):
  43. super(EncoderWithFC, self).__init__()
  44. self.out_channels = hidden_size
  45. weight_attr, bias_attr = get_para_bias_attr(
  46. l2_decay=0.00001, k=in_channels, name='reduce_encoder_fea')
  47. self.fc = nn.Linear(
  48. in_channels,
  49. hidden_size,
  50. weight_attr=weight_attr,
  51. bias_attr=bias_attr,
  52. name='reduce_encoder_fea')
  53. def forward(self, x):
  54. x = self.fc(x)
  55. return x
  56. class SequenceEncoder(nn.Layer):
  57. def __init__(self, in_channels, encoder_type, hidden_size=48, **kwargs):
  58. super(SequenceEncoder, self).__init__()
  59. self.encoder_reshape = Im2Seq(in_channels)
  60. self.out_channels = self.encoder_reshape.out_channels
  61. if encoder_type == 'reshape':
  62. self.only_reshape = True
  63. else:
  64. support_encoder_dict = {
  65. 'reshape': Im2Seq,
  66. 'fc': EncoderWithFC,
  67. 'rnn': EncoderWithRNN
  68. }
  69. assert encoder_type in support_encoder_dict, '{} must in {}'.format(
  70. encoder_type, support_encoder_dict.keys())
  71. self.encoder = support_encoder_dict[encoder_type](
  72. self.encoder_reshape.out_channels, hidden_size)
  73. self.out_channels = self.encoder.out_channels
  74. self.only_reshape = False
  75. def forward(self, x):
  76. x = self.encoder_reshape(x)
  77. if not self.only_reshape:
  78. x = self.encoder(x)
  79. return x