Using Flash Attention to speed up your training runs
In earlier chapters, we learned about the core Transformer model, with its underlying self-attention mechanism that serves as the basis for most state-of-the-art models across vision, language, and generative use cases today. While Transformer models are easily parallelizable, they aren’t particularly good at optimizing for different memory speeds within modern GPUs. This becomes a problem when they materialize the Transformer in the slowest part of the GPU due to a naïve implementation. As you can imagine, that leaves performance gains on the table.
A Stanford-led research team realized that they could improve this and developed a novel implementation of the Transformer architecture. Simply put, it’s an extremely clever way to handle a quadratic nested for-loop. Let’s take a closer look.
Figure 9.2 – From FlashAttention by Tri Dao et al, 2022 (1)
This visual from...