Skip to main content
โšก Calmops

Sparse Attention Algorithms: Efficient Transformers at Scale

Introduction

The Transformer architecture has revolutionized natural language processing and beyond, but its self-attention mechanism faces a fundamental challenge: quadratic computational and memory complexity with respect to sequence length. For a sequence of length n, standard attention requires O(nยฒ) operationsโ€”prohibitive for long sequences like documents, genomes, or audio streams.

Sparse attention emerges as a powerful solution, reducing this complexity to O(nโˆšn) or even O(n) while maintaining competitive model quality. By selectively attending to only a subset of positions rather than all pairs, sparse attention makes Transformers tractable for real-world applications with long contexts.

In 2026, sparse attention has become essential for deploying large language models efficiently. This comprehensive guide explores the mathematics, implementations, and practical applications of sparse attention algorithms.

The Attention Computation Problem

Standard Attention Complexity

Standard self-attention computes:

$$\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V$$

Where Q (queries), K (keys), and V (values) are matrices of shape [n, d]. The complexity breakdown:

  • QK^T multiplication: O(nยฒ ร— d)
  • Softmax: O(nยฒ)
  • Matrix multiplication with V: O(nยฒ ร— d)

Total: O(nยฒd) time and O(nยฒ) memory

import torch
import torch.nn as nn
import torch.nn.functional as F

class StandardAttention(nn.Module):
    """
    Standard full attention - O(nยฒ) complexity.
    """
    
    def __init__(self, d_model, num_heads=8):
        super().__init__()
        self.d_model = d_model
        self.num_heads = num_heads
        self.d_k = d_model // num_heads
        
        self.W_q = nn.Linear(d_model, d_model)
        self.W_k = nn.Linear(d_model, d_model)
        self.W_v = nn.Linear(d_model, d_model)
        self.W_o = nn.Linear(d_model, d_model)
        
    def forward(self, x, mask=None):
        """
        Forward pass with O(nยฒ) attention.
        
        Args:
            x: [batch, seq_len, d_model]
            mask: Optional attention mask
        
        Returns:
            output: [batch, seq_len, d_model]
        """
        batch_size, seq_len, _ = x.shape
        
        # Project to Q, K, V
        Q = self.W_q(x).view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
        K = self.W_k(x).view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
        V = self.W_v(x).view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
        
        # Attention scores: [batch, heads, seq_len, seq_len]
        scores = torch.matmul(Q, K.transpose(-2, -1)) / (self.d_k ** 0.5)
        
        if mask is not None:
            scores = scores.masked_fill(mask == 0, float('-inf'))
            
        attn_weights = F.softmax(scores, dim=-1)
        
        # Apply attention to values
        output = torch.matmul(attn_weights, V)
        
        # Reshape and project
        output = output.transpose(1, 2).contiguous().view(batch_size, seq_len, self.d_model)
        
        return self.W_o(output)

Sparse Attention Patterns

1. Fixed Sparse Patterns

Sliding Window Attention

Attend only to neighboring tokens within a fixed window:

class SlidingWindowAttention(nn.Module):
    """
    Sliding window attention - attend to local context only.
    
    Complexity: O(n ร— w) where w is window size.
    """
    
    def __init__(self, d_model, num_heads=8, window_size=512):
        super().__init__()
        self.d_model = d_model
        self.num_heads = num_heads
        self.d_k = d_model // num_heads
        self.window_size = window_size
        
        self.W_q = nn.Linear(d_model, d_model)
        self.W_k = nn.Linear(d_model, d_model)
        self.W_v = nn.Linear(d_model, d_model)
        self.W_o = nn.Linear(d_model, d_model)
        
    def forward(self, x, mask=None):
        batch_size, seq_len, _ = x.shape
        
        Q = self.W_q(x).view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
        K = self.W_k(x).view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
        V = self.W_v(x).view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
        
        # Create sliding window mask
        half_window = self.window_size // 2
        
        # Generate position indices
        positions = torch.arange(seq_len, device=x.device)
        
        # Compute which positions to attend to
        mask = torch.zeros(
            batch_size, self.num_heads, seq_len, seq_len, 
            device=x.device, dtype=torch.bool
        )
        
        for i in range(seq_len):
            start = max(0, i - half_window)
            end = min(seq_len, i + half_window + 1)
            mask[:, :, i, start:end] = True
            
        # Apply attention within window
        scores = torch.matmul(Q, K.transpose(-2, -1)) / (self.d_k ** 0.5)
        scores = scores.masked_fill(~mask, float('-inf'))
        
        attn_weights = F.softmax(scores, dim=-1)
        output = torch.matmul(attn_weights, V)
        
        output = output.transpose(1, 2).contiguous().view(batch_size, seq_len, self.d_model)
        
        return self.W_o(output)

Global + Local Attention

Combine global tokens (attended by all) with local windows:

class GlobalLocalAttention(nn.Module):
    """
    Global + Local attention pattern.
    
    Some tokens (global) attend to all positions.
    Other tokens (local) attend only within window.
    """
    
    def __init__(self, d_model, num_heads=8, window_size=128, 
                 num_global_tokens=2):
        super().__init__()
        self.d_model = d_model
        self.num_heads = num_heads
        self.d_k = d_model // num_heads
        self.window_size = window_size
        self.num_global = num_global_tokens
        
        self.W_q = nn.Linear(d_model, d_model)
        self.W_k = nn.Linear(d_model, d_model)
        self.W_v = nn.Linear(d_model, d_model)
        self.W_o = nn.Linear(d_model, d_model)
        
    def create_global_local_mask(self, batch_size, seq_len):
        """
        Create mask for global + local attention.
        """
        device = x.device
        half_window = self.window_size // 2
        
        # Global positions (first num_global tokens)
        global_pos = torch.arange(self.num_global, device=device)
        
        mask = torch.zeros(
            batch_size, self.num_heads, seq_len, seq_len,
            device=device, dtype=torch.bool
        )
        
        for i in range(seq_len):
            if i < self.num_global:
                # Global token attends to everything
                mask[:, :, i, :] = True
            else:
                # Local token attends to global + local window
                start = max(0, i - half_window)
                end = min(seq_len, i + half_window + 1)
                mask[:, :, i, :self.num_global] = True  # Global
                mask[:, :, i, start:end] = True  # Local window
                
        return mask

2. Data-Dependent Sparse Patterns

Random Attention

Randomly sample positions to attend to:

class RandomAttention(nn.Module):
    """
    Random attention - sample random positions.
    
    Complexity: O(n ร— r) where r is number of random samples.
    """
    
    def __init__(self, d_model, num_heads=8, num_random=32):
        super().__init__()
        self.d_model = d_model
        self.num_heads = num_heads
        self.d_k = d_model // num_heads
        self.num_random = num_random
        
        self.W_q = nn.Linear(d_model, d_model)
        self.W_k = nn.Linear(d_model, d_model)
        self.W_v = nn.Linear(d_model, d_model)
        self.W_o = nn.Linear(d_model, d_model)
        
    def forward(self, x):
        batch_size, seq_len, _ = x.shape
        
        Q = self.W_q(x).view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
        K = self.W_k(x).view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
        V = self.W_v(x).view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
        
        # For each query, sample random keys
        random_indices = torch.randint(
            0, seq_len, 
            (batch_size, self.num_heads, seq_len, self.num_random),
            device=x.device
        )
        
        # Gather random K and V
        K_random = torch.gather(
            K.expand(batch_size, -1, -1, -1),  # [B, H, N, D]
            2,  # dim to gather on
            random_indices  # [B, H, N, R]
        )
        V_random = torch.gather(
            V.expand(batch_size, -1, -1, -1),
            2,
            random_indices
        )
        
        # Compute attention with random samples
        scores_random = torch.matmul(Q, K_random.transpose(-2, -1)) / (self.d_k ** 0.5)
        attn_random = F.softmax(scores_random, dim=-1)
        
        output_random = torch.matmul(attn_random, V_random)
        
        # Project output
        output = output_random.transpose(1, 2).contiguous().view(batch_size, seq_len, self.d_model)
        
        return self.W_o(output)

3. Linear Attention: O(n) Complexity

Replace softmax with kernel approximation:

class LinearAttention(nn.Module):
    """
    Linear attention using feature map approximation.
    
    Complexity: O(n ร— dยฒ) or O(n ร— d) with optimizations.
    
    Key insight: softmax(QK^T) โ‰ˆ ฯ†(Q) ร— ฯ†(K)^T
    """
    
    def __init__(self, d_model, num_heads=8):
        super().__init__()
        self.d_model = d_model
        self.num_heads = num_heads
        self.d_k = d_model // num_heads
        
        self.W_q = nn.Linear(d_model, d_model)
        self.W_k = nn.Linear(d_model, d_model)
        self.W_v = nn.Linear(d_model, d_model)
        self.W_o = nn.Linear(d_model, d_model)
        
    def forward(self, x):
        batch_size, seq_len, _ = x.shape
        
        Q = self.W_q(x).view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
        K = self.W_k(x).view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
        V = self.W_v(x).view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
        
        # Feature map (positive ReLU)
        def phi(x):
            return F.relu(x)
            
        Q_prime = phi(Q)
        K_prime = phi(K)
        
        # Linear attention: (Q' ร— K'^T) ร— V
        KV = torch.einsum('bhdn,bhem->bhem', K_prime, V)
        QKV = torch.einsum('bhdk,bhem->bhde', Q_prime, KV)
        
        # Normalize
        Q_sum = Q_prime.sum(dim=-1, keepdim=True)
        KV_sum = KV.sum(dim=-2, keepdim=True)
        denominator = torch.einsum('bhdk,bhem->bhde', Q_sum, KV_sum).clamp(min=1e-10)
        
        output = QKV / denominator
        
        output = output.transpose(1, 2).contiguous().view(batch_size, seq_len, self.d_model)
        
        return self.W_o(output)

Complete Sparse Transformer

class SparseTransformerLayer(nn.Module):
    """
    Complete sparse transformer layer.
    """
    
    def __init__(self, d_model, num_heads, d_ff, sparse_type='window', 
                 window_size=512, dropout=0.1):
        super().__init__()
        
        # Sparse attention
        if sparse_type == 'window':
            self.attention = SlidingWindowAttention(d_model, num_heads, window_size)
        elif sparse_type == 'random':
            self.attention = RandomAttention(d_model, num_heads, num_random=32)
        elif sparse_type == 'linear':
            self.attention = LinearAttention(d_model, num_heads)
        else:
            self.attention = StandardAttention(d_model, num_heads)
            
        # Feed-forward
        self.ffn = nn.Sequential(
            nn.Linear(d_model, d_ff),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(d_ff, d_model),
            nn.Dropout(dropout)
        )
        
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        
    def forward(self, x, mask=None):
        # Attention with residual
        attn_out = self.attention(self.norm1(x), mask)
        x = x + attn_out
        
        # FFN with residual
        ffn_out = self.ffn(self.norm2(x))
        x = x + ffn_out
        
        return x

Pattern Comparison

Pattern Complexity Best For Quality
Full O(nยฒ) Short sequences Best
Sliding Window O(nร—w) Long contexts Very Good
Random O(nร—r) Fast inference Good
Linear O(nร—d) Very long sequences Good
Routing O(nร—c) Adaptive Very Good

Practical Applications

Long Document Processing

class LongDocumentModel(nn.Module):
    """
    Process long documents using sparse attention.
    """
    
    def __init__(self, config):
        super().__init__()
        
        self.encoder = SlidingWindowAttention(
            config.d_model, config.num_heads, window_size=512
        )
        
    def process_long_document(self, input_ids):
        """
        Process document using sliding window with overlap.
        """
        chunk_size = 8192
        overlap = 512
        
        chunks = []
        for i in range(0, len(input_ids), chunk_size - overlap):
            chunk = input_ids[i:i + chunk_size]
            chunks.append(chunk)
            
        # Encode each chunk
        chunk_embeddings = []
        for chunk in chunks:
            embedding = self.encoder(chunk.unsqueeze(0))
            chunk_embeddings.append(embedding[:, -1, :])
            
        # Combine
        combined = torch.stack(chunk_embeddings).mean(dim=0)
        
        return combined

Best Practices

  1. Warmup: Use standard attention initially, then switch to sparse
  2. Pattern Selection: Data-dependent patterns need training to learn
  3. Hybrid: Combine sparse attention for most, full for critical positions
  4. Hardware: Match sparse pattern to your hardware capabilities

Future Directions in 2026

  • Hardware-Optimized Sparse: Custom kernels for sparse patterns
  • Dynamic Sparsity: Adapt sparsity pattern based on content
  • Hierarchical Attention: Multi-level sparse at different granularities
  • Sparse + MoE: Combine with mixture of experts

Resources

Conclusion

Sparse attention represents a crucial breakthrough in making Transformers practical for real-world applications with long sequences. By carefully designing which positions to attend to, we can dramatically reduce computational and memory requirements while maintaining model quality.

The choice of sparse pattern depends on your specific use case: sliding window for autoregressive generation, random for efficiency, linear for very long contexts, and learned patterns for maximum flexibility. As hardware continues to evolve, sparse attention will remain essential for scaling Transformers efficiently.

Comments