Introduction
The Transformer architecture has revolutionized natural language processing and beyond, but its self-attention mechanism faces a fundamental challenge: quadratic computational and memory complexity with respect to sequence length. For a sequence of length n, standard attention requires O(nยฒ) operationsโprohibitive for long sequences like documents, genomes, or audio streams.
Sparse attention emerges as a powerful solution, reducing this complexity to O(nโn) or even O(n) while maintaining competitive model quality. By selectively attending to only a subset of positions rather than all pairs, sparse attention makes Transformers tractable for real-world applications with long contexts.
In 2026, sparse attention has become essential for deploying large language models efficiently. This comprehensive guide explores the mathematics, implementations, and practical applications of sparse attention algorithms.
The Attention Computation Problem
Standard Attention Complexity
Standard self-attention computes:
$$\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V$$Where Q (queries), K (keys), and V (values) are matrices of shape [n, d]. The complexity breakdown:
- QK^T multiplication: O(nยฒ ร d)
- Softmax: O(nยฒ)
- Matrix multiplication with V: O(nยฒ ร d)
Total: O(nยฒd) time and O(nยฒ) memory
import torch
import torch.nn as nn
import torch.nn.functional as F
class StandardAttention(nn.Module):
"""
Standard full attention - O(nยฒ) complexity.
"""
def __init__(self, d_model, num_heads=8):
super().__init__()
self.d_model = d_model
self.num_heads = num_heads
self.d_k = d_model // num_heads
self.W_q = nn.Linear(d_model, d_model)
self.W_k = nn.Linear(d_model, d_model)
self.W_v = nn.Linear(d_model, d_model)
self.W_o = nn.Linear(d_model, d_model)
def forward(self, x, mask=None):
"""
Forward pass with O(nยฒ) attention.
Args:
x: [batch, seq_len, d_model]
mask: Optional attention mask
Returns:
output: [batch, seq_len, d_model]
"""
batch_size, seq_len, _ = x.shape
# Project to Q, K, V
Q = self.W_q(x).view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
K = self.W_k(x).view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
V = self.W_v(x).view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
# Attention scores: [batch, heads, seq_len, seq_len]
scores = torch.matmul(Q, K.transpose(-2, -1)) / (self.d_k ** 0.5)
if mask is not None:
scores = scores.masked_fill(mask == 0, float('-inf'))
attn_weights = F.softmax(scores, dim=-1)
# Apply attention to values
output = torch.matmul(attn_weights, V)
# Reshape and project
output = output.transpose(1, 2).contiguous().view(batch_size, seq_len, self.d_model)
return self.W_o(output)
Sparse Attention Patterns
1. Fixed Sparse Patterns
Sliding Window Attention
Attend only to neighboring tokens within a fixed window:
class SlidingWindowAttention(nn.Module):
"""
Sliding window attention - attend to local context only.
Complexity: O(n ร w) where w is window size.
"""
def __init__(self, d_model, num_heads=8, window_size=512):
super().__init__()
self.d_model = d_model
self.num_heads = num_heads
self.d_k = d_model // num_heads
self.window_size = window_size
self.W_q = nn.Linear(d_model, d_model)
self.W_k = nn.Linear(d_model, d_model)
self.W_v = nn.Linear(d_model, d_model)
self.W_o = nn.Linear(d_model, d_model)
def forward(self, x, mask=None):
batch_size, seq_len, _ = x.shape
Q = self.W_q(x).view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
K = self.W_k(x).view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
V = self.W_v(x).view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
# Create sliding window mask
half_window = self.window_size // 2
# Generate position indices
positions = torch.arange(seq_len, device=x.device)
# Compute which positions to attend to
mask = torch.zeros(
batch_size, self.num_heads, seq_len, seq_len,
device=x.device, dtype=torch.bool
)
for i in range(seq_len):
start = max(0, i - half_window)
end = min(seq_len, i + half_window + 1)
mask[:, :, i, start:end] = True
# Apply attention within window
scores = torch.matmul(Q, K.transpose(-2, -1)) / (self.d_k ** 0.5)
scores = scores.masked_fill(~mask, float('-inf'))
attn_weights = F.softmax(scores, dim=-1)
output = torch.matmul(attn_weights, V)
output = output.transpose(1, 2).contiguous().view(batch_size, seq_len, self.d_model)
return self.W_o(output)
Global + Local Attention
Combine global tokens (attended by all) with local windows:
class GlobalLocalAttention(nn.Module):
"""
Global + Local attention pattern.
Some tokens (global) attend to all positions.
Other tokens (local) attend only within window.
"""
def __init__(self, d_model, num_heads=8, window_size=128,
num_global_tokens=2):
super().__init__()
self.d_model = d_model
self.num_heads = num_heads
self.d_k = d_model // num_heads
self.window_size = window_size
self.num_global = num_global_tokens
self.W_q = nn.Linear(d_model, d_model)
self.W_k = nn.Linear(d_model, d_model)
self.W_v = nn.Linear(d_model, d_model)
self.W_o = nn.Linear(d_model, d_model)
def create_global_local_mask(self, batch_size, seq_len):
"""
Create mask for global + local attention.
"""
device = x.device
half_window = self.window_size // 2
# Global positions (first num_global tokens)
global_pos = torch.arange(self.num_global, device=device)
mask = torch.zeros(
batch_size, self.num_heads, seq_len, seq_len,
device=device, dtype=torch.bool
)
for i in range(seq_len):
if i < self.num_global:
# Global token attends to everything
mask[:, :, i, :] = True
else:
# Local token attends to global + local window
start = max(0, i - half_window)
end = min(seq_len, i + half_window + 1)
mask[:, :, i, :self.num_global] = True # Global
mask[:, :, i, start:end] = True # Local window
return mask
2. Data-Dependent Sparse Patterns
Random Attention
Randomly sample positions to attend to:
class RandomAttention(nn.Module):
"""
Random attention - sample random positions.
Complexity: O(n ร r) where r is number of random samples.
"""
def __init__(self, d_model, num_heads=8, num_random=32):
super().__init__()
self.d_model = d_model
self.num_heads = num_heads
self.d_k = d_model // num_heads
self.num_random = num_random
self.W_q = nn.Linear(d_model, d_model)
self.W_k = nn.Linear(d_model, d_model)
self.W_v = nn.Linear(d_model, d_model)
self.W_o = nn.Linear(d_model, d_model)
def forward(self, x):
batch_size, seq_len, _ = x.shape
Q = self.W_q(x).view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
K = self.W_k(x).view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
V = self.W_v(x).view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
# For each query, sample random keys
random_indices = torch.randint(
0, seq_len,
(batch_size, self.num_heads, seq_len, self.num_random),
device=x.device
)
# Gather random K and V
K_random = torch.gather(
K.expand(batch_size, -1, -1, -1), # [B, H, N, D]
2, # dim to gather on
random_indices # [B, H, N, R]
)
V_random = torch.gather(
V.expand(batch_size, -1, -1, -1),
2,
random_indices
)
# Compute attention with random samples
scores_random = torch.matmul(Q, K_random.transpose(-2, -1)) / (self.d_k ** 0.5)
attn_random = F.softmax(scores_random, dim=-1)
output_random = torch.matmul(attn_random, V_random)
# Project output
output = output_random.transpose(1, 2).contiguous().view(batch_size, seq_len, self.d_model)
return self.W_o(output)
3. Linear Attention: O(n) Complexity
Replace softmax with kernel approximation:
class LinearAttention(nn.Module):
"""
Linear attention using feature map approximation.
Complexity: O(n ร dยฒ) or O(n ร d) with optimizations.
Key insight: softmax(QK^T) โ ฯ(Q) ร ฯ(K)^T
"""
def __init__(self, d_model, num_heads=8):
super().__init__()
self.d_model = d_model
self.num_heads = num_heads
self.d_k = d_model // num_heads
self.W_q = nn.Linear(d_model, d_model)
self.W_k = nn.Linear(d_model, d_model)
self.W_v = nn.Linear(d_model, d_model)
self.W_o = nn.Linear(d_model, d_model)
def forward(self, x):
batch_size, seq_len, _ = x.shape
Q = self.W_q(x).view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
K = self.W_k(x).view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
V = self.W_v(x).view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
# Feature map (positive ReLU)
def phi(x):
return F.relu(x)
Q_prime = phi(Q)
K_prime = phi(K)
# Linear attention: (Q' ร K'^T) ร V
KV = torch.einsum('bhdn,bhem->bhem', K_prime, V)
QKV = torch.einsum('bhdk,bhem->bhde', Q_prime, KV)
# Normalize
Q_sum = Q_prime.sum(dim=-1, keepdim=True)
KV_sum = KV.sum(dim=-2, keepdim=True)
denominator = torch.einsum('bhdk,bhem->bhde', Q_sum, KV_sum).clamp(min=1e-10)
output = QKV / denominator
output = output.transpose(1, 2).contiguous().view(batch_size, seq_len, self.d_model)
return self.W_o(output)
Complete Sparse Transformer
class SparseTransformerLayer(nn.Module):
"""
Complete sparse transformer layer.
"""
def __init__(self, d_model, num_heads, d_ff, sparse_type='window',
window_size=512, dropout=0.1):
super().__init__()
# Sparse attention
if sparse_type == 'window':
self.attention = SlidingWindowAttention(d_model, num_heads, window_size)
elif sparse_type == 'random':
self.attention = RandomAttention(d_model, num_heads, num_random=32)
elif sparse_type == 'linear':
self.attention = LinearAttention(d_model, num_heads)
else:
self.attention = StandardAttention(d_model, num_heads)
# Feed-forward
self.ffn = nn.Sequential(
nn.Linear(d_model, d_ff),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(d_ff, d_model),
nn.Dropout(dropout)
)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
def forward(self, x, mask=None):
# Attention with residual
attn_out = self.attention(self.norm1(x), mask)
x = x + attn_out
# FFN with residual
ffn_out = self.ffn(self.norm2(x))
x = x + ffn_out
return x
Pattern Comparison
| Pattern | Complexity | Best For | Quality |
|---|---|---|---|
| Full | O(nยฒ) | Short sequences | Best |
| Sliding Window | O(nรw) | Long contexts | Very Good |
| Random | O(nรr) | Fast inference | Good |
| Linear | O(nรd) | Very long sequences | Good |
| Routing | O(nรc) | Adaptive | Very Good |
Practical Applications
Long Document Processing
class LongDocumentModel(nn.Module):
"""
Process long documents using sparse attention.
"""
def __init__(self, config):
super().__init__()
self.encoder = SlidingWindowAttention(
config.d_model, config.num_heads, window_size=512
)
def process_long_document(self, input_ids):
"""
Process document using sliding window with overlap.
"""
chunk_size = 8192
overlap = 512
chunks = []
for i in range(0, len(input_ids), chunk_size - overlap):
chunk = input_ids[i:i + chunk_size]
chunks.append(chunk)
# Encode each chunk
chunk_embeddings = []
for chunk in chunks:
embedding = self.encoder(chunk.unsqueeze(0))
chunk_embeddings.append(embedding[:, -1, :])
# Combine
combined = torch.stack(chunk_embeddings).mean(dim=0)
return combined
Best Practices
- Warmup: Use standard attention initially, then switch to sparse
- Pattern Selection: Data-dependent patterns need training to learn
- Hybrid: Combine sparse attention for most, full for critical positions
- Hardware: Match sparse pattern to your hardware capabilities
Future Directions in 2026
- Hardware-Optimized Sparse: Custom kernels for sparse patterns
- Dynamic Sparsity: Adapt sparsity pattern based on content
- Hierarchical Attention: Multi-level sparse at different granularities
- Sparse + MoE: Combine with mixture of experts
Resources
Conclusion
Sparse attention represents a crucial breakthrough in making Transformers practical for real-world applications with long sequences. By carefully designing which positions to attend to, we can dramatically reduce computational and memory requirements while maintaining model quality.
The choice of sparse pattern depends on your specific use case: sliding window for autoregressive generation, random for efficiency, linear for very long contexts, and learned patterns for maximum flexibility. As hardware continues to evolve, sparse attention will remain essential for scaling Transformers efficiently.
Comments