self_attention.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240
  1. import keras
  2. from keras import backend as K
  3. class SeqSelfAttention(keras.layers.Layer):
  4. ATTENTION_TYPE_ADD = 'additive'
  5. ATTENTION_TYPE_MUL = 'multiplicative'
  6. def __init__(self,
  7. units=32,
  8. attention_width=None,
  9. attention_type=ATTENTION_TYPE_ADD,
  10. return_attention=False,
  11. history_only=False,
  12. kernel_initializer='glorot_normal',
  13. bias_initializer='zeros',
  14. kernel_regularizer=None,
  15. bias_regularizer=None,
  16. kernel_constraint=None,
  17. bias_constraint=None,
  18. use_additive_bias=True,
  19. use_attention_bias=True,
  20. attention_activation=None,
  21. attention_regularizer_weight=0.0,
  22. **kwargs):
  23. """Layer initialization.
  24. For additive attention, see: https://arxiv.org/pdf/1806.01264.pdf
  25. :param units: The dimension of the vectors that used to calculate the attention weights.
  26. :param attention_width: The width of local attention.
  27. :param attention_type: 'additive' or 'multiplicative'.
  28. :param return_attention: Whether to return the attention weights for visualization.
  29. :param history_only: Only use historical pieces of data.
  30. :param kernel_initializer: The initializer for weight matrices.
  31. :param bias_initializer: The initializer for biases.
  32. :param kernel_regularizer: The regularization for weight matrices.
  33. :param bias_regularizer: The regularization for biases.
  34. :param kernel_constraint: The constraint for weight matrices.
  35. :param bias_constraint: The constraint for biases.
  36. :param use_additive_bias: Whether to use bias while calculating the relevance of inputs features
  37. in additive mode.
  38. :param use_attention_bias: Whether to use bias while calculating the weights of attention.
  39. :param attention_activation: The activation used for calculating the weights of attention.
  40. :param attention_regularizer_weight: The weights of attention regularizer.
  41. :param kwargs: Parameters for parent class.
  42. """
  43. super(SeqSelfAttention, self).__init__(**kwargs)
  44. self.supports_masking = True
  45. self.units = units
  46. self.attention_width = attention_width
  47. self.attention_type = attention_type
  48. self.return_attention = return_attention
  49. self.history_only = history_only
  50. if history_only and attention_width is None:
  51. self.attention_width = int(1e9)
  52. self.use_additive_bias = use_additive_bias
  53. self.use_attention_bias = use_attention_bias
  54. self.kernel_initializer = keras.initializers.get(kernel_initializer)
  55. self.bias_initializer = keras.initializers.get(bias_initializer)
  56. self.kernel_regularizer = keras.regularizers.get(kernel_regularizer)
  57. self.bias_regularizer = keras.regularizers.get(bias_regularizer)
  58. self.kernel_constraint = keras.constraints.get(kernel_constraint)
  59. self.bias_constraint = keras.constraints.get(bias_constraint)
  60. self.attention_activation = keras.activations.get(attention_activation)
  61. self.attention_regularizer_weight = attention_regularizer_weight
  62. self._backend = keras.backend.backend()
  63. if attention_type == SeqSelfAttention.ATTENTION_TYPE_ADD:
  64. self.Wx, self.Wt, self.bh = None, None, None
  65. self.Wa, self.ba = None, None
  66. elif attention_type == SeqSelfAttention.ATTENTION_TYPE_MUL:
  67. self.Wa, self.ba = None, None
  68. else:
  69. raise NotImplementedError('No implementation for attention type : ' + attention_type)
  70. def get_config(self):
  71. config = {
  72. 'units': self.units,
  73. 'attention_width': self.attention_width,
  74. 'attention_type': self.attention_type,
  75. 'return_attention': self.return_attention,
  76. 'history_only': self.history_only,
  77. 'use_additive_bias': self.use_additive_bias,
  78. 'use_attention_bias': self.use_attention_bias,
  79. 'kernel_initializer': keras.initializers.serialize(self.kernel_initializer),
  80. 'bias_initializer': keras.initializers.serialize(self.bias_initializer),
  81. 'kernel_regularizer': keras.regularizers.serialize(self.kernel_regularizer),
  82. 'bias_regularizer': keras.regularizers.serialize(self.bias_regularizer),
  83. 'kernel_constraint': keras.constraints.serialize(self.kernel_constraint),
  84. 'bias_constraint': keras.constraints.serialize(self.bias_constraint),
  85. 'attention_activation': keras.activations.serialize(self.attention_activation),
  86. 'attention_regularizer_weight': self.attention_regularizer_weight,
  87. }
  88. base_config = super(SeqSelfAttention, self).get_config()
  89. return dict(list(base_config.items()) + list(config.items()))
  90. def build(self, input_shape):
  91. if self.attention_type == SeqSelfAttention.ATTENTION_TYPE_ADD:
  92. self._build_additive_attention(input_shape)
  93. elif self.attention_type == SeqSelfAttention.ATTENTION_TYPE_MUL:
  94. self._build_multiplicative_attention(input_shape)
  95. super(SeqSelfAttention, self).build(input_shape)
  96. def _build_additive_attention(self, input_shape):
  97. feature_dim = int(input_shape[2])
  98. self.Wt = self.add_weight(shape=(feature_dim, self.units),
  99. name='{}_Add_Wt'.format(self.name),
  100. initializer=self.kernel_initializer,
  101. regularizer=self.kernel_regularizer,
  102. constraint=self.kernel_constraint)
  103. self.Wx = self.add_weight(shape=(feature_dim, self.units),
  104. name='{}_Add_Wx'.format(self.name),
  105. initializer=self.kernel_initializer,
  106. regularizer=self.kernel_regularizer,
  107. constraint=self.kernel_constraint)
  108. if self.use_additive_bias:
  109. self.bh = self.add_weight(shape=(self.units,),
  110. name='{}_Add_bh'.format(self.name),
  111. initializer=self.bias_initializer,
  112. regularizer=self.bias_regularizer,
  113. constraint=self.bias_constraint)
  114. self.Wa = self.add_weight(shape=(self.units, 1),
  115. name='{}_Add_Wa'.format(self.name),
  116. initializer=self.kernel_initializer,
  117. regularizer=self.kernel_regularizer,
  118. constraint=self.kernel_constraint)
  119. if self.use_attention_bias:
  120. self.ba = self.add_weight(shape=(1,),
  121. name='{}_Add_ba'.format(self.name),
  122. initializer=self.bias_initializer,
  123. regularizer=self.bias_regularizer,
  124. constraint=self.bias_constraint)
  125. def _build_multiplicative_attention(self, input_shape):
  126. feature_dim = int(input_shape[2])
  127. self.Wa = self.add_weight(shape=(feature_dim, feature_dim),
  128. name='{}_Mul_Wa'.format(self.name),
  129. initializer=self.kernel_initializer,
  130. regularizer=self.kernel_regularizer,
  131. constraint=self.kernel_constraint)
  132. if self.use_attention_bias:
  133. self.ba = self.add_weight(shape=(1,),
  134. name='{}_Mul_ba'.format(self.name),
  135. initializer=self.bias_initializer,
  136. regularizer=self.bias_regularizer,
  137. constraint=self.bias_constraint)
  138. def call(self, inputs, mask=None, **kwargs):
  139. input_len = K.shape(inputs)[1]
  140. if self.attention_type == SeqSelfAttention.ATTENTION_TYPE_ADD:
  141. e = self._call_additive_emission(inputs)
  142. elif self.attention_type == SeqSelfAttention.ATTENTION_TYPE_MUL:
  143. e = self._call_multiplicative_emission(inputs)
  144. if self.attention_activation is not None:
  145. e = self.attention_activation(e)
  146. if self.attention_width is not None:
  147. if self.history_only:
  148. lower = K.arange(0, input_len) - (self.attention_width - 1)
  149. else:
  150. lower = K.arange(0, input_len) - self.attention_width // 2
  151. lower = K.expand_dims(lower, axis=-1)
  152. upper = lower + self.attention_width
  153. indices = K.expand_dims(K.arange(0, input_len), axis=0)
  154. e -= 10000.0 * (1.0 - K.cast(lower <= indices, K.floatx()) * K.cast(indices < upper, K.floatx()))
  155. if mask is not None:
  156. mask = K.expand_dims(K.cast(mask, K.floatx()), axis=-1)
  157. e -= 10000.0 * ((1.0 - mask) * (1.0 - K.permute_dimensions(mask, (0, 2, 1))))
  158. # a_{t} = \text{softmax}(e_t)
  159. e = K.exp(e - K.max(e, axis=-1, keepdims=True))
  160. a = e / K.sum(e, axis=-1, keepdims=True)
  161. # l_t = \sum_{t'} a_{t, t'} x_{t'}
  162. v = K.batch_dot(a, inputs)
  163. if self.attention_regularizer_weight > 0.0:
  164. self.add_loss(self._attention_regularizer(a))
  165. if self.return_attention:
  166. return [v, a]
  167. return v
  168. def _call_additive_emission(self, inputs):
  169. input_shape = K.shape(inputs)
  170. batch_size, input_len = input_shape[0], input_shape[1]
  171. # h_{t, t'} = \tanh(x_t^T W_t + x_{t'}^T W_x + b_h)
  172. q = K.expand_dims(K.dot(inputs, self.Wt), 2)
  173. k = K.expand_dims(K.dot(inputs, self.Wx), 1)
  174. if self.use_additive_bias:
  175. h = K.tanh(q + k + self.bh)
  176. else:
  177. h = K.tanh(q + k)
  178. # e_{t, t'} = W_a h_{t, t'} + b_a
  179. if self.use_attention_bias:
  180. e = K.reshape(K.dot(h, self.Wa) + self.ba, (batch_size, input_len, input_len))
  181. else:
  182. e = K.reshape(K.dot(h, self.Wa), (batch_size, input_len, input_len))
  183. return e
  184. def _call_multiplicative_emission(self, inputs):
  185. # e_{t, t'} = x_t^T W_a x_{t'} + b_a
  186. e = K.batch_dot(K.dot(inputs, self.Wa), K.permute_dimensions(inputs, (0, 2, 1)))
  187. if self.use_attention_bias:
  188. e += self.ba[0]
  189. return e
  190. def compute_output_shape(self, input_shape):
  191. output_shape = input_shape
  192. if self.return_attention:
  193. attention_shape = (input_shape[0], output_shape[1], input_shape[1])
  194. return [output_shape, attention_shape]
  195. return output_shape
  196. def compute_mask(self, inputs, mask=None):
  197. if self.return_attention:
  198. return [mask, None]
  199. return mask
  200. def _attention_regularizer(self, attention):
  201. batch_size = K.cast(K.shape(attention)[0], K.floatx())
  202. input_len = K.shape(attention)[-1]
  203. indices = K.expand_dims(K.arange(0, input_len), axis=0)
  204. diagonal = K.expand_dims(K.arange(0, input_len), axis=-1)
  205. eye = K.cast(K.equal(indices, diagonal), K.floatx())
  206. return self.attention_regularizer_weight * K.sum(K.square(K.batch_dot(
  207. attention,
  208. K.permute_dimensions(attention, (0, 2, 1))) - eye)) / batch_size
  209. @staticmethod
  210. def get_custom_objects():
  211. return {'SeqSelfAttention': SeqSelfAttention}