Skip to main content
โšก Calmops

FlashAttention-3: Next-Generation Transformer Optimization

Introduction

FlashAttention revolutionized transformer training and inference by reducing attention memory complexity from O(Nยฒ) to O(N) while maintaining exact computation. FlashAttention-2 pushed this further with improved parallelism and warp partitioning. Now, FlashAttention-3 brings unprecedented optimizations specifically designed for NVIDIA’s Hopper architecture, achieving 75% FLOP utilization on H100 GPUsโ€”a 1.5-2x improvement over FlashAttention-2.

This article explores the three key innovations behind FlashAttention-3: warp specialization, interleaved computation, and low-precision execution.

The Evolution of FlashAttention

FlashAttention (V1): Tiling and Recomputation

The original FlashAttention introduced two key ideas:

  1. Tiling: Process attention in small blocks that fit in SRAM
  2. Recomputation: Recalculate gradients during backward pass instead of storing all attention matrices
# Simplified FlashAttention V1 algorithm
def flash_attention_v1(Q, K, V, softmax_scale):
    """
    Standard FlashAttention with tiling
    """
    BATCH, N_CTX, H, D = Q.shape
    D_HEAD = D  # head dimension
    
    # Initialize output and statistics
    O = torch.zeros_like(Q)
    l = torch.zeros((BATCH, H, N_CTX), device=Q.device)
    m = torch.full((BATCH, H, N_CTX), -float('inf'), device=Q.device)
    
    # Process in tiles (blocks along sequence dimension)
    for j in range(0, N_CTX, BLOCK_M):
        # Load Q block
        q_block = Q[:, j:j+BLOCK_M]
        
        for i in range(0, N_CTX, BLOCK_N):
            # Load K, V blocks
            k_block = K[:, i:i+BLOCK_N]
            v_block = V[:, i:i+BLOCK_N]
            
            # Compute attention for this tile
            qk = torch.matmul(q_block, k_block.transpose(-2, -1)) * softmax_scale
            
            # Online softmax computation
            m_new = torch.maximum(m, qk.max(-1).values)
            l_new = torch.exp(m - m_new.unsqueeze(-1)) @ torch.ones_like(qk[0]) + \
                    torch.exp(qk - m_new.unsqueeze(-1)).sum(-1)
            
            # Scale and accumulate output
            p = torch.exp(qk - m_new.unsqueeze(-1))
            O_block = (p @ v_block).type_as(O)
            
            O[:, j:j+BLOCK_M] = O[:, j:j+BLOCK_M] * \
                torch.exp(m[:, j:j+BLOCK_M] - m_new).unsqueeze(-1) + O_block
            l[:, j:j+BLOCK_M] = l[:, j:j+BLOCK_M] * \
                torch.exp(m[:, j:j+BLOCK_M] - m_new) + l_new
            
            m = m_new
    
    O = O / l.unsqueeze(-1)
    return O

FlashAttention-2: Improved Parallelism

FlashAttention-2 made several improvements:

  • Better thread block partitioning across sequence length
  • Reorganized loops for improved parallelism
  • Optimized warp-level reductions

FlashAttention-3: Hopper-Specific Optimizations

Architecture Background: NVIDIA Hopper

Hopper GPU architecture introduced key features that FlashAttention-3 exploits:

Feature Description Benefit for Attention
Tensor Core Matrix Multiply Specialized matrix ops 4x throughput for FP16/BF16 matmul
TMA (Tensor Memory Accelerator) Async memory transfers Overlap computation with data movement
Thread Block Clusters Cooperative groups Better warp coordination
WGMMA (Warp Group MMA) Warp-level matrix ops Fine-grained parallelism

Innovation 1: Warp Specialization

Traditional GPU kernels suffer from idle resources when:

  • Warps wait for memory loads
  • Computation cannot overlap with data transfer

Warp specialization addresses this by dedicating specific warps to specific tasks:

# Conceptual warp specialization in FlashAttention-3
class WarpSpecializedAttention:
    """
    FlashAttention-3 uses different warp groups for different tasks
    """
    
    WARP_GROUPS = {
        'LOAD_KV': 0,      # Warps 0-1: Load K, V from global to shared
        'COMPUTE_QK': 2,    # Warps 2-3: Compute Q @ K^T
        'SOFTMAX': 4,       # Warp 4: Softmax computation
        'LOAD_V_OUTPUT': 5, # Warp 5: Load V, compute final output
        'STORE_RESULT': 7   # Warp 7: Store to global memory
    }
    
    @cuda.jit
    def forward_kernel(
        Q, K, V, O,
        stride_q, stride_k, stride_v, stride_o,
        B, H, N, D
    ):
        # Get cluster and block identifiers
        cluster_id = cuda.cluster.group_index()
        block_id = cuda.block_thread_index()
        
        # Each warp group operates independently but synchronizes
        # Warp groups can overlap their work
        warp_id = cuda.warp_id()
        
        if warp_id in [0, 1]:
            # TMA async load for K, V
            tma_async_load_kv(K, V, cluster_id)
        elif warp_id in [2, 3]:
            # Compute QK^T while K,V loads happen
            qk_result = compute_qk_matmul(Q, K)
        # ... other warp groups

