Efficient Transformers


Created: =dateformat(this.file.ctime,"dd MMM yyyy, hh:mm a") | Modified: =dateformat(this.file.mtime,"dd MMM yyyy, hh:mm a") Tags: knowledge


Overview

Introduction

  • Sparsifying Attention / Sparse Attention / Sparse Transformer
    • Methods that try to constrict and sparsify attention. The most primitive example is “windowed” attention which is conceptually similar to convolutions (Figure (b) below). The most successful sparse-base method is Big Bird, as depicted below uses the combination of the above attention types.

FlashAttention

  • ELI5: FlashAttention. Step by step explanation | by Aleksa Gordić | Medium
  • GitHub - Dao-AILab/flash-attention: Fast and memory-efficient exact attention
  • [2205.14135] FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness
  • Where the inefficiencies for the attention operation lies is in mask, softmax and dropout (rather than matmul)
  • SRAM fastest type of memory, but limited amount
    • below DRAM would be things like SSD (higher capacity but even slower) and HDD / AWS S3
  • FlashAttention: IO-aware (accounts for RW between levels of GPU memory) exact attention algorithm that uses tiling to reduce the number of memory reads/writes between GPU high bandwidth memory (HBM) and GPU on-chip SRAM
    • FlashAttention uses tiling to prevent materialization of the large N×N attention matrix (dotted box) on (relatively) slow GPU HBM.
    • In the outer loop (red arrows), FlashAttention loops through blocks of the K  and V matrices and loads them to fast on-chip SRAM.
    • In each block, FlashAttention loops over blocks of Q matrix (blue arrows), loading them to SRAM, and writing the output of the attention computation back to HBM.
  • Key Benefit:
    • Compared to vanilla attention, which is quadratic in sequence length, O(N²), this method is sub-quadratic/linear in N (O(N)).
    • Exact - Not an approximation of the attention mechanism. It is same output
  • Main Ideas
    • Tiling (both forward and backward passes) - chunking the NxN softmax/scores matrix into blocks
      • partial computation of softmax through iterations, then converge to final result
    • Recomputation (only backward pass)
  • Limitation
    • Requires writing a new CUDA kernel for each new attention implementation in a considerably lower-level language than PyTorch, and requires significant engineering effort
    • Implementations may also not be transferrable across GPU architectures

Theoretical References

Papers

Articles

Courses


Code References

Methods

Tools, Frameworks