FlashAttention is a fast, memory-efficient, exact attention algorithm and implementation that reorders computation to minimize reads and writes to GPU high-bandwidth memory (HBM). The official repository, Dao-AILab/flash-attention, ships CUDA and ROCm kernels with PyTorch bindings and is widely adopted across frameworks and production stacks (Dao et al., 2022) link; (Dao, 2024) link; (Shah et al., 2024) link.
Key features & functionality
The repository provides kernels and a clean PyTorch API for scaled dot-product attention in training and inference. Core entry points include
flash_attn.flash_attn_func
,flash_attn.flash_attn_qkvpacked_func
, and inference-focusedflash_attn_with_kvcache
, with support for causal masks, sliding-window (local) attention, ALiBi, rotary embeddings, MQA/GQA, paged KV cache, and softcapping.The module
flash_attn/modules/mha.py
shows how these primitives compose into a multi-head attention layer. A Triton implementation is available for experimentation, and ROCm backends support both Composable Kernel (MI200/MI300) and Triton.
- Training and inference primitives: Fused kernels for forward/backward, variable-length sequences, packed QKV, and KV cache updates for iterative decoding.
- Broad hardware support: CUDA on Ampere/Ada/Hopper; ROCm 6+ via Composable Kernel and Triton backends; guidance for Docker-based builds.
- Model-ready extras: Sliding-window attention, ALiBi, rotary embeddings, softcapping, paged KV cache (PagedAttention), and compatibility with torch.compile.
What problem does it solve?
Standard self-attention computes a matrix of size seq_len x seq_len and applies softmax, making both time and memory scale quadratically with sequence length. At long context (even a few thousand tokens), the intermediate attention matrix becomes the bottleneck: memory blows up and data shuffles between HBM and on-chip SRAM dominate wall-clock time. Prior work attempted approximations to reduce complexity, but many did not translate into real end-to-end speedups due to IO overheads.
The idea behind FlashAttention
FlashAttention is IO-aware: it tiles Q, K, V into blocks, streams them through on-chip SRAM, and fuses operations so that large intermediate matrices are never materialized in global memory. This cuts HBM traffic, turning memory complexity from quadratic to effectively linear in sequence length, while remaining exact (no quality compromise).
Reported gains include 2-4x wall-clock speedups and substantial memory savings: roughly 10x at 2K tokens and 20x at 4K tokens per the repository's performance notes. The v2 rewrite improves parallelism and work partitioning across GPUs; the v3 beta targets Hopper (H100/H800) and leverages asynchronous tensor-core instructions and FP8 for even higher throughput (Shah et al., 2024) link.
A short background on self-attention
Self-attention was popularized by the Transformer architecture in "Attention Is All You Need" (Vaswani et al., 2017) paper. At its core is scaled dot-product attention, which computes weights by comparing each query vector to all keys, then uses those weights to mix the values: softmax(QK^T / sqrt(d)) * V. This lets each token selectively attend to other tokens, enabling long-range dependencies without recurrence.
In practice, Transformers use multi-head attention (MHA), which splits the model's channels into multiple "heads" so different subspaces can attend to different patterns. Masks adjust which tokens can be seen: a causal mask enforces left-to-right dependencies for language modeling, while padding masks ignore non-content positions. Frameworks expose this primitive directly, for example PyTorch's scaled_dot_product_attention, and libraries such as xFormers provide optimized kernels.
The catch is complexity: computing and storing the n-by-n attention matrix for sequence length n yields quadratic time and memory. On GPUs, the limiting factor is often memory bandwidth rather than FLOPs. Materializing the attention matrix amplifies traffic between HBM (global memory) and on-chip SRAM (shared memory/registers), hurting real-world throughput. Many approaches reduce complexity via approximation (e.g., kernel methods or low-rank projections), but FlashAttention takes a different path: keep attention exact and reorganize computation to minimize IO.
Under the hood
Technically, FlashAttention is an algorithm-kernel co-design. It tiles the attention computation and recomputes softmax statistics per tile to avoid writing the full attention matrix to HBM. The implementation mixes Python (bindings and module wiring), C++/CUDA (high-performance kernels), and, for AMD, ROCm CK/Triton backends.
The repository structure highlights this scope: csrc/
for CUDA sources, flash_attn/
for Python modules and Triton code, training/
and examples/
for end-to-end usage, and tests/
for numerical parity checks. In-v2, work partitioning yields better GPU utilization; in-v3 beta, Hopper-specific features such as WGMMA and TMA enable overlapping GEMM and softmax and higher throughput, while FP8 with Hadamard-based incoherent processing reduces quantization error at low precision (Shah et al., 2024) link.
from flash_attn import flash_attn_func, flash_attn_qkvpacked_func
# Q, K, V shapes: (batch, seqlen, nheads, headdim)
out = flash_attn_func(q, k, v, dropout_p=0.0, causal=True)
# If Q, K, V are packed: (batch, seqlen, 3, nheads, headdim)
out = flash_attn_qkvpacked_func(qkv, dropout_p=0.0, causal=False)
# Inference with KV cache and rotary embeddings
from flash_attn import flash_attn_with_kvcache
out = flash_attn_with_kvcache(q, k_cache, v_cache, k=k_new, v=v_new,
rotary_cos=cos, rotary_sin=sin, causal=True)
For a concrete module integration, see flash_attn/modules/mha.py. The README's performance plots cover A100 and H100, showing combined forward+backward speedups over PyTorch standard attention and the memory footprint advantages. Benchmarks assume typical head dims (64/128) and long sequences (up to 16K). The project also includes an optimized GPT model and training scripts that avoid activation checkpointing yet reach high MFU, demonstrating system-level wins beyond a single kernel.
Use cases
FlashAttention is widely integrated across the ecosystem: PyTorch's core scaled_dot_product_attention, Hugging Face Transformers, DeepSpeed, Megatron-LM, MosaicML Composer, PaddlePaddle, and more, as tracked in usage.md. Real-world wins include faster BERT training in MLPerf (2022), big throughput boosts in diffusion pipelines (Hugging Face diffusers), and faster or longer-context inference paths in engines like NVIDIA FasterTransformer and Meta AITemplate.
Practically, teams adopt it to: train LLMs with longer context windows without quadratic memory blowup; accelerate decoder-only inference via fused KV-cache updates; and scale multi-GPU training without resorting to heavy activation checkpointing.
Community and contribution
The repository is active, with 100+ contributors and regular releases. Issues and discussions reflect a healthy mix of research, performance tuning, and platform support (CUDA versions, ROCm backends, Triton). Contributions have come from the broader community for features like ALiBi, deterministic backward, paged KV cache, and softcapping.
If you encounter problems, the maintainers encourage opening issues; tests verify parity with PyTorch baselines within tight numerical tolerances, which helps reviewers quickly validate changes. The changelog documents semantic shifts (e.g., causal mask behavior in v2.1) so downstream users can upgrade safely.
Usage and license terms
Installation supports pip (pip install flash-attn --no-build-isolation
) or building from source with ninja
for fast parallel compilation. CUDA 12+ is recommended (Ampere/Ada/Hopper), with guidance for ROCm 6+ and Docker images on AMD. Windows build support has improved since v2.3.2 but still requires care. Licensing is BSD-3-Clause, a permissive license that allows reuse and modification with minimal obligations: retain copyright and license notices in source and binary redistributions; do not use authors' names for endorsement without permission; and accept the software "as is" without warranty. See LICENSE and AUTHORS for details.
Impact and what is next
FlashAttention has influenced how the community thinks about attention: IO-awareness, fusion, and SRAM residency now feel standard in high-performance implementations. The project helped push context lengths from 2-4K to 128K and beyond in production LLMs by removing practical bottlenecks, not by approximation.
FlashAttention-3 shows that algorithm-hardware co-design can keep unlocking throughput on new GPUs: overlapping GEMMs with softmax and adopting FP8 with careful error control achieves up to ~75 percent of H100's theoretical peak in FP16 and near 1.2 PFLOPS in FP8 (Shah et al., 2024) link.
Looking forward, expect deeper integration with frameworks, better Windows support, and generalizations to other accelerators. Related kernels in the Dao AI Lab ecosystem (e.g., causal-conv1d, fast-hadamard-transform) hint at a growing suite of GPU-native building blocks.
About Dao AI Lab
Dao AI Lab is an AI research group led by Prof. Tri Dao, with a track record of impactful open-source releases in GPU systems and deep learning. The group's codebases emphasize clarity, performance, and practical adoption across industry stacks.
Learn more at the lab's GitHub org Dao-AILab and Tri Dao's site tridao.me. Publications for FlashAttention-2 and -3 are available on Tri's site and arXiv, with accessible write-ups that walk through kernel design choices.
Conclusion
FlashAttention turned a core Transformer bottleneck into a GPU-first, IO-aware kernel that the ecosystem could adopt quickly. If you are training or serving LLMs, diffusion models, or long-context transformers, this is a practical speed and memory win with battle-tested kernels and growing hardware coverage. Explore the repo, read the papers and blog posts, and consider contributing improvements or platform support.
Explore the repository and join the community:
https://github.com/Dao-AILab/flash-attention
FlashAttention, in Focus: How IO-aware kernels unlocked longer context and faster Transformers