|
@@ -0,0 +1,240 @@
|
|
|
|
+import keras
|
|
|
|
+from keras import backend as K
|
|
|
|
+
|
|
|
|
+
|
|
|
|
+class SeqSelfAttention(keras.layers.Layer):
|
|
|
|
+
|
|
|
|
+ ATTENTION_TYPE_ADD = 'additive'
|
|
|
|
+ ATTENTION_TYPE_MUL = 'multiplicative'
|
|
|
|
+
|
|
|
|
+ def __init__(self,
|
|
|
|
+ units=32,
|
|
|
|
+ attention_width=None,
|
|
|
|
+ attention_type=ATTENTION_TYPE_ADD,
|
|
|
|
+ return_attention=False,
|
|
|
|
+ history_only=False,
|
|
|
|
+ kernel_initializer='glorot_normal',
|
|
|
|
+ bias_initializer='zeros',
|
|
|
|
+ kernel_regularizer=None,
|
|
|
|
+ bias_regularizer=None,
|
|
|
|
+ kernel_constraint=None,
|
|
|
|
+ bias_constraint=None,
|
|
|
|
+ use_additive_bias=True,
|
|
|
|
+ use_attention_bias=True,
|
|
|
|
+ attention_activation=None,
|
|
|
|
+ attention_regularizer_weight=0.0,
|
|
|
|
+ **kwargs):
|
|
|
|
+ """Layer initialization.
|
|
|
|
+
|
|
|
|
+ For additive attention, see: https://arxiv.org/pdf/1806.01264.pdf
|
|
|
|
+
|
|
|
|
+ :param units: The dimension of the vectors that used to calculate the attention weights.
|
|
|
|
+ :param attention_width: The width of local attention.
|
|
|
|
+ :param attention_type: 'additive' or 'multiplicative'.
|
|
|
|
+ :param return_attention: Whether to return the attention weights for visualization.
|
|
|
|
+ :param history_only: Only use historical pieces of data.
|
|
|
|
+ :param kernel_initializer: The initializer for weight matrices.
|
|
|
|
+ :param bias_initializer: The initializer for biases.
|
|
|
|
+ :param kernel_regularizer: The regularization for weight matrices.
|
|
|
|
+ :param bias_regularizer: The regularization for biases.
|
|
|
|
+ :param kernel_constraint: The constraint for weight matrices.
|
|
|
|
+ :param bias_constraint: The constraint for biases.
|
|
|
|
+ :param use_additive_bias: Whether to use bias while calculating the relevance of inputs features
|
|
|
|
+ in additive mode.
|
|
|
|
+ :param use_attention_bias: Whether to use bias while calculating the weights of attention.
|
|
|
|
+ :param attention_activation: The activation used for calculating the weights of attention.
|
|
|
|
+ :param attention_regularizer_weight: The weights of attention regularizer.
|
|
|
|
+ :param kwargs: Parameters for parent class.
|
|
|
|
+ """
|
|
|
|
+ super(SeqSelfAttention, self).__init__(**kwargs)
|
|
|
|
+ self.supports_masking = True
|
|
|
|
+ self.units = units
|
|
|
|
+ self.attention_width = attention_width
|
|
|
|
+ self.attention_type = attention_type
|
|
|
|
+ self.return_attention = return_attention
|
|
|
|
+ self.history_only = history_only
|
|
|
|
+ if history_only and attention_width is None:
|
|
|
|
+ self.attention_width = int(1e9)
|
|
|
|
+
|
|
|
|
+ self.use_additive_bias = use_additive_bias
|
|
|
|
+ self.use_attention_bias = use_attention_bias
|
|
|
|
+ self.kernel_initializer = keras.initializers.get(kernel_initializer)
|
|
|
|
+ self.bias_initializer = keras.initializers.get(bias_initializer)
|
|
|
|
+ self.kernel_regularizer = keras.regularizers.get(kernel_regularizer)
|
|
|
|
+ self.bias_regularizer = keras.regularizers.get(bias_regularizer)
|
|
|
|
+ self.kernel_constraint = keras.constraints.get(kernel_constraint)
|
|
|
|
+ self.bias_constraint = keras.constraints.get(bias_constraint)
|
|
|
|
+ self.attention_activation = keras.activations.get(attention_activation)
|
|
|
|
+ self.attention_regularizer_weight = attention_regularizer_weight
|
|
|
|
+ self._backend = keras.backend.backend()
|
|
|
|
+
|
|
|
|
+ if attention_type == SeqSelfAttention.ATTENTION_TYPE_ADD:
|
|
|
|
+ self.Wx, self.Wt, self.bh = None, None, None
|
|
|
|
+ self.Wa, self.ba = None, None
|
|
|
|
+ elif attention_type == SeqSelfAttention.ATTENTION_TYPE_MUL:
|
|
|
|
+ self.Wa, self.ba = None, None
|
|
|
|
+ else:
|
|
|
|
+ raise NotImplementedError('No implementation for attention type : ' + attention_type)
|
|
|
|
+
|
|
|
|
+ def get_config(self):
|
|
|
|
+ config = {
|
|
|
|
+ 'units': self.units,
|
|
|
|
+ 'attention_width': self.attention_width,
|
|
|
|
+ 'attention_type': self.attention_type,
|
|
|
|
+ 'return_attention': self.return_attention,
|
|
|
|
+ 'history_only': self.history_only,
|
|
|
|
+ 'use_additive_bias': self.use_additive_bias,
|
|
|
|
+ 'use_attention_bias': self.use_attention_bias,
|
|
|
|
+ 'kernel_initializer': keras.initializers.serialize(self.kernel_initializer),
|
|
|
|
+ 'bias_initializer': keras.initializers.serialize(self.bias_initializer),
|
|
|
|
+ 'kernel_regularizer': keras.regularizers.serialize(self.kernel_regularizer),
|
|
|
|
+ 'bias_regularizer': keras.regularizers.serialize(self.bias_regularizer),
|
|
|
|
+ 'kernel_constraint': keras.constraints.serialize(self.kernel_constraint),
|
|
|
|
+ 'bias_constraint': keras.constraints.serialize(self.bias_constraint),
|
|
|
|
+ 'attention_activation': keras.activations.serialize(self.attention_activation),
|
|
|
|
+ 'attention_regularizer_weight': self.attention_regularizer_weight,
|
|
|
|
+ }
|
|
|
|
+ base_config = super(SeqSelfAttention, self).get_config()
|
|
|
|
+ return dict(list(base_config.items()) + list(config.items()))
|
|
|
|
+
|
|
|
|
+ def build(self, input_shape):
|
|
|
|
+ if self.attention_type == SeqSelfAttention.ATTENTION_TYPE_ADD:
|
|
|
|
+ self._build_additive_attention(input_shape)
|
|
|
|
+ elif self.attention_type == SeqSelfAttention.ATTENTION_TYPE_MUL:
|
|
|
|
+ self._build_multiplicative_attention(input_shape)
|
|
|
|
+ super(SeqSelfAttention, self).build(input_shape)
|
|
|
|
+
|
|
|
|
+ def _build_additive_attention(self, input_shape):
|
|
|
|
+ feature_dim = int(input_shape[2])
|
|
|
|
+
|
|
|
|
+ self.Wt = self.add_weight(shape=(feature_dim, self.units),
|
|
|
|
+ name='{}_Add_Wt'.format(self.name),
|
|
|
|
+ initializer=self.kernel_initializer,
|
|
|
|
+ regularizer=self.kernel_regularizer,
|
|
|
|
+ constraint=self.kernel_constraint)
|
|
|
|
+ self.Wx = self.add_weight(shape=(feature_dim, self.units),
|
|
|
|
+ name='{}_Add_Wx'.format(self.name),
|
|
|
|
+ initializer=self.kernel_initializer,
|
|
|
|
+ regularizer=self.kernel_regularizer,
|
|
|
|
+ constraint=self.kernel_constraint)
|
|
|
|
+ if self.use_additive_bias:
|
|
|
|
+ self.bh = self.add_weight(shape=(self.units,),
|
|
|
|
+ name='{}_Add_bh'.format(self.name),
|
|
|
|
+ initializer=self.bias_initializer,
|
|
|
|
+ regularizer=self.bias_regularizer,
|
|
|
|
+ constraint=self.bias_constraint)
|
|
|
|
+
|
|
|
|
+ self.Wa = self.add_weight(shape=(self.units, 1),
|
|
|
|
+ name='{}_Add_Wa'.format(self.name),
|
|
|
|
+ initializer=self.kernel_initializer,
|
|
|
|
+ regularizer=self.kernel_regularizer,
|
|
|
|
+ constraint=self.kernel_constraint)
|
|
|
|
+ if self.use_attention_bias:
|
|
|
|
+ self.ba = self.add_weight(shape=(1,),
|
|
|
|
+ name='{}_Add_ba'.format(self.name),
|
|
|
|
+ initializer=self.bias_initializer,
|
|
|
|
+ regularizer=self.bias_regularizer,
|
|
|
|
+ constraint=self.bias_constraint)
|
|
|
|
+
|
|
|
|
+ def _build_multiplicative_attention(self, input_shape):
|
|
|
|
+ feature_dim = int(input_shape[2])
|
|
|
|
+
|
|
|
|
+ self.Wa = self.add_weight(shape=(feature_dim, feature_dim),
|
|
|
|
+ name='{}_Mul_Wa'.format(self.name),
|
|
|
|
+ initializer=self.kernel_initializer,
|
|
|
|
+ regularizer=self.kernel_regularizer,
|
|
|
|
+ constraint=self.kernel_constraint)
|
|
|
|
+ if self.use_attention_bias:
|
|
|
|
+ self.ba = self.add_weight(shape=(1,),
|
|
|
|
+ name='{}_Mul_ba'.format(self.name),
|
|
|
|
+ initializer=self.bias_initializer,
|
|
|
|
+ regularizer=self.bias_regularizer,
|
|
|
|
+ constraint=self.bias_constraint)
|
|
|
|
+
|
|
|
|
+ def call(self, inputs, mask=None, **kwargs):
|
|
|
|
+ input_len = K.shape(inputs)[1]
|
|
|
|
+
|
|
|
|
+ if self.attention_type == SeqSelfAttention.ATTENTION_TYPE_ADD:
|
|
|
|
+ e = self._call_additive_emission(inputs)
|
|
|
|
+ elif self.attention_type == SeqSelfAttention.ATTENTION_TYPE_MUL:
|
|
|
|
+ e = self._call_multiplicative_emission(inputs)
|
|
|
|
+
|
|
|
|
+ if self.attention_activation is not None:
|
|
|
|
+ e = self.attention_activation(e)
|
|
|
|
+ if self.attention_width is not None:
|
|
|
|
+ if self.history_only:
|
|
|
|
+ lower = K.arange(0, input_len) - (self.attention_width - 1)
|
|
|
|
+ else:
|
|
|
|
+ lower = K.arange(0, input_len) - self.attention_width // 2
|
|
|
|
+ lower = K.expand_dims(lower, axis=-1)
|
|
|
|
+ upper = lower + self.attention_width
|
|
|
|
+ indices = K.expand_dims(K.arange(0, input_len), axis=0)
|
|
|
|
+ e -= 10000.0 * (1.0 - K.cast(lower <= indices, K.floatx()) * K.cast(indices < upper, K.floatx()))
|
|
|
|
+ if mask is not None:
|
|
|
|
+ mask = K.expand_dims(K.cast(mask, K.floatx()), axis=-1)
|
|
|
|
+ e -= 10000.0 * ((1.0 - mask) * (1.0 - K.permute_dimensions(mask, (0, 2, 1))))
|
|
|
|
+
|
|
|
|
+ # a_{t} = \text{softmax}(e_t)
|
|
|
|
+ e = K.exp(e - K.max(e, axis=-1, keepdims=True))
|
|
|
|
+ a = e / K.sum(e, axis=-1, keepdims=True)
|
|
|
|
+
|
|
|
|
+ # l_t = \sum_{t'} a_{t, t'} x_{t'}
|
|
|
|
+ v = K.batch_dot(a, inputs)
|
|
|
|
+ if self.attention_regularizer_weight > 0.0:
|
|
|
|
+ self.add_loss(self._attention_regularizer(a))
|
|
|
|
+
|
|
|
|
+ if self.return_attention:
|
|
|
|
+ return [v, a]
|
|
|
|
+ return v
|
|
|
|
+
|
|
|
|
+ def _call_additive_emission(self, inputs):
|
|
|
|
+ input_shape = K.shape(inputs)
|
|
|
|
+ batch_size, input_len = input_shape[0], input_shape[1]
|
|
|
|
+
|
|
|
|
+ # h_{t, t'} = \tanh(x_t^T W_t + x_{t'}^T W_x + b_h)
|
|
|
|
+ q = K.expand_dims(K.dot(inputs, self.Wt), 2)
|
|
|
|
+ k = K.expand_dims(K.dot(inputs, self.Wx), 1)
|
|
|
|
+ if self.use_additive_bias:
|
|
|
|
+ h = K.tanh(q + k + self.bh)
|
|
|
|
+ else:
|
|
|
|
+ h = K.tanh(q + k)
|
|
|
|
+
|
|
|
|
+ # e_{t, t'} = W_a h_{t, t'} + b_a
|
|
|
|
+ if self.use_attention_bias:
|
|
|
|
+ e = K.reshape(K.dot(h, self.Wa) + self.ba, (batch_size, input_len, input_len))
|
|
|
|
+ else:
|
|
|
|
+ e = K.reshape(K.dot(h, self.Wa), (batch_size, input_len, input_len))
|
|
|
|
+ return e
|
|
|
|
+
|
|
|
|
+ def _call_multiplicative_emission(self, inputs):
|
|
|
|
+ # e_{t, t'} = x_t^T W_a x_{t'} + b_a
|
|
|
|
+ e = K.batch_dot(K.dot(inputs, self.Wa), K.permute_dimensions(inputs, (0, 2, 1)))
|
|
|
|
+ if self.use_attention_bias:
|
|
|
|
+ e += self.ba[0]
|
|
|
|
+ return e
|
|
|
|
+
|
|
|
|
+ def compute_output_shape(self, input_shape):
|
|
|
|
+ output_shape = input_shape
|
|
|
|
+ if self.return_attention:
|
|
|
|
+ attention_shape = (input_shape[0], output_shape[1], input_shape[1])
|
|
|
|
+ return [output_shape, attention_shape]
|
|
|
|
+ return output_shape
|
|
|
|
+
|
|
|
|
+ def compute_mask(self, inputs, mask=None):
|
|
|
|
+ if self.return_attention:
|
|
|
|
+ return [mask, None]
|
|
|
|
+ return mask
|
|
|
|
+
|
|
|
|
+ def _attention_regularizer(self, attention):
|
|
|
|
+ batch_size = K.cast(K.shape(attention)[0], K.floatx())
|
|
|
|
+ input_len = K.shape(attention)[-1]
|
|
|
|
+ indices = K.expand_dims(K.arange(0, input_len), axis=0)
|
|
|
|
+ diagonal = K.expand_dims(K.arange(0, input_len), axis=-1)
|
|
|
|
+ eye = K.cast(K.equal(indices, diagonal), K.floatx())
|
|
|
|
+ return self.attention_regularizer_weight * K.sum(K.square(K.batch_dot(
|
|
|
|
+ attention,
|
|
|
|
+ K.permute_dimensions(attention, (0, 2, 1))) - eye)) / batch_size
|
|
|
|
+
|
|
|
|
+ @staticmethod
|
|
|
|
+ def get_custom_objects():
|
|
|
|
+ return {'SeqSelfAttention': SeqSelfAttention}
|