Tensorflow: How to extract attention_scores for graphing?

If you have a MultiHeadAttention layer in Keras, then it can return attention scores like so:

    x, attention_scores = MultiHeadAttention(1, 10, 10)(x, return_attention_scores=True)

How do you extract the attention scores from the network graph? I would like to graph them.


Option 1: If you want to plot the attention scores during training, you can create a Callback and pass data to it. It can be triggered for example, after every epoch. Here is an example where I am using 2 attention heads and plotting them after every epoch:

import tensorflow as tf
import seaborn as sb
import matplotlib.pyplot as plt

class CustomCallback(tf.keras.callbacks.Callback):
   def __init__(self, data):
      self.data = data
   def on_epoch_end(self, epoch, logs=None):
      test_targets, test_sources = self.data 
      _, attention_scores = attention_layer(test_targets[:1], test_sources[:1], return_attention_scores=True) # take one sample

      fig, axs = plt.subplots(ncols=3, gridspec_kw=dict(width_ratios=[5,5,0.2]))
      sb.heatmap(attention_scores[0, 0, :, :], annot=True, cbar=False, ax=axs[0])
      sb.heatmap(attention_scores[0, 1, :, :], annot=True, yticklabels=False, cbar=False, ax=axs[1])
      fig.colorbar(axs[1].collections[0], cax=axs[2])
      plt.show()

layer = tf.keras.layers.MultiHeadAttention(num_heads=2, key_dim=2)
target = tf.keras.layers.Input(shape=[8, 16])
source = tf.keras.layers.Input(shape=[4, 16])
output_tensor, weights = layer(target, source,
                               return_attention_scores=True)
output = tf.keras.layers.Flatten()(output_tensor)
output = tf.keras.layers.Dense(1, activation='sigmoid')(output)

model = tf.keras.Model([target, source], output)
model.compile(optimizer = 'adam', loss = tf.keras.losses.BinaryCrossentropy())

attention_layer = model.layers[2]
samples = 5
train_targets = tf.random.normal((samples, 8, 16))
train_sources = tf.random.normal((samples, 4, 16))
test_targets = tf.random.normal((samples, 8, 16))
test_sources = tf.random.normal((samples, 4, 16))
y = tf.random.uniform((samples,), maxval=2, dtype=tf.int32)

model.fit([train_targets, train_sources], y, batch_size=2, epochs=2, callbacks=[CustomCallback([test_targets, test_sources])])
Epoch 1/2
1/3 [=========>....................] - ETA: 2s - loss: 0.7142

enter image description here

3/3 [==============================] - 3s 649ms/step - loss: 0.6992
Epoch 2/2
1/3 [=========>....................] - ETA: 0s - loss: 0.7265

enter image description here

3/3 [==============================] - 1s 650ms/step - loss: 0.6863
<keras.callbacks.History at 0x7fcc839dc590>

Option 2: If you just want to plot the attention scores after training, you can just pass some data to the model's attention layer and plot the scores:

import tensorflow as tf
import seaborn as sb
import matplotlib.pyplot as plt

layer = tf.keras.layers.MultiHeadAttention(num_heads=2, key_dim=2)
target = tf.keras.layers.Input(shape=[8, 16])
source = tf.keras.layers.Input(shape=[4, 16])
output_tensor, weights = layer(target, source,
                               return_attention_scores=True)
output = tf.keras.layers.Flatten()(output_tensor)
output = tf.keras.layers.Dense(1, activation='sigmoid')(output)

model = tf.keras.Model([target, source], output)
model.compile(optimizer = 'adam', loss = tf.keras.losses.BinaryCrossentropy())

samples = 5
train_targets = tf.random.normal((samples, 8, 16))
train_sources = tf.random.normal((samples, 4, 16))
test_targets = tf.random.normal((samples, 8, 16))
test_sources = tf.random.normal((samples, 4, 16))
y = tf.random.uniform((samples,), maxval=2, dtype=tf.int32)

model.fit([train_targets, train_sources], y, batch_size=2, epochs=2)

attention_layer = model.layers[2]

_, attention_scores = attention_layer(test_targets[:1], test_sources[:1], return_attention_scores=True) # take one sample
fig, axs = plt.subplots(ncols=3, gridspec_kw=dict(width_ratios=[5,5,0.2]))
sb.heatmap(attention_scores[0, 0, :, :], annot=True, cbar=False, ax=axs[0])
sb.heatmap(attention_scores[0, 1, :, :], annot=True, yticklabels=False, cbar=False, ax=axs[1])
fig.colorbar(axs[1].collections[0], cax=axs[2])
plt.show()
Epoch 1/2
3/3 [==============================] - 1s 7ms/step - loss: 0.6727
Epoch 2/2
3/3 [==============================] - 0s 6ms/step - loss: 0.6503

enter image description here