Introduction
FlashAttention revolutionized transformer training and inference by reducing attention memory complexity from O(Nยฒ) to O(N) while maintaining exact computation. FlashAttention-2 pushed this further with improved parallelism and warp partitioning. Now, FlashAttention-3 brings unprecedented optimizations specifically designed for NVIDIA’s Hopper architecture, achieving 75% FLOP utilization on H100 GPUsโa 1.5-2x improvement over FlashAttention-2.
This article explores the three key innovations behind FlashAttention-3: warp specialization, interleaved computation, and low-precision execution.
The Evolution of FlashAttention
FlashAttention (V1): Tiling and Recomputation
The original FlashAttention introduced two key ideas:
- Tiling: Process attention in small blocks that fit in SRAM
- Recomputation: Recalculate gradients during backward pass instead of storing all attention matrices
# Simplified FlashAttention V1 algorithm
def flash_attention_v1(Q, K, V, softmax_scale):
"""
Standard FlashAttention with tiling
"""
BATCH, N_CTX, H, D = Q.shape
D_HEAD = D # head dimension
# Initialize output and statistics
O = torch.zeros_like(Q)
l = torch.zeros((BATCH, H, N_CTX), device=Q.device)
m = torch.full((BATCH, H, N_CTX), -float('inf'), device=Q.device)
# Process in tiles (blocks along sequence dimension)
for j in range(0, N_CTX, BLOCK_M):
# Load Q block
q_block = Q[:, j:j+BLOCK_M]
for i in range(0, N_CTX, BLOCK_N):
# Load K, V blocks
k_block = K[:, i:i+BLOCK_N]
v_block = V[:, i:i+BLOCK_N]
# Compute attention for this tile
qk = torch.matmul(q_block, k_block.transpose(-2, -1)) * softmax_scale
# Online softmax computation
m_new = torch.maximum(m, qk.max(-1).values)
l_new = torch.exp(m - m_new.unsqueeze(-1)) @ torch.ones_like(qk[0]) + \
torch.exp(qk - m_new.unsqueeze(-1)).sum(-1)
# Scale and accumulate output
p = torch.exp(qk - m_new.unsqueeze(-1))
O_block = (p @ v_block).type_as(O)
O[:, j:j+BLOCK_M] = O[:, j:j+BLOCK_M] * \
torch.exp(m[:, j:j+BLOCK_M] - m_new).unsqueeze(-1) + O_block
l[:, j:j+BLOCK_M] = l[:, j:j+BLOCK_M] * \
torch.exp(m[:, j:j+BLOCK_M] - m_new) + l_new
m = m_new
O = O / l.unsqueeze(-1)
return O
FlashAttention-2: Improved Parallelism
FlashAttention-2 made several improvements:
- Better thread block partitioning across sequence length
- Reorganized loops for improved parallelism
- Optimized warp-level reductions
FlashAttention-3: Hopper-Specific Optimizations
Architecture Background: NVIDIA Hopper
Hopper GPU architecture introduced key features that FlashAttention-3 exploits:
| Feature | Description | Benefit for Attention |
|---|---|---|
| Tensor Core Matrix Multiply | Specialized matrix ops | 4x throughput for FP16/BF16 matmul |
| TMA (Tensor Memory Accelerator) | Async memory transfers | Overlap computation with data movement |
| Thread Block Clusters | Cooperative groups | Better warp coordination |
| WGMMA (Warp Group MMA) | Warp-level matrix ops | Fine-grained parallelism |
Innovation 1: Warp Specialization
Traditional GPU kernels suffer from idle resources when:
- Warps wait for memory loads
- Computation cannot overlap with data transfer
Warp specialization addresses this by dedicating specific warps to specific tasks:
# Conceptual warp specialization in FlashAttention-3
class WarpSpecializedAttention:
"""
FlashAttention-3 uses different warp groups for different tasks
"""
WARP_GROUPS = {
'LOAD_KV': 0, # Warps 0-1: Load K, V from global to shared
'COMPUTE_QK': 2, # Warps 2-3: Compute Q @ K^T
'SOFTMAX': 4, # Warp 4: Softmax computation
'LOAD_V_OUTPUT': 5, # Warp 5: Load V, compute final output
'STORE_RESULT': 7 # Warp 7: Store to global memory
}
@cuda.jit
def forward_kernel(
Q, K, V, O,
stride_q, stride_k, stride_v, stride_o,
B, H, N, D
):
# Get cluster and block identifiers
cluster_id = cuda.cluster.group_index()
block_id = cuda.block_thread_index()
# Each warp group operates independently but synchronizes
# Warp groups can overlap their work
warp_id = cuda.warp_id()
if warp_id in [0, 1]:
# TMA async load for K, V
tma_async_load_kv(K, V, cluster_id)
elif warp_id in [2, 3]:
# Compute QK^T while K,V loads happen
qk_result = compute_qk_matmul(Q, K)
# ... other warp groups
Innovation 2: Interleaved Computation
Instead of computing entire QK^T then softmax then PV sequentially, FlashAttention-3 interleaves operations:
def interleaved_attention_fused(
Q, K, V, O,
BLOCK_M=128, BLOCK_N=64
):
"""
Interleaved matmul and softmax computation
Key innovation: begin softmax before full QK^T completes
"""
# Initialize online softmax statistics
m_i = torch.full(..., -float('inf')) # max so far
l_i = torch.zeros(...) # sum of exp
for i in range(0, N, BLOCK_N):
# Start loading next K,V block
k_next = K[:, i:i+BLOCK_N]
v_next = V[:, i:i+BLOCK_N]
# Compute partial QK^T for current Q block
qk_partial = Q[:, :BLOCK_M] @ k_next.T
# BEGIN SOFTMAX BEFORE FULL BLOCK COMPLETES
# Compute partial softmax incrementally
m_new = torch.maximum(m_i[:, :, :BLOCK_M],
qk_partial.max(-1).values)
# Scale previous output
scale_prev = torch.exp(m_i[:, :, :BLOCK_M] - m_new.unsqueeze(-1))
# Scale current exp values
scale_curr = torch.exp(qk_partial - m_new.unsqueeze(-1))
# Accumulate into output
p_curr = scale_curr @ v_next
O_new = O * scale_prev + p_curr
# Update statistics
l_i_new = l_i * scale_prev.squeeze(-1) + scale_curr.sum(-1)
# Update for next iteration
O[:, :BLOCK_M] = O_new
m_i[:, :, :BLOCK_M] = m_new
l_i[:, :BLOCK_M] = l_i_new
# Final scaling
O = O / l_i.unsqueeze(-1)
Innovation 3: FP8 Low-Precision Computation
Hopper supports FP8 tensor cores with different formats:
# FP8 format options in Hopper
class FP8Formats:
"""
Hopper supports multiple FP8 formats
"""
E4M3 = "e4m3" # 4 exponent, 3 mantissa - more precision
E5M2 = "e5m2" # 5 exponent, 2 mantissa - wider range
def flash_attention_fp8(Q, K, V):
"""
FlashAttention-3 with FP8 computation
Uses E4M3 for most operations, E5M2 for accumulation
"""
# Convert to FP8
Q_fp8 = Q.to(dtype=torch.float8_e4m3fn)
K_fp8 = K.to(dtype=torch.float8_e4m3fn)
V_fp8 = V.to(dtype=torch.float8_e4m3fn)
# FP8 matmul for QK^T
# Uses Tensor Cores at 4x throughput
QK_fp8 = matmul_fp8(Q_fp8, K_fp8.transpose(-2, -1))
# Convert to BF16 for softmax (more precision for softmax)
QK_bf16 = QK_fp8.to(torch.bfloat16)
# Softmax in BF16
S = F.softmax(QK_bf16 * scale, dim=-1)
# Convert back to FP8 for final matmul
S_fp8 = S.to(torch.float8_e4m3fn)
# FP8 matmul for attention scores times V
O_fp8 = matmul_fp8(S_fp8, V_fp8)
# Final output in BF16
return O_fp8.to(torch.bfloat16)
Performance Analysis
FLOP Utilization
FlashAttention-3 achieves remarkable GPU utilization:
| Version | H100 FLOP Utilization | Relative Speed |
|---|---|---|
| PyTorch Standard | 12% | 1x |
| FlashAttention | 30% | 2.5x |
| FlashAttention-2 | 50% | 4x |
| FlashAttention-3 | 75% | 6x |
Memory Bandwidth
# Memory access comparison
memory_access = {
'Standard Attention': 'O(Nยฒ ร d) reads/writes',
'FlashAttention': 'O(N ร d) reads, O(N ร d) writes',
'FlashAttention-3': 'O(N ร d) reads, O(N ร d) writes + async prefetch'
}
Latency Improvements
| Sequence Length | FA2 Latency | FA3 Latency | Improvement |
|---|---|---|---|
| 1K tokens | 2.1ms | 1.4ms | 1.5x |
| 4K tokens | 8.5ms | 5.2ms | 1.6x |
| 16K tokens | 45ms | 26ms | 1.7x |
| 64K tokens | 280ms | 150ms | 1.9x |
Implementation
Using FlashAttention-3
# Install with pip
# pip install flash-attn --prefer-binary
import torch
from flash_attn import flash_attn_func
from flash_attn.bert_padding import index_first_axis
def example_usage():
# Input tensors: [batch, seqlen, nheads, headdim]
q = torch.randn(2, 512, 16, 64, dtype=torch.float16, device='cuda')
k = torch.randn(2, 512, 16, 64, dtype=torch.float16, device='cuda')
v = torch.randn(2, 512, 16, 64, dtype=torch.float16, device='cuda')
# FlashAttention-3 automatically uses optimizations
output = flash_attn_func(
q, k, v,
softmax_scale=None, #่ชๅจ่ฎก็ฎ
causal=True, #ไธไธ่งmask
)
return output
# Check if FP8 is available
def check_fp8_support():
import subprocess
result = subprocess.run(
['nvidia-smi', '--query-gpu=name,compute_9', '--format=csv,noheader'],
capture_output=True, text=True
)
print(result.stdout)
Custom Kernel with FlashAttention-3 Principles
# Simplified custom kernel demonstrating FA3 principles
import torch.cuda.nvtx as nvtx
@torch.cuda.jit
def flash_attention_fused_kernel(
Q, K, V, O,
softmax_scale,
cu_seqlens_q,
cu_seqlens_kv
):
"""
Simplified FA3-style fused kernel
"""
# Shared memory sizes
Q_SZ = BLOCK_M * D
KV_SZ = BLOCK_N * D * 2
# Get thread and block indices
batch_id = cuda.blockIdx.x
head_id = cuda.blockIdx.y
chunk_id = cuda.blockIdx.z
# Warp-specialized operations
# ... (complex synchronization logic)
# Interleaved computation
# 1. Load K,V asynchronously via TMA
# 2. Compute partial QK while loading
# 3. Partial softmax
# 4. Continue with more K,V
Advanced Optimizations
Flash-Decoding Integration
def flash_decoding_with_fa3(
query: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
block_tables: torch.Tensor,
context_lengths: torch.Tensor
):
"""
Combine FlashAttention-3 with Flash-Decoding for long contexts
"""
batch_size, num_heads, num_kv_heads = query.shape[0], query.shape[1], key_cache.shape[2]
head_dim = query.shape[2]
# Split query into chunks for parallel processing
num_query_chunks = (num_heads + WARP_GROUPS - 1) // WARP_GROUPS
# Each warp group processes different query chunk
# Using FA3 optimizations for each chunk
outputs = []
for chunk_idx in range(num_query_chunks):
q_chunk = query[:, chunk_idx*WARP_GROUPS:(chunk_idx+1)*WARP_GROUPS]
# FA3 computation
o_chunk = flash_attention_func(q_chunk, key_cache, value_cache)
outputs.append(o_chunk)
# Merge chunks
return torch.cat(outputs, dim=1)
FP8 Training
def fa3_training_step(model, batch):
"""
Using FP8 with FlashAttention-3 for training
"""
# Enable FP8
with torch.cuda.amp.autocast(dtype=torch.float8_e4m3fn):
# Forward pass with FA3
outputs = []
for layer in model.layers:
q, k, v = layer.compute_qkv(hidden_states)
# FA3 handles FP8 internally
attn_output = flash_attn_func(q, k, v, causal=True)
# Continue forward
hidden_states = layer(attn_output)
# Backward pass also uses FA3 (with recomputation)
loss.backward()
Comparison with Alternatives
| Feature | FlashAttention-3 | Standard Attention | xFormers |
|---|---|---|---|
| Memory | O(N) | O(Nยฒ) | O(N) |
| Precision | Exact | Exact | Approximate |
| Hardware | H100 optimized | Universal | Universal |
| FLOP Utilization | 75% | 12% | 35% |
| FP8 Support | Native | No | No |
| Sequence Length | 64K+ | Limited by memory | Limited |
Best Practices
When to Use FlashAttention-3
- Training: Always recommended for transformer training
- Long Sequences: Critical for sequences > 4K tokens
- H100 Hardware: Maximizes hardware utilization
- Memory Constrained: Reduces memory by 20x for long sequences
Configuration Tips
# Optimal settings for different scenarios
configs = {
'short_context': {
'block_size': 128,
'softmax_scale': None, # auto
},
'long_context': {
'block_size': 64, # smaller for more granularity
'return_softmax': False,
},
'memory_efficient': {
'dropout_p': 0.0, # No dropout for memory
'alibi_slopes': None, # No ALiBi if not needed
}
}
Conclusion
FlashAttention-3 represents the culmination of years of optimization work, specifically tailored for modern GPU architectures. By combining:
- Warp specialization for overlapped computation
- Interleaved operations for reduced latency
- FP8 precision for 4x throughput
It achieves 75% FLOP utilizationโnearly reaching the theoretical maximum for attention computation. As GPU architectures continue to evolve, these principles will shape the future of transformer optimization.
The technique is now widely adopted in training frameworks (PyTorch, Megatron) and inference engines (vLLM, TensorRT-LLM), making modern large language models practical to train and deploy.
Resources
- FlashAttention-3 Paper
- FlashAttention Official GitHub
- NVIDIA Hopper Architecture Whitepaper
- xFormers Library
Comments