Innovation 2: Interleaved Computation

Instead of computing entire QK^T then softmax then PV sequentially, FlashAttention-3 interleaves operations:

def interleaved_attention_fused(
    Q, K, V, O,
    BLOCK_M=128, BLOCK_N=64
):
    """
    Interleaved matmul and softmax computation
    Key innovation: begin softmax before full QK^T completes
    """
    
    # Initialize online softmax statistics
    m_i = torch.full(..., -float('inf'))  # max so far
    l_i = torch.zeros(...)                # sum of exp
    
    for i in range(0, N, BLOCK_N):
        # Start loading next K,V block
        k_next = K[:, i:i+BLOCK_N]
        v_next = V[:, i:i+BLOCK_N]
        
        # Compute partial QK^T for current Q block
        qk_partial = Q[:, :BLOCK_M] @ k_next.T
        
        # BEGIN SOFTMAX BEFORE FULL BLOCK COMPLETES
        # Compute partial softmax incrementally
        m_new = torch.maximum(m_i[:, :, :BLOCK_M], 
                             qk_partial.max(-1).values)
        
        # Scale previous output
        scale_prev = torch.exp(m_i[:, :, :BLOCK_M] - m_new.unsqueeze(-1))
        
        # Scale current exp values
        scale_curr = torch.exp(qk_partial - m_new.unsqueeze(-1))
        
        # Accumulate into output
        p_curr = scale_curr @ v_next
        O_new = O * scale_prev + p_curr
        
        # Update statistics
        l_i_new = l_i * scale_prev.squeeze(-1) + scale_curr.sum(-1)
        
        # Update for next iteration
        O[:, :BLOCK_M] = O_new
        m_i[:, :, :BLOCK_M] = m_new
        l_i[:, :BLOCK_M] = l_i_new
    
    # Final scaling
    O = O / l_i.unsqueeze(-1)

Innovation 3: FP8 Low-Precision Computation

Hopper supports FP8 tensor cores with different formats:

# FP8 format options in Hopper
class FP8Formats:
    """
    Hopper supports multiple FP8 formats
    """
    E4M3 = "e4m3"  # 4 exponent, 3 mantissa - more precision
    E5M2 = "e5m2"  # 5 exponent, 2 mantissa - wider range

def flash_attention_fp8(Q, K, V):
    """
    FlashAttention-3 with FP8 computation
    Uses E4M3 for most operations, E5M2 for accumulation
    """
    # Convert to FP8
    Q_fp8 = Q.to(dtype=torch.float8_e4m3fn)
    K_fp8 = K.to(dtype=torch.float8_e4m3fn)
    V_fp8 = V.to(dtype=torch.float8_e4m3fn)
    
    # FP8 matmul for QK^T
    # Uses Tensor Cores at 4x throughput
    QK_fp8 = matmul_fp8(Q_fp8, K_fp8.transpose(-2, -1))
    
    # Convert to BF16 for softmax (more precision for softmax)
    QK_bf16 = QK_fp8.to(torch.bfloat16)
    
    # Softmax in BF16
    S = F.softmax(QK_bf16 * scale, dim=-1)
    
    # Convert back to FP8 for final matmul
    S_fp8 = S.to(torch.float8_e4m3fn)
    
    # FP8 matmul for attention scores times V
    O_fp8 = matmul_fp8(S_fp8, V_fp8)
    
    # Final output in BF16
    return O_fp8.to(torch.bfloat16)

Performance Analysis

FLOP Utilization

FlashAttention-3 achieves remarkable GPU utilization:

Version H100 FLOP Utilization Relative Speed
PyTorch Standard 12% 1x
FlashAttention 30% 2.5x
FlashAttention-2 50% 4x
FlashAttention-3 75% 6x

Memory Bandwidth

# Memory access comparison
memory_access = {
    'Standard Attention': 'O(Nยฒ ร— d) reads/writes',
    'FlashAttention': 'O(N ร— d) reads, O(N ร— d) writes',
    'FlashAttention-3': 'O(N ร— d) reads, O(N ร— d) writes + async prefetch'
}

Latency Improvements

Sequence Length FA2 Latency FA3 Latency Improvement
1K tokens 2.1ms 1.4ms 1.5x
4K tokens 8.5ms 5.2ms 1.6x
16K tokens 45ms 26ms 1.7x
64K tokens 280ms 150ms 1.9x

Implementation

Using FlashAttention-3

# Install with pip
# pip install flash-attn --prefer-binary

import torch
from flash_attn import flash_attn_func
from flash_attn.bert_padding import index_first_axis

def example_usage():
    # Input tensors: [batch, seqlen, nheads, headdim]
    q = torch.randn(2, 512, 16, 64, dtype=torch.float16, device='cuda')
    k = torch.randn(2, 512, 16, 64, dtype=torch.float16, device='cuda')
    v = torch.randn(2, 512, 16, 64, dtype=torch.float16, device='cuda')
    
    # FlashAttention-3 automatically uses optimizations
    output = flash_attn_func(
        q, k, v,
        softmax_scale=None,  #่‡ชๅŠจ่ฎก็ฎ—
        causal=True,         #ไธ‹ไธ‰่ง’mask
    )
    
    return output

# Check if FP8 is available
def check_fp8_support():
    import subprocess
    result = subprocess.run(
        ['nvidia-smi', '--query-gpu=name,compute_9', '--format=csv,noheader'],
        capture_output=True, text=True
    )
    print(result.stdout)

Custom Kernel with FlashAttention-3 Principles

# Simplified custom kernel demonstrating FA3 principles
import torch.cuda.nvtx as nvtx

@torch.cuda.jit
def flash_attention_fused_kernel(
    Q, K, V, O,
    softmax_scale,
    cu_seqlens_q,
    cu_seqlens_kv
):
    """
    Simplified FA3-style fused kernel
    """
    # Shared memory sizes
    Q_SZ = BLOCK_M * D
    KV_SZ = BLOCK_N * D * 2
    
    # Get thread and block indices
    batch_id = cuda.blockIdx.x
    head_id = cuda.blockIdx.y
    chunk_id = cuda.blockIdx.z
    
    # Warp-specialized operations
    # ... (complex synchronization logic)
    
    # Interleaved computation
    # 1. Load K,V asynchronously via TMA
    # 2. Compute partial QK while loading
    # 3. Partial softmax
    # 4. Continue with more K,V

Advanced Optimizations

Flash-Decoding Integration

def flash_decoding_with_fa3(
    query: torch.Tensor,
    key_cache: torch.Tensor,
    value_cache: torch.Tensor,
    block_tables: torch.Tensor,
    context_lengths: torch.Tensor
):
    """
    Combine FlashAttention-3 with Flash-Decoding for long contexts
    """
    batch_size, num_heads, num_kv_heads = query.shape[0], query.shape[1], key_cache.shape[2]
    head_dim = query.shape[2]
    
    # Split query into chunks for parallel processing
    num_query_chunks = (num_heads + WARP_GROUPS - 1) // WARP_GROUPS
    
    # Each warp group processes different query chunk
    # Using FA3 optimizations for each chunk
    outputs = []
    for chunk_idx in range(num_query_chunks):
        q_chunk = query[:, chunk_idx*WARP_GROUPS:(chunk_idx+1)*WARP_GROUPS]
        
        # FA3 computation
        o_chunk = flash_attention_func(q_chunk, key_cache, value_cache)
        outputs.append(o_chunk)
    
    # Merge chunks
    return torch.cat(outputs, dim=1)

FP8 Training

def fa3_training_step(model, batch):
    """
    Using FP8 with FlashAttention-3 for training
    """
    # Enable FP8
    with torch.cuda.amp.autocast(dtype=torch.float8_e4m3fn):
        # Forward pass with FA3
        outputs = []
        for layer in model.layers:
            q, k, v = layer.compute_qkv(hidden_states)
            
            # FA3 handles FP8 internally
            attn_output = flash_attn_func(q, k, v, causal=True)
            
            # Continue forward
            hidden_states = layer(attn_output)
    
    # Backward pass also uses FA3 (with recomputation)
    loss.backward()

Comparison with Alternatives

Feature FlashAttention-3 Standard Attention xFormers
Memory O(N) O(Nยฒ) O(N)
Precision Exact Exact Approximate
Hardware H100 optimized Universal Universal
FLOP Utilization 75% 12% 35%
FP8 Support Native No No
Sequence Length 64K+ Limited by memory Limited

Best Practices

When to Use FlashAttention-3

  1. Training: Always recommended for transformer training
  2. Long Sequences: Critical for sequences > 4K tokens
  3. H100 Hardware: Maximizes hardware utilization
  4. Memory Constrained: Reduces memory by 20x for long sequences

Configuration Tips

# Optimal settings for different scenarios
configs = {
    'short_context': {
        'block_size': 128,
        'softmax_scale': None,  # auto
    },
    'long_context': {
        'block_size': 64,       # smaller for more granularity
        'return_softmax': False,
    },
    'memory_efficient': {
        'dropout_p': 0.0,       # No dropout for memory
        'alibi_slopes': None,   # No ALiBi if not needed
    }
}

Conclusion

FlashAttention-3 represents the culmination of years of optimization work, specifically tailored for modern GPU architectures. By combining:

  • Warp specialization for overlapped computation
  • Interleaved operations for reduced latency
  • FP8 precision for 4x throughput

It achieves 75% FLOP utilizationโ€”nearly reaching the theoretical maximum for attention computation. As GPU architectures continue to evolve, these principles will shape the future of transformer optimization.

The technique is now widely adopted in training frameworks (PyTorch, Megatron) and inference engines (vLLM, TensorRT-LLM), making modern large language models practical to train and deploy.

Resources

Comments