self_attention.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279
  1. import keras
  2. from keras import backend as K
  3. from keras.layers import Layer
  4. class SeqSelfAttention(keras.layers.Layer):
  5. ATTENTION_TYPE_ADD = 'additive'
  6. ATTENTION_TYPE_MUL = 'multiplicative'
  7. def __init__(self,
  8. units=32,
  9. attention_width=None,
  10. attention_type=ATTENTION_TYPE_ADD,
  11. return_attention=False,
  12. history_only=False,
  13. kernel_initializer='glorot_normal',
  14. bias_initializer='zeros',
  15. kernel_regularizer=None,
  16. bias_regularizer=None,
  17. kernel_constraint=None,
  18. bias_constraint=None,
  19. use_additive_bias=True,
  20. use_attention_bias=True,
  21. attention_activation=None,
  22. attention_regularizer_weight=0.0,
  23. **kwargs):
  24. """Layer initialization.
  25. For additive attention, see: https://arxiv.org/pdf/1806.01264.pdf
  26. :param units: The dimension of the vectors that used to calculate the attention weights.
  27. :param attention_width: The width of local attention.
  28. :param attention_type: 'additive' or 'multiplicative'.
  29. :param return_attention: Whether to return the attention weights for visualization.
  30. :param history_only: Only use historical pieces of data.
  31. :param kernel_initializer: The initializer for weight matrices.
  32. :param bias_initializer: The initializer for biases.
  33. :param kernel_regularizer: The regularization for weight matrices.
  34. :param bias_regularizer: The regularization for biases.
  35. :param kernel_constraint: The constraint for weight matrices.
  36. :param bias_constraint: The constraint for biases.
  37. :param use_additive_bias: Whether to use bias while calculating the relevance of inputs features
  38. in additive mode.
  39. :param use_attention_bias: Whether to use bias while calculating the weights of attention.
  40. :param attention_activation: The activation used for calculating the weights of attention.
  41. :param attention_regularizer_weight: The weights of attention regularizer.
  42. :param kwargs: Parameters for parent class.
  43. """
  44. super(SeqSelfAttention, self).__init__(**kwargs)
  45. self.supports_masking = True
  46. self.units = units
  47. self.attention_width = attention_width
  48. self.attention_type = attention_type
  49. self.return_attention = return_attention
  50. self.history_only = history_only
  51. if history_only and attention_width is None:
  52. self.attention_width = int(1e9)
  53. self.use_additive_bias = use_additive_bias
  54. self.use_attention_bias = use_attention_bias
  55. self.kernel_initializer = keras.initializers.get(kernel_initializer)
  56. self.bias_initializer = keras.initializers.get(bias_initializer)
  57. self.kernel_regularizer = keras.regularizers.get(kernel_regularizer)
  58. self.bias_regularizer = keras.regularizers.get(bias_regularizer)
  59. self.kernel_constraint = keras.constraints.get(kernel_constraint)
  60. self.bias_constraint = keras.constraints.get(bias_constraint)
  61. self.attention_activation = keras.activations.get(attention_activation)
  62. self.attention_regularizer_weight = attention_regularizer_weight
  63. self._backend = keras.backend.backend()
  64. if attention_type == SeqSelfAttention.ATTENTION_TYPE_ADD:
  65. self.Wx, self.Wt, self.bh = None, None, None
  66. self.Wa, self.ba = None, None
  67. elif attention_type == SeqSelfAttention.ATTENTION_TYPE_MUL:
  68. self.Wa, self.ba = None, None
  69. else:
  70. raise NotImplementedError('No implementation for attention type : ' + attention_type)
  71. def get_config(self):
  72. config = {
  73. 'units': self.units,
  74. 'attention_width': self.attention_width,
  75. 'attention_type': self.attention_type,
  76. 'return_attention': self.return_attention,
  77. 'history_only': self.history_only,
  78. 'use_additive_bias': self.use_additive_bias,
  79. 'use_attention_bias': self.use_attention_bias,
  80. 'kernel_initializer': keras.initializers.serialize(self.kernel_initializer),
  81. 'bias_initializer': keras.initializers.serialize(self.bias_initializer),
  82. 'kernel_regularizer': keras.regularizers.serialize(self.kernel_regularizer),
  83. 'bias_regularizer': keras.regularizers.serialize(self.bias_regularizer),
  84. 'kernel_constraint': keras.constraints.serialize(self.kernel_constraint),
  85. 'bias_constraint': keras.constraints.serialize(self.bias_constraint),
  86. 'attention_activation': keras.activations.serialize(self.attention_activation),
  87. 'attention_regularizer_weight': self.attention_regularizer_weight,
  88. }
  89. base_config = super(SeqSelfAttention, self).get_config()
  90. return dict(list(base_config.items()) + list(config.items()))
  91. def build(self, input_shape):
  92. if self.attention_type == SeqSelfAttention.ATTENTION_TYPE_ADD:
  93. self._build_additive_attention(input_shape)
  94. elif self.attention_type == SeqSelfAttention.ATTENTION_TYPE_MUL:
  95. self._build_multiplicative_attention(input_shape)
  96. super(SeqSelfAttention, self).build(input_shape)
  97. def _build_additive_attention(self, input_shape):
  98. feature_dim = int(input_shape[2])
  99. self.Wt = self.add_weight(shape=(feature_dim, self.units),
  100. name='{}_Add_Wt'.format(self.name),
  101. initializer=self.kernel_initializer,
  102. regularizer=self.kernel_regularizer,
  103. constraint=self.kernel_constraint)
  104. self.Wx = self.add_weight(shape=(feature_dim, self.units),
  105. name='{}_Add_Wx'.format(self.name),
  106. initializer=self.kernel_initializer,
  107. regularizer=self.kernel_regularizer,
  108. constraint=self.kernel_constraint)
  109. if self.use_additive_bias:
  110. self.bh = self.add_weight(shape=(self.units,),
  111. name='{}_Add_bh'.format(self.name),
  112. initializer=self.bias_initializer,
  113. regularizer=self.bias_regularizer,
  114. constraint=self.bias_constraint)
  115. self.Wa = self.add_weight(shape=(self.units, 1),
  116. name='{}_Add_Wa'.format(self.name),
  117. initializer=self.kernel_initializer,
  118. regularizer=self.kernel_regularizer,
  119. constraint=self.kernel_constraint)
  120. if self.use_attention_bias:
  121. self.ba = self.add_weight(shape=(1,),
  122. name='{}_Add_ba'.format(self.name),
  123. initializer=self.bias_initializer,
  124. regularizer=self.bias_regularizer,
  125. constraint=self.bias_constraint)
  126. def _build_multiplicative_attention(self, input_shape):
  127. feature_dim = int(input_shape[2])
  128. self.Wa = self.add_weight(shape=(feature_dim, feature_dim),
  129. name='{}_Mul_Wa'.format(self.name),
  130. initializer=self.kernel_initializer,
  131. regularizer=self.kernel_regularizer,
  132. constraint=self.kernel_constraint)
  133. if self.use_attention_bias:
  134. self.ba = self.add_weight(shape=(1,),
  135. name='{}_Mul_ba'.format(self.name),
  136. initializer=self.bias_initializer,
  137. regularizer=self.bias_regularizer,
  138. constraint=self.bias_constraint)
  139. def call(self, inputs, mask=None, **kwargs):
  140. input_len = K.shape(inputs)[1]
  141. if self.attention_type == SeqSelfAttention.ATTENTION_TYPE_ADD:
  142. e = self._call_additive_emission(inputs)
  143. elif self.attention_type == SeqSelfAttention.ATTENTION_TYPE_MUL:
  144. e = self._call_multiplicative_emission(inputs)
  145. if self.attention_activation is not None:
  146. e = self.attention_activation(e)
  147. if self.attention_width is not None:
  148. if self.history_only:
  149. lower = K.arange(0, input_len) - (self.attention_width - 1)
  150. else:
  151. lower = K.arange(0, input_len) - self.attention_width // 2
  152. lower = K.expand_dims(lower, axis=-1)
  153. upper = lower + self.attention_width
  154. indices = K.expand_dims(K.arange(0, input_len), axis=0)
  155. e -= 10000.0 * (1.0 - K.cast(lower <= indices, K.floatx()) * K.cast(indices < upper, K.floatx()))
  156. if mask is not None:
  157. mask = K.expand_dims(K.cast(mask, K.floatx()), axis=-1)
  158. e -= 10000.0 * ((1.0 - mask) * (1.0 - K.permute_dimensions(mask, (0, 2, 1))))
  159. # a_{t} = \text{softmax}(e_t)
  160. e = K.exp(e - K.max(e, axis=-1, keepdims=True))
  161. a = e / K.sum(e, axis=-1, keepdims=True)
  162. # l_t = \sum_{t'} a_{t, t'} x_{t'}
  163. v = K.batch_dot(a, inputs)
  164. if self.attention_regularizer_weight > 0.0:
  165. self.add_loss(self._attention_regularizer(a))
  166. if self.return_attention:
  167. return [v, a]
  168. return v
  169. def _call_additive_emission(self, inputs):
  170. input_shape = K.shape(inputs)
  171. batch_size, input_len = input_shape[0], input_shape[1]
  172. # h_{t, t'} = \tanh(x_t^T W_t + x_{t'}^T W_x + b_h)
  173. q = K.expand_dims(K.dot(inputs, self.Wt), 2)
  174. k = K.expand_dims(K.dot(inputs, self.Wx), 1)
  175. if self.use_additive_bias:
  176. h = K.tanh(q + k + self.bh)
  177. else:
  178. h = K.tanh(q + k)
  179. # e_{t, t'} = W_a h_{t, t'} + b_a
  180. if self.use_attention_bias:
  181. e = K.reshape(K.dot(h, self.Wa) + self.ba, (batch_size, input_len, input_len))
  182. else:
  183. e = K.reshape(K.dot(h, self.Wa), (batch_size, input_len, input_len))
  184. return e
  185. def _call_multiplicative_emission(self, inputs):
  186. # e_{t, t'} = x_t^T W_a x_{t'} + b_a
  187. e = K.batch_dot(K.dot(inputs, self.Wa), K.permute_dimensions(inputs, (0, 2, 1)))
  188. if self.use_attention_bias:
  189. e += self.ba[0]
  190. return e
  191. def compute_output_shape(self, input_shape):
  192. output_shape = input_shape
  193. if self.return_attention:
  194. attention_shape = (input_shape[0], output_shape[1], input_shape[1])
  195. return [output_shape, attention_shape]
  196. return output_shape
  197. def compute_mask(self, inputs, mask=None):
  198. if self.return_attention:
  199. return [mask, None]
  200. return mask
  201. def _attention_regularizer(self, attention):
  202. batch_size = K.cast(K.shape(attention)[0], K.floatx())
  203. input_len = K.shape(attention)[-1]
  204. indices = K.expand_dims(K.arange(0, input_len), axis=0)
  205. diagonal = K.expand_dims(K.arange(0, input_len), axis=-1)
  206. eye = K.cast(K.equal(indices, diagonal), K.floatx())
  207. return self.attention_regularizer_weight * K.sum(K.square(K.batch_dot(
  208. attention,
  209. K.permute_dimensions(attention, (0, 2, 1))) - eye)) / batch_size
  210. @staticmethod
  211. def get_custom_objects():
  212. return {'SeqSelfAttention': SeqSelfAttention}
  213. class MySelfAttention(Layer):
  214. def __init__(self, output_dim, **kwargs):
  215. self.output_dim = output_dim
  216. super(MySelfAttention, self).__init__(**kwargs)
  217. def build(self, input_shape):
  218. # inputs.shape = (batch_size, time_steps, seq_len)
  219. self.W_Q = self.add_weight(name='W_Q',
  220. shape=(input_shape[2], self.output_dim),
  221. initializer='uniform',
  222. trainable=True)
  223. self.W_K = self.add_weight(name='W_K',
  224. shape=(input_shape[2], self.output_dim),
  225. initializer='uniform',
  226. trainable=True)
  227. self.W_V = self.add_weight(name='W_V',
  228. shape=(input_shape[2], self.output_dim),
  229. initializer='uniform',
  230. trainable=True)
  231. super(MySelfAttention, self).build(input_shape)
  232. def call(self, x, mask=None, **kwargs):
  233. _Q = K.dot(x, self.W_Q)
  234. _K = K.dot(x, self.W_K)
  235. _V = K.dot(x, self.W_V)
  236. # batch_dot代替K.T
  237. _Z = K.batch_dot(_Q, K.permute_dimensions(_K, [0, 2, 1]))
  238. _Z = _Z / (self.output_dim**0.5)
  239. _Z = K.softmax(_Z)
  240. _Z = K.batch_dot(_Z, _V)
  241. return _Z
  242. def compute_output_shape(self, input_shape):
  243. return input_shape[0], input_shape[1], self.output_dim