ELI5: FlashAttention. Step by step explanation of how one of… | by Aleksa Gordić | Medium

omnivore

Read on Omnivore | Read Original

Highlights

methods reduce the compute requirements to linear or near-linear in sequence length, many of them do not display wall-clock speedup against standard attention and have not gained wide adoption. One main reason is that they focus on FLOP reduction (which may not correlate with wall-clock speed) and tend to ignore overheads from memory access (IO).” ⤴️

doesn’t matter if you can compute at exaFLOPS speeds if there is no data to be processed. ⤴️

Depending on this ratio between computation and memory accesses, operations can be classified as either:

  • compute-bound (example: matrix multiplication)

  • OR memory-bound (examples: elementwise ops (activation, dropout, masking), reduction ops (softmax, layer norm, sum, etc.)…)

Note on the terminology: this ratio is commonly measured by the arithmetic intensity, which is the number of arithmetic operations per byte of memory access. ⤴️

attention is (on current AI accelerators) memory-bound.

Why?

Because it “mostly consists of elementwise ops” or more accurately the arithmetic density of attention is not very high. ⤴️

masking, softmax & dropout are the ops that are taking the bulk of the time and not matrix multiplication (even though bulk of the FLOPS is in matmul). ⤴️

Being “IO-aware” in practice boils down to exploiting the fact that SRAM is so much faster than HBM (“high bandwidth memory” — unfortunate name) by making sure to reduce the communication between the two. ⤴️

A100 GPU has 40–80GB of high bandwidth memory (HBM, the thing that gives you lovely CUDA OOMs) with a bandwidth of 1.5–2.0 TB/s and 192KB of on-chip SRAM per each of 108 streaming multiprocessors with bandwidth estimated around 19TB/s.

Similar ratios still hold for H100 and other accelerators. ⤴️

the standard implementation shows the utmost disrespect for the way HW operates. It’s basically treating HBM load/store ops as 0 cost (it’s not “IO-aware”). ⤴️

The lowest hanging fruit is to remove redundant HBM reads/writes.

Why write S back to HBM only to (re)load it again in order to compute the softmax? Let’s keep it in SRAM instead, perform all of the intermediate steps, and only then write the final result back to HBM.

This is what compilers folks refer to as “kernel fusion”, one of the most important low-level optimizations in deep learning: ⤴️

loading from the HBM only once, you execute the fused op, and only then write the results back. By doing this you reduce the communication overhead. ⤴️

Flash attention basically boils down to 2 main ideas:

  1. Tiling (used during both forward & backward passes) — basically chunking the NxN softmax/scores matrix into blocks.

2. Recomputation (used in the backward pass only — if you’re familiar with activation/gradient checkpointing, this will be trivial to understand) ⤴️

The main hurdle in getting the tiling approach to work is softmax. In particular, the fact that softmax couples all of the score columns together. ⤴️

That’s the issue.

To compute how much a particular i-th token from the input sequence pays attention to other tokens in the sequence you’d need to have all of those scores readily available (denoted here by z_j) in SRAM.

But let me remind you: SRAM is severely limited in its capacity. You can’t just load the whole thing. N (sequence length) can be 1000 or even 100.000 tokens. So explodes fairly quickly. ⤴️

we can actually chop the softmax computation down into smaller blocks and still end up with precisely the same result. ⤴️

the trick is that we can combine those per-block partial softmax numbers in a smart way such that the final result is actually correct. ⤴️

in order to compute the softmax for the scores belonging to the first 2 blocks (of size B), you have to keep track of 2 statistics for each of the blocks: m(x) (maximum score) and l(x) (sum of exp scores).

And then you can seamlessly fuse them together using the normalizing coefficients. ⤴️

Step 0: HBM’s capacity is measured in GBs (e.g. RTX 3090 has 24 GBs of VRAM/HBM, A100 has 40–80 GB, etc.) so allocating Q, K, and V is not an issue.

Step 1: Let’s compute the row/column block sizes. Why ceil(M/4d)? Because query, key, and value vectors are d-dimensional, and, we also need to combine them into the output d-dimensional vector. So this size basically allows us to max out SRAM capacity with q, k, v, and o vectors.

Toy example: assume M = 1000, d = 5. In this example, the block size is (1000/4*5) = 50. So in this example, we would load blocks of 50 q, k, v, o vectors at a time, to make sure we’re reducing the number of reads/writes between HBM/SRAM. ⤴️