Back to Blog

Understanding Transformer Attention: From Scratch to Flash Attention

·12 min read
machine-learningtransformerssystems

The attention mechanism is the computational heart of the transformer architecture. In this post, we'll build understanding from first principles and trace the evolution to Flash Attention.

The Attention Formula

At its core, attention computes a weighted sum of values based on the compatibility between queries and keys:

Attention(Q,K,V)=softmax(QKTdk)V\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V

Where QRn×dkQ \in \mathbb{R}^{n \times d_k}, KRm×dkK \in \mathbb{R}^{m \times d_k}, and VRm×dvV \in \mathbb{R}^{m \times d_v}.

The scaling factor 1dk\frac{1}{\sqrt{d_k}} prevents the dot products from growing too large, which would push the softmax into regions with extremely small gradients.

The quadratic memory complexity O(n2)O(n^2) of standard attention is the primary bottleneck for long-context transformers. Flash Attention reduces this to O(n)O(n) by never materializing the full attention matrix.

Multi-Head Attention

Rather than performing a single attention function, we project queries, keys, and values hh times with different learned projections:

MultiHead(Q,K,V)=Concat(head1,,headh)WO\text{MultiHead}(Q, K, V) = \text{Concat}(\text{head}_1, \ldots, \text{head}_h)W^O

where headi=Attention(QWiQ,KWiK,VWiV)\text{head}_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V).

The IO-Awareness Revolution

Flash Attention's key insight is treating attention as an IO-bound operation rather than a compute-bound one. By tiling the computation and keeping intermediate results in SRAM rather than writing them to HBM, we achieve:

  • 2-4x speedup over standard PyTorch attention
  • 5-20x memory reduction for long sequences
  • Exact computation — no approximation needed
# Standard attention (materializes N×N matrix)
scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(d_k)
attn = torch.softmax(scores, dim=-1)
output = torch.matmul(attn, V)

# Flash Attention (tiled, IO-aware)
# The block-sparse variant never materializes the full matrix
output = flash_attn_func(Q, K, V, causal=True)

When benchmarking attention implementations, always measure wall-clock time on the target hardware. FLOPs alone don't capture the memory bandwidth bottleneck that Flash Attention addresses.

Looking Forward

The trajectory is clear: as context windows grow from 4K to 1M+ tokens, hardware-aware algorithm design becomes not optional but essential. Ring Attention, Paged Attention (vLLM), and the emerging class of linear attention variants each represent different points in the approximation-efficiency tradeoff space.

The fundamental question remains: can we find attention mechanisms that are both subquadratic and retain the full representational power of softmax attention?