Understanding Transformer Attention: From Scratch to Flash Attention
The attention mechanism is the computational heart of the transformer architecture. In this post, we'll build understanding from first principles and trace the evolution to Flash Attention.
The Attention Formula
At its core, attention computes a weighted sum of values based on the compatibility between queries and keys:
Where , , and .
The scaling factor prevents the dot products from growing too large, which would push the softmax into regions with extremely small gradients.
The quadratic memory complexity of standard attention is the primary bottleneck for long-context transformers. Flash Attention reduces this to by never materializing the full attention matrix.
Multi-Head Attention
Rather than performing a single attention function, we project queries, keys, and values times with different learned projections:
where .
The IO-Awareness Revolution
Flash Attention's key insight is treating attention as an IO-bound operation rather than a compute-bound one. By tiling the computation and keeping intermediate results in SRAM rather than writing them to HBM, we achieve:
- 2-4x speedup over standard PyTorch attention
- 5-20x memory reduction for long sequences
- Exact computation — no approximation needed
# Standard attention (materializes N×N matrix)
scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(d_k)
attn = torch.softmax(scores, dim=-1)
output = torch.matmul(attn, V)
# Flash Attention (tiled, IO-aware)
# The block-sparse variant never materializes the full matrix
output = flash_attn_func(Q, K, V, causal=True)
When benchmarking attention implementations, always measure wall-clock time on the target hardware. FLOPs alone don't capture the memory bandwidth bottleneck that Flash Attention addresses.
Looking Forward
The trajectory is clear: as context windows grow from 4K to 1M+ tokens, hardware-aware algorithm design becomes not optional but essential. Ring Attention, Paged Attention (vLLM), and the emerging class of linear attention variants each represent different points in the approximation-efficiency tradeoff space.
The fundamental question remains: can we find attention mechanisms that are both subquadratic and retain the full representational power of softmax attention?