How to add an attention mechanism in keras?

I'm currently using this code that i get from one discussion on github Here's the code of the attention mechanism:

_input = Input(shape=[max_length], dtype='int32')

# get the embedding layer
embedded = Embedding(
        input_dim=vocab_size,
        output_dim=embedding_size,
        input_length=max_length,
        trainable=False,
        mask_zero=False
    )(_input)

activations = LSTM(units, return_sequences=True)(embedded)

# compute importance for each step
attention = Dense(1, activation='tanh')(activations)
attention = Flatten()(attention)
attention = Activation('softmax')(attention)
attention = RepeatVector(units)(attention)
attention = Permute([2, 1])(attention)


sent_representation = merge([activations, attention], mode='mul')
sent_representation = Lambda(lambda xin: K.sum(xin, axis=-2), output_shape=(units,))(sent_representation)

probabilities = Dense(3, activation='softmax')(sent_representation)

Is this the correct way to do it? i was sort of expecting the existence of time distributed layer since attention mechanism is distributed in every time step of the RNN. I need someone to confirm that this implementation(the code) is a correct implementation of attention mechanism. Thank you.


Solution 1:

If you want to have an attention along the time dimension, then this part of your code seems correct to me:

activations = LSTM(units, return_sequences=True)(embedded)

# compute importance for each step
attention = Dense(1, activation='tanh')(activations)
attention = Flatten()(attention)
attention = Activation('softmax')(attention)
attention = RepeatVector(units)(attention)
attention = Permute([2, 1])(attention)

sent_representation = merge([activations, attention], mode='mul')

You've worked out the attention vector of shape (batch_size, max_length):

attention = Activation('softmax')(attention)

I've never seen this code before, so I can't say if this one is actually correct or not:

K.sum(xin, axis=-2)

Further reading (you might have a look):

  • https://github.com/philipperemy/keras-visualize-activations

  • https://github.com/philipperemy/keras-attention-mechanism

Solution 2:

Attention mechanism pays attention to different part of the sentence:

activations = LSTM(units, return_sequences=True)(embedded)

And it determines the contribution of each hidden state of that sentence by

  1. Computing the aggregation of each hidden state attention = Dense(1, activation='tanh')(activations)
  2. Assigning weights to different state attention = Activation('softmax')(attention)

And finally pay attention to different states:

sent_representation = merge([activations, attention], mode='mul')

I don't quite understand this part: sent_representation = Lambda(lambda xin: K.sum(xin, axis=-2), output_shape=(units,))(sent_representation)

To understand more, you can refer to this and this, and also this one gives a good implementation, see if you can understand more on your own.