Efficient Transformers
Created: =dateformat(this.file.ctime,"dd MMM yyyy, hh:mm a") | Modified: =dateformat(this.file.mtime,"dd MMM yyyy, hh:mm a")
Tags: knowledge
Overview
Related fields
Introduction

- Sparsifying Attention / Sparse Attention / Sparse Transformer
- Methods that try to constrict and sparsify attention. The most primitive example is “windowed” attention which is conceptually similar to convolutions (Figure (b) below). The most successful sparse-base method is Big Bird, as depicted below uses the combination of the above attention types.

- Methods that try to constrict and sparsify attention. The most primitive example is “windowed” attention which is conceptually similar to convolutions (Figure (b) below). The most successful sparse-base method is Big Bird, as depicted below uses the combination of the above attention types.
FlashAttention
- ELI5: FlashAttention. Step by step explanation | by Aleksa Gordić | Medium
- GitHub - Dao-AILab/flash-attention: Fast and memory-efficient exact attention
- [2205.14135] FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness

- Where the inefficiencies for the attention operation lies is in mask, softmax and dropout (rather than matmul)
- means that attention operation is memory-bound and not compute-bound (see Optimising GPU code, and Analysing LLM Inference Costs)
- SRAM fastest type of memory, but limited amount
- below DRAM would be things like SSD (higher capacity but even slower) and HDD / AWS S3
- FlashAttention: IO-aware (accounts for RW between levels of GPU memory) exact attention algorithm that uses tiling to reduce the number of memory reads/writes between GPU high bandwidth memory (HBM) and GPU on-chip SRAM
- FlashAttention uses tiling to prevent materialization of the large N×N attention matrix (dotted box) on (relatively) slow GPU HBM.
- In the outer loop (red arrows), FlashAttention loops through blocks of the K and V matrices and loads them to fast on-chip SRAM.
- In each block, FlashAttention loops over blocks of Q matrix (blue arrows), loading them to SRAM, and writing the output of the attention computation back to HBM.
- Key Benefit:
- Compared to vanilla attention, which is quadratic in sequence length, O(N²), this method is sub-quadratic/linear in N (O(N)).
- Exact - Not an approximation of the attention mechanism. It is same output
- Main Ideas
- Tiling (both forward and backward passes) - chunking the NxN softmax/scores matrix into blocks
- partial computation of softmax through iterations, then converge to final result
- Recomputation (only backward pass)
- Tiling (both forward and backward passes) - chunking the NxN softmax/scores matrix into blocks
- Limitation
- Requires writing a new CUDA kernel for each new attention implementation in a considerably lower-level language than PyTorch, and requires significant engineering effort
- Implementations may also not be transferrable across GPU architectures
Theoretical References
Papers
- Efficient Transformers: A Survey
- DeepSpeed - many many papers there
- [2205.14135] FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness
Articles
- Democratizing on-device generative AI with sub-10 billion parameter models | Qualcomm
- AI on the Edge: The latest on-device AI insights and trends | Qualcomm
- Decoding Transformers on Edge Devices - Axelera AI
Courses
Code References
Methods
Tools, Frameworks
- DeepSpeed - Microsoft Research
- An open source deep learning optimization library for PyTorch
- DeepSpeed-Training
- ZeRO, 3D-Parallelism, DeepSpeed-MoE, ZeRO-Infinity
- DeepSpeed-Inference, [Blog]
- Parallelism technology: tensor, pipeline, expert and ZeRO-parallelism
- custom inference kernels, communication optimizations and heterogeneous memory technologies
- DeepSpeed-Compression, [Blog]
- ZeroQuant, XTC