Introduction
The context length limitation has been a major bottleneck for large language models. Standard attention requires O(Nยฒ) memory and computation with sequence length N, making it impossible to process long documents, codebases, or entire books. While many optimizations reduce the per-token cost, scaling across multiple GPUs for extremely long contexts requires fundamentally different approaches.
Ring Attention and Unified Sequence Parallelism (USP) solve this problem by distributing the attention computation across multiple devices. Combined with FlashAttention, these techniques enable processing context lengths of 1 million tokens and beyondโopening new possibilities for document analysis, code understanding, and long-horizon reasoning.
The Context Length Challenge
Standard Attention Complexity
def standard_attention_complexity():
"""
Standard attention memory and compute requirements
"""
# For a sequence of length N:
# Memory: O(Nยฒ) for attention matrices
# Compute: O(Nยฒ) for attention computation
# Example: 128K context
sequence_length = 128 * 1024 # 128K
hidden_dim = 4096
num_heads = 32
# Attention matrix size: N ร N
attention_matrix = sequence_length ** 2 # ~16 billion entries
memory_bytes = attention_matrix * 2 # FP16 = 2 bytes
memory_gb = memory_bytes / (1024**3) # ~32 GB just for attention!
# This exceeds single GPU memory for most setups
Existing Solutions and Their Limits
existing_solutions = {
'flash_attention': {
'memory': 'O(N) instead of O(Nยฒ)',
'limitation': 'Still single GPU, limited by GPU memory'
},
'paged_attention': {
'memory': 'Reduces fragmentation',
'limitation': 'Still single GPU'
},
'sparse_attention': {
'memory': 'Reduces computation',
'limitation': 'Loses information, accuracy trade-off'
},
'ring_attention': {
'memory': 'Distributes across GPUs',
'limitation': 'No theoretical limit with more GPUs'
}
}
Ring Attention
Core Concept
Ring Attention distributes the KV cache across multiple GPUs in a ring topology:
class RingAttention:
"""
Ring Attention: Distribute KV cache across GPUs in a ring
Each GPU holds a portion of the KV cache and streams it
to neighbors as needed during attention computation
"""
def __init__(self, num_gpus, num_heads):
self.num_gpus = num_gpus
self.num_heads = num_heads
self.gpu_id = None
def setup_ring(self, device_id):
"""
Setup ring topology
"""
self.gpu_id = device_id
# Each GPU has two neighbors in the ring
self.prev_gpu = (device_id - 1) % self.num_gpus
self.next_gpu = (device_id + 1) % self.num_gpus
def forward(self, query, key, value):
"""
Ring Attention forward pass
"""
# Split sequence across GPUs
seq_len = query.size(1)
chunk_size = seq_len // self.num_gpus
# Each GPU computes attention for its local chunk
local_query = query[:, chunk_size*self.gpu_id:chunk_size*(self.gpu_id+1)]
local_key = key[:, chunk_size*self.gpu_id:chunk_size*(self.gpu_id+1)]
local_value = value[:, chunk_size*self.gpu_id:chunk_size*(self.gpu_id+1)]
# Initialize output
output = torch.zeros_like(local_query)
# Iterate through all KV chunks in ring order
for offset in range(self.num_gpus):
# Compute which KV chunk we need
kv_gpu = (self.gpu_id + offset) % self.num_gpus
# Fetch KV from that GPU (would be actual communication)
remote_key = self.fetch_from_gpu(kv_gpu, 'key')
remote_value = self.fetch_from_gpu(kv_gpu, 'value')
# Compute attention with this chunk
attn_scores = torch.matmul(local_query, remote_key.transpose(-2, -1))
attn_scores = attn_scores / (query.size(-1) ** 0.5)
# Apply attention
chunk_output = torch.matmul(attn_scores, remote_value)
# Accumulate
output = output + chunk_output
# Pass our KV to next GPU for next iteration
self.send_to_gpu(self.next_gpu, local_key, 'key')
self.send_to_gpu(self.next_gpu, local_value, 'value')
return output
Optimized Ring Attention with FlashAttention
import torch
from flash_attn import flash_attn_func
class FlashRingAttention:
"""
FlashAttention combined with Ring Attention
"""
def __init__(self, num_gpus):
self.num_gpus = num_gpus
def ring_attention_fused(
self,
queries, # List of query tensors per GPU
keys, # List of key tensors per GPU
values, # List of value tensors per GPU
cu_seqlens # Cumulative sequence lengths
):
"""
Fused Ring Attention using FlashAttention kernel
Each GPU processes its local queries but needs
access to all keys/values through the ring
"""
outputs = [None] * self.num_gpus
for gpu_id in range(self.num_gpus):
q = queries[gpu_id]
# Collect K and V from all GPUs
# Communication pattern: sequential ring passing
k_full = []
v_full = []
for step in range(self.num_gpus):
source_gpu = (gpu_id + step) % self.num_gpus
k_full.append(keys[source_gpu])
v_full.append(values[source_gpu])
k_full = torch.cat(k_full, dim=1)
v_full = torch.cat(v_full, dim=1)
# Use FlashAttention for local computation
output = flash_attn_func(q, k_full, v_full, causal=False)
outputs[gpu_id] = output
return outputs
def ring_attention_streaming(
self,
local_query,
kv_buffer,
max_chunk_size=4096
):
"""
Streaming version for very long sequences
Instead of waiting for all KV, process in chunks
while KV streams through the ring
"""
seq_len = local_query.size(1)
num_chunks = (seq_len + max_chunk_size - 1) // max_chunk_size
output = torch.zeros_like(local_query)
for chunk_idx in range(num_chunks):
# Get local query chunk
start = chunk_idx * max_chunk_size
end = min(start + max_chunk_size, seq_len)
q_chunk = local_query[:, start:end]
# Accumulate attention from streaming KV
for _ in range(self.num_gpus):
# Receive KV from previous GPU
k_chunk = kv_buffer.receive()
v_chunk = kv_buffer.receive()
# Compute attention
attn_chunk = flash_attn_func(q_chunk, k_chunk, v_chunk)
output[:, start:end] += attn_chunk
# Forward KV to next GPU
kv_buffer.send(k_chunk, v_chunk)
return output
Communication Optimization
class RingCommunicationOptimizer:
"""
Optimizing communication in Ring Attention
"""
@staticmethod
def overlap_compute_communication(
compute_stream,
comm_stream,
key_buffer,
value_buffer
):
"""
Overlap computation and communication using CUDA streams
"""
# While computing attention for chunk i
# Simultaneously receive KV for chunk i+1
for i in range(num_chunks):
# Start async receive for next chunk KV
if i < num_chunks - 1:
future_kv = comm_stream.stream_receive(
key_buffer, value_buffer,
src=gpu_ring[i+1]
)
# Compute attention for current chunk
output_chunk = attention_kernel(query_chunk[i], key_chunk[i], value_chunk[i])
# Wait for KV to arrive if needed
if i < num_chunks - 1:
future_kv.wait()
# Send current KV to next GPU (async)
if i > 0:
comm_stream.async_send(key_chunk[i], value_chunk[i], dst=gpu_ring[i+1])
@staticmethod
def reduce_communication(
keys,
values,
compression_ratio=4
):
"""
Compress KV before sending through ring
Reduces bandwidth requirements
"""
# Project to lower dimension
compressed_keys = keys @ self.compress_proj # [batch, seq, heads, dim//ratio]
compressed_values = values @ self.compress_proj
# Send compressed version through ring
# Decompress on-the-fly when computing attention
return compressed_keys, compressed_values
Unified Sequence Parallelism (USP)
Combining Ring and Ulysses
class UnifiedSequenceParallelism:
"""
USP combines the best of Ring Attention and DeepSpeed-Ulysses
- Ring Attention: Splits sequence, keeps attention heads together
- Ulysses: Splits attention heads, combines sequences
USP: Allows both dimensions to be split flexibly
"""
def __init__(self, num_gpus, num_heads, parallelism='hybrid'):
self.num_gpus = num_gpus
self.num_heads = num_heads
self.parallelism = parallelism
def usp_forward(
self,
hidden_states,
attention_mask,
ring_dim=2,
ulysses_dim=2
):
"""
USP combines ring and ulysses parallelism
Args:
ring_dim: How to split along sequence
ulysses_dim: How to split along attention heads
"""
# Step 1: Sequence splitting (Ring-style)
# Split hidden states across GPUs along sequence dimension
hidden_chunks = self.split_sequence(hidden_states, ring_dim)
# Each GPU gets a chunk of the sequence
local_hidden = hidden_chunks[self.gpu_id]
# Step 2: Attention head splitting (Ulysses-style)
# Further split heads across GPUs within each sequence chunk
query, key, value = self.compute_qkv(local_hidden)
query_heads = self.split_heads(query, ulysses_dim)
# Step 3: All-to-All for attention heads
# Exchange heads across GPUs while keeping sequence together
# This is the "Ulysses" part - combines head dimension across GPUs
query_ulysses = self.all_to_all_heads(query_heads)
key_ulysses = self.all_to_all_heads(key_heads)
value_ulysses = self.all_to_all_heads(value_heads)
# Step 4: Local attention computation
# Now each GPU has a subset of heads for a subset of sequence
# Attention is LOCAL within each GPU (sequence is "together" per GPU)
local_output = self.local_attention(
query_ulysses, key_ulysses, value_ulysses, attention_mask
)
# Step 5: All-to-All to return heads to original GPUs
output_heads = self.all_to_all_heads(local_output)
# Step 6: Gather sequence chunks (Ring reverse)
output = self.gather_sequences(output_heads)
return output
def split_sequence(self, x, num_chunks):
"""Split sequence dimension across GPUs"""
seq_len = x.size(1)
chunk_size = seq_len // num_chunks
chunks = []
for i in range(num_chunks):
start = i * chunk_size
end = start + chunk_size if i < num_chunks - 1 else seq_len
chunks.append(x[:, start:end])
return chunks
def split_heads(self, x, num_splits):
"""Split attention heads across GPUs"""
batch, seq, hidden = x.shape
head_dim = hidden // self.num_heads
# Reshape to heads
x = x.view(batch, seq, self.num_heads, head_dim)
# Split heads
heads_per_gpu = self.num_heads // num_splits
return x.split(heads_per_gpu, dim=2)
def all_to_all_heads(self, local_heads):
"""
All-to-All communication for attention heads
Each GPU sends its local heads to all other GPUs
and receives a portion of each other's heads
"""
# Using NCCL all_to_all primitive
# This is the "Ulysses" transformation
return torch.distributed.all_to_all(local_heads, group=self.gpu_group)
Adaptive USP
class AdaptiveUSP:
"""
Adaptive USP chooses the best parallelism strategy based on workload
"""
def __init__(self, model_config):
self.model_config = model_config
def select_parallelism(
self,
sequence_length,
num_gpus,
num_heads,
memory_per_gpu
):
"""
Automatically select optimal parallelism strategy
"""
# Calculate what's needed
attention_memory = self.estimate_attention_memory(sequence_length, num_heads)
ffn_memory = self.estimate_ffn_memory(sequence_length)
total_memory = attention_memory + ffn_memory
# Decision tree for parallelism
if total_memory <= memory_per_gpu:
# Can fit in one GPU, no parallelism needed
return {'strategy': 'none', 'ring_dim': 1, 'ulysses_dim': 1}
elif attention_memory > memory_per_gpu:
# Attention is the bottleneck, use Ring
# Need sequence parallelism
ring_dim = int(attention_memory / memory_per_gpu) + 1
ring_dim = min(ring_dim, num_gpus)
return {
'strategy': 'ring',
'ring_dim': ring_dim,
'ulysses_dim': 1
}
else:
# FFN is bottleneck, can use Ulysses
# Split heads across GPUs
ulysses_dim = num_gpus
return {
'strategy': 'hybrid',
'ring_dim': 1,
'ulysses_dim': ulysses_dim
}
def hybrid_strategy_selector(
self,
sequence_length,
available_memory,
network_bandwidth
):
"""
Select optimal hybrid strategy
"""
# Cost models
ring_comm = self.estimate_ring_comm(sequence_length)
ulysses_comm = self.estimate_ulysses_comm(sequence_length)
# Network-bound: prefer Ulysses (less communication)
# Memory-bound: prefer Ring
if network_bandwidth < 100e9: # < 100 GB/s
return 'ulysses' # Less bandwidth
elif available_memory < 20e9: # < 20 GB
return 'ring' # More memory efficient
else:
return 'hybrid'
Practical Implementation
Using with vLLM
# Deploy long-context LLM with Ring Attention using vLLM
from vllm import LLM
# Configure for long context
llm = LLM(
model="meta-llama/Llama-3-70B-128K",
tensor_parallel_size=8, # Use 8 GPUs
max_num_seqs=64,
max_model_len=128 * 1024, # 128K context
enable_chunked_prefill=False, # Full sequence attention
# vLLM uses Ring Attention internally for long contexts
)
# Generate with long context
outputs = llm.generate(
prompts=[
"Summarize this entire book: [very long text...]"
],
sampling_params=sampling_params
)
Custom Implementation with PyTorch
import torch
import torch.nn as nn
import torch.distributed as dist
class RingAttentionLayer(nn.Module):
"""
Complete Ring Attention layer
"""
def __init__(self, hidden_size, num_heads, num_gpus):
super().__init__()
self.hidden_size = hidden_size
self.num_heads = num_heads
self.head_dim = hidden_size // num_heads
self.num_gpus = num_gpus
# Standard attention projections
self.q_proj = nn.Linear(hidden_size, hidden_size)
self.k_proj = nn.Linear(hidden_size, hidden_size)
self.v_proj = nn.Linear(hidden_size, hidden_size)
self.o_proj = nn.Linear(hidden_size, hidden_size)
self.norm = nn.LayerNorm(hidden_size)
def forward(self, x, layer_idx):
"""
Forward with Ring Attention
"""
batch_size, seq_len, _ = x.shape
# Compute Q, K, V
q = self.q_proj(x).view(batch_size, seq_len, self.num_heads, self.head_dim)
k = self.k_proj(x).view(batch_size, seq_len, self.num_heads, self.head_dim)
v = self.v_proj(x).view(batch_size, seq_len, self.num_heads, self.head_dim)
# Get current GPU ID
gpu_id = dist.get_rank()
# Split sequence across GPUs
chunk_size = seq_len // self.num_gpus
# Local sequence for this GPU
local_q = q[:, gpu_id*chunk_size:(gpu_id+1)*chunk_size]
# Ring attention: accumulate attention from all KV chunks
output_chunks = []
# Start with local KV
current_k = k[:, gpu_id*chunk_size:(gpu_id+1)*chunk_size]
current_v = v[:, gpu_id*chunk_size:(gpu_id+1)*chunk_size]
for step in range(self.num_gpus):
# Compute attention with current KV chunk
attn_scores = torch.matmul(
local_q,
current_k.transpose(-2, -1)
) / (self.head_dim ** 0.5)
attn_probs = F.softmax(attn_scores, dim=-1)
output_chunk = torch.matmul(attn_probs, current_v)
output_chunks.append(output_chunk)
# Receive KV from next GPU in ring
if step < self.num_gpus - 1:
next_gpu = (gpu_id + step + 1) % self.num_gpus
current_k = self.recv_from_gpu(next_gpu, 'key')
current_v = self.recv_from_gpu(next_gpu, 'value')
# Combine all output chunks
local_output = torch.cat(output_chunks, dim=1)
# Project and add residual
output = self.o_proj(local_output)
output = self.norm(output)
return output + x
def recv_from_gpu(self, src, dtype):
"""Receive tensor from another GPU"""
# Would use torch.distributed.isend/irecv
pass
def send_to_gpu(self, dst, tensor, dtype):
"""Send tensor to another GPU"""
pass
Performance Benchmarks
Context Length Scaling
# Benchmark: Time and memory vs context length
benchmarks = {
'sequence_length': [32K, 64K, 128K, 256K, 512K, 1M],
'ring_attention_8gpu': {
'memory_gb': [8, 12, 20, 36, 68, 130],
'time_per_token_ms': [2.1, 2.3, 2.8, 4.2, 7.5, 15.2],
},
'standard_single_gpu': {
'memory_gb': [8, 'OOM', 'OOM', 'OOM', 'OOM', 'OOM'],
'time_per_token_ms': [2.1, 'N/A', 'N/A', 'N/A', 'N/A', 'N/A'],
},
'usp_8gpu': {
'memory_gb': [6, 8, 12, 18, 32, 58],
'time_per_token_ms': [1.8, 1.9, 2.1, 2.8, 4.5, 8.2],
}
}
Scaling Efficiency
# Scaling efficiency with number of GPUs
scaling_results = {
'2_gpus': {'efficiency': 0.92, 'max_context': '256K'},
'4_gpus': {'efficiency': 0.88, 'max_context': '512K'},
'8_gpus': {'efficiency': 0.85, 'max_context': '1M'},
'16_gpus': {'efficiency': 0.78, 'max_context': '2M'},
# Note: Efficiency decreases due to communication overhead
# But maximum context increases proportionally
}
Comparison with Alternatives
| Feature | Standard | Ring Attention | USP | Local + Global |
|---|---|---|---|---|
| Memory | O(Nยฒ) | O(Nยฒ/P) | O(Nยฒ/P) | O(Nยฒ) |
| Compute | O(Nยฒ) | O(Nยฒ/P) | O(Nยฒ/P) | O(N) |
| Max Context | Limited | 1M+ | 1M+ | 128K |
| Communication | None | Ring | All-to-All | Tree |
| Complexity | Simple | Medium | High | Medium |
Use Cases
Long Document Analysis
use_cases = {
'document_qa': {
'context': 'Full books, legal documents',
'length': '100K - 1M tokens',
'example': 'Answer questions about entire codebases'
},
'code_understanding': {
'context': 'Large code repositories',
'length': '500K+ tokens',
'example': 'Understand entire monorepo, cross-file analysis'
},
'long_conversation': {
'context': 'Multi-hour chat sessions',
'length': '128K+ tokens',
'example': 'AI assistant with perfect long-term memory'
},
'research_synthesis': {
'context': 'Reading multiple papers',
'length': '500K+ tokens',
'example': 'Synthesize insights from 100s of papers'
}
}
Future Directions
future_directions = {
'heterogeneous_systems': {
'description': 'Mix CPU, GPU, and specialized accelerators',
'challenge': 'Different memory and compute characteristics'
},
'hierarchical_ring': {
'description': 'Multiple rings (within node, across nodes)',
'challenge': 'Bandwidth differences at different levels'
},
'adaptive_attention': {
'description': 'Dynamically choose attention pattern per layer',
'challenge': 'When to use what pattern'
},
'1B_context': {
'description': 'Scale to billion token context',
'challenge': 'Communication becomes bottleneck'
}
}
Conclusion
Ring Attention and Unified Sequence Parallelism represent the cutting edge of long-context LLM deployment:
- ็ช็ ด้ๅถ: 1M+ token context now practical
- ๅฏๆฉๅฑ: ็่ฎบไธๆ ไธ้๏ผ้่ฟๅขๅ GPU๏ผ
- ๆ็้ซ: ็ธๆฏๅGPU๏ผๅ ๅญ็บฟๆงๅๅฐ
- ๅ ผๅฎนๅฅฝ: ๅฏไธFlashAttentionใPagedAttention็ปๅ
่ฟไบๆๆฏๆญฃๅจๆนๅๆไปฌๅค็้ฟๆๆกฃใไปฃ็ ๅบๅๅคๆๆจ็ไปปๅก็ๆนๅผใ้็GPU้็พค่งๆจก็ๆฉๅคงๅ้ไฟกไผๅ็ๆน่ฟ๏ผ็พไธtoken็บงๅซ็ไธไธๆๅฐๆไธบๆ ๅ้ ็ฝฎใ
Resources
- Ring Attention Paper
- USP Paper: Unified Sequence Parallelism
- Long Context with Ring Attention Tutorial
- vLLM Long Context Documentation
- FlashAttention with Ring Communication
Comments