123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240 |
- import keras
- from keras import backend as K
- class SeqSelfAttention(keras.layers.Layer):
- ATTENTION_TYPE_ADD = 'additive'
- ATTENTION_TYPE_MUL = 'multiplicative'
- def __init__(self,
- units=32,
- attention_width=None,
- attention_type=ATTENTION_TYPE_ADD,
- return_attention=False,
- history_only=False,
- kernel_initializer='glorot_normal',
- bias_initializer='zeros',
- kernel_regularizer=None,
- bias_regularizer=None,
- kernel_constraint=None,
- bias_constraint=None,
- use_additive_bias=True,
- use_attention_bias=True,
- attention_activation=None,
- attention_regularizer_weight=0.0,
- **kwargs):
- """Layer initialization.
- For additive attention, see: https://arxiv.org/pdf/1806.01264.pdf
- :param units: The dimension of the vectors that used to calculate the attention weights.
- :param attention_width: The width of local attention.
- :param attention_type: 'additive' or 'multiplicative'.
- :param return_attention: Whether to return the attention weights for visualization.
- :param history_only: Only use historical pieces of data.
- :param kernel_initializer: The initializer for weight matrices.
- :param bias_initializer: The initializer for biases.
- :param kernel_regularizer: The regularization for weight matrices.
- :param bias_regularizer: The regularization for biases.
- :param kernel_constraint: The constraint for weight matrices.
- :param bias_constraint: The constraint for biases.
- :param use_additive_bias: Whether to use bias while calculating the relevance of inputs features
- in additive mode.
- :param use_attention_bias: Whether to use bias while calculating the weights of attention.
- :param attention_activation: The activation used for calculating the weights of attention.
- :param attention_regularizer_weight: The weights of attention regularizer.
- :param kwargs: Parameters for parent class.
- """
- super(SeqSelfAttention, self).__init__(**kwargs)
- self.supports_masking = True
- self.units = units
- self.attention_width = attention_width
- self.attention_type = attention_type
- self.return_attention = return_attention
- self.history_only = history_only
- if history_only and attention_width is None:
- self.attention_width = int(1e9)
- self.use_additive_bias = use_additive_bias
- self.use_attention_bias = use_attention_bias
- self.kernel_initializer = keras.initializers.get(kernel_initializer)
- self.bias_initializer = keras.initializers.get(bias_initializer)
- self.kernel_regularizer = keras.regularizers.get(kernel_regularizer)
- self.bias_regularizer = keras.regularizers.get(bias_regularizer)
- self.kernel_constraint = keras.constraints.get(kernel_constraint)
- self.bias_constraint = keras.constraints.get(bias_constraint)
- self.attention_activation = keras.activations.get(attention_activation)
- self.attention_regularizer_weight = attention_regularizer_weight
- self._backend = keras.backend.backend()
- if attention_type == SeqSelfAttention.ATTENTION_TYPE_ADD:
- self.Wx, self.Wt, self.bh = None, None, None
- self.Wa, self.ba = None, None
- elif attention_type == SeqSelfAttention.ATTENTION_TYPE_MUL:
- self.Wa, self.ba = None, None
- else:
- raise NotImplementedError('No implementation for attention type : ' + attention_type)
- def get_config(self):
- config = {
- 'units': self.units,
- 'attention_width': self.attention_width,
- 'attention_type': self.attention_type,
- 'return_attention': self.return_attention,
- 'history_only': self.history_only,
- 'use_additive_bias': self.use_additive_bias,
- 'use_attention_bias': self.use_attention_bias,
- 'kernel_initializer': keras.initializers.serialize(self.kernel_initializer),
- 'bias_initializer': keras.initializers.serialize(self.bias_initializer),
- 'kernel_regularizer': keras.regularizers.serialize(self.kernel_regularizer),
- 'bias_regularizer': keras.regularizers.serialize(self.bias_regularizer),
- 'kernel_constraint': keras.constraints.serialize(self.kernel_constraint),
- 'bias_constraint': keras.constraints.serialize(self.bias_constraint),
- 'attention_activation': keras.activations.serialize(self.attention_activation),
- 'attention_regularizer_weight': self.attention_regularizer_weight,
- }
- base_config = super(SeqSelfAttention, self).get_config()
- return dict(list(base_config.items()) + list(config.items()))
- def build(self, input_shape):
- if self.attention_type == SeqSelfAttention.ATTENTION_TYPE_ADD:
- self._build_additive_attention(input_shape)
- elif self.attention_type == SeqSelfAttention.ATTENTION_TYPE_MUL:
- self._build_multiplicative_attention(input_shape)
- super(SeqSelfAttention, self).build(input_shape)
- def _build_additive_attention(self, input_shape):
- feature_dim = int(input_shape[2])
- self.Wt = self.add_weight(shape=(feature_dim, self.units),
- name='{}_Add_Wt'.format(self.name),
- initializer=self.kernel_initializer,
- regularizer=self.kernel_regularizer,
- constraint=self.kernel_constraint)
- self.Wx = self.add_weight(shape=(feature_dim, self.units),
- name='{}_Add_Wx'.format(self.name),
- initializer=self.kernel_initializer,
- regularizer=self.kernel_regularizer,
- constraint=self.kernel_constraint)
- if self.use_additive_bias:
- self.bh = self.add_weight(shape=(self.units,),
- name='{}_Add_bh'.format(self.name),
- initializer=self.bias_initializer,
- regularizer=self.bias_regularizer,
- constraint=self.bias_constraint)
- self.Wa = self.add_weight(shape=(self.units, 1),
- name='{}_Add_Wa'.format(self.name),
- initializer=self.kernel_initializer,
- regularizer=self.kernel_regularizer,
- constraint=self.kernel_constraint)
- if self.use_attention_bias:
- self.ba = self.add_weight(shape=(1,),
- name='{}_Add_ba'.format(self.name),
- initializer=self.bias_initializer,
- regularizer=self.bias_regularizer,
- constraint=self.bias_constraint)
- def _build_multiplicative_attention(self, input_shape):
- feature_dim = int(input_shape[2])
- self.Wa = self.add_weight(shape=(feature_dim, feature_dim),
- name='{}_Mul_Wa'.format(self.name),
- initializer=self.kernel_initializer,
- regularizer=self.kernel_regularizer,
- constraint=self.kernel_constraint)
- if self.use_attention_bias:
- self.ba = self.add_weight(shape=(1,),
- name='{}_Mul_ba'.format(self.name),
- initializer=self.bias_initializer,
- regularizer=self.bias_regularizer,
- constraint=self.bias_constraint)
- def call(self, inputs, mask=None, **kwargs):
- input_len = K.shape(inputs)[1]
- if self.attention_type == SeqSelfAttention.ATTENTION_TYPE_ADD:
- e = self._call_additive_emission(inputs)
- elif self.attention_type == SeqSelfAttention.ATTENTION_TYPE_MUL:
- e = self._call_multiplicative_emission(inputs)
- if self.attention_activation is not None:
- e = self.attention_activation(e)
- if self.attention_width is not None:
- if self.history_only:
- lower = K.arange(0, input_len) - (self.attention_width - 1)
- else:
- lower = K.arange(0, input_len) - self.attention_width // 2
- lower = K.expand_dims(lower, axis=-1)
- upper = lower + self.attention_width
- indices = K.expand_dims(K.arange(0, input_len), axis=0)
- e -= 10000.0 * (1.0 - K.cast(lower <= indices, K.floatx()) * K.cast(indices < upper, K.floatx()))
- if mask is not None:
- mask = K.expand_dims(K.cast(mask, K.floatx()), axis=-1)
- e -= 10000.0 * ((1.0 - mask) * (1.0 - K.permute_dimensions(mask, (0, 2, 1))))
- # a_{t} = \text{softmax}(e_t)
- e = K.exp(e - K.max(e, axis=-1, keepdims=True))
- a = e / K.sum(e, axis=-1, keepdims=True)
- # l_t = \sum_{t'} a_{t, t'} x_{t'}
- v = K.batch_dot(a, inputs)
- if self.attention_regularizer_weight > 0.0:
- self.add_loss(self._attention_regularizer(a))
- if self.return_attention:
- return [v, a]
- return v
- def _call_additive_emission(self, inputs):
- input_shape = K.shape(inputs)
- batch_size, input_len = input_shape[0], input_shape[1]
- # h_{t, t'} = \tanh(x_t^T W_t + x_{t'}^T W_x + b_h)
- q = K.expand_dims(K.dot(inputs, self.Wt), 2)
- k = K.expand_dims(K.dot(inputs, self.Wx), 1)
- if self.use_additive_bias:
- h = K.tanh(q + k + self.bh)
- else:
- h = K.tanh(q + k)
- # e_{t, t'} = W_a h_{t, t'} + b_a
- if self.use_attention_bias:
- e = K.reshape(K.dot(h, self.Wa) + self.ba, (batch_size, input_len, input_len))
- else:
- e = K.reshape(K.dot(h, self.Wa), (batch_size, input_len, input_len))
- return e
- def _call_multiplicative_emission(self, inputs):
- # e_{t, t'} = x_t^T W_a x_{t'} + b_a
- e = K.batch_dot(K.dot(inputs, self.Wa), K.permute_dimensions(inputs, (0, 2, 1)))
- if self.use_attention_bias:
- e += self.ba[0]
- return e
- def compute_output_shape(self, input_shape):
- output_shape = input_shape
- if self.return_attention:
- attention_shape = (input_shape[0], output_shape[1], input_shape[1])
- return [output_shape, attention_shape]
- return output_shape
- def compute_mask(self, inputs, mask=None):
- if self.return_attention:
- return [mask, None]
- return mask
- def _attention_regularizer(self, attention):
- batch_size = K.cast(K.shape(attention)[0], K.floatx())
- input_len = K.shape(attention)[-1]
- indices = K.expand_dims(K.arange(0, input_len), axis=0)
- diagonal = K.expand_dims(K.arange(0, input_len), axis=-1)
- eye = K.cast(K.equal(indices, diagonal), K.floatx())
- return self.attention_regularizer_weight * K.sum(K.square(K.batch_dot(
- attention,
- K.permute_dimensions(attention, (0, 2, 1))) - eye)) / batch_size
- @staticmethod
- def get_custom_objects():
- return {'SeqSelfAttention': SeqSelfAttention}
|