Visualizing Attention patterns
Remember that we specifically defined a model called attention_visualizer
to generate attention matrices? With the model trained, we can now look at these attention patterns by feeding data to the model. Here’s how the model was defined:
attention_visualizer = tf.keras.models.Model(inputs=[encoder.inputs, decoder_input], outputs=[attn_weights, decoder_out])
We’ll also define a function to get the processed attention matrix along with label data that we can use directly for visualization purposes:
def get_attention_matrix_for_sampled_data(attention_model, target_lookup_layer, test_xy, n_samples=5):
test_x, test_y = test_xy
rand_ids = np.random.randint(0, len(test_xy[0]),
size=(n_samples,))
results = []
for rid in rand_ids:
en_input = test_x[rid:rid+1]
de_input = test_y[rid:rid+1,:-1]
attn_weights, predictions = attention_model.predict([en_input...