Skip to main content
โšก Calmops

Ring Attention and USP: Scaling Transformer Context to Millions of Tokens

Introduction

The context length limitation has been a major bottleneck for large language models. Standard attention requires O(Nยฒ) memory and computation with sequence length N, making it impossible to process long documents, codebases, or entire books. While many optimizations reduce the per-token cost, scaling across multiple GPUs for extremely long contexts requires fundamentally different approaches.

Ring Attention and Unified Sequence Parallelism (USP) solve this problem by distributing the attention computation across multiple devices. Combined with FlashAttention, these techniques enable processing context lengths of 1 million tokens and beyondโ€”opening new possibilities for document analysis, code understanding, and long-horizon reasoning.

The Context Length Challenge

Standard Attention Complexity

def standard_attention_complexity():
    """
    Standard attention memory and compute requirements
    """
    # For a sequence of length N:
    # Memory: O(Nยฒ) for attention matrices
    # Compute: O(Nยฒ) for attention computation
    
    # Example: 128K context
    sequence_length = 128 * 1024  # 128K
    hidden_dim = 4096
    num_heads = 32
    
    # Attention matrix size: N ร— N
    attention_matrix = sequence_length ** 2  # ~16 billion entries
    memory_bytes = attention_matrix * 2  # FP16 = 2 bytes
    memory_gb = memory_bytes / (1024**3)  # ~32 GB just for attention!
    
    # This exceeds single GPU memory for most setups

Existing Solutions and Their Limits

existing_solutions = {
    'flash_attention': {
        'memory': 'O(N) instead of O(Nยฒ)',
        'limitation': 'Still single GPU, limited by GPU memory'
    },
    'paged_attention': {
        'memory': 'Reduces fragmentation',
        'limitation': 'Still single GPU'
    },
    'sparse_attention': {
        'memory': 'Reduces computation',
        'limitation': 'Loses information, accuracy trade-off'
    },
    'ring_attention': {
        'memory': 'Distributes across GPUs',
        'limitation': 'No theoretical limit with more GPUs'
    }
}

Ring Attention

Core Concept

Ring Attention distributes the KV cache across multiple GPUs in a ring topology:

class RingAttention:
    """
    Ring Attention: Distribute KV cache across GPUs in a ring
    
    Each GPU holds a portion of the KV cache and streams it
    to neighbors as needed during attention computation
    """
    
    def __init__(self, num_gpus, num_heads):
        self.num_gpus = num_gpus
        self.num_heads = num_heads
        self.gpu_id = None
        
    def setup_ring(self, device_id):
        """
        Setup ring topology
        """
        self.gpu_id = device_id
        
        # Each GPU has two neighbors in the ring
        self.prev_gpu = (device_id - 1) % self.num_gpus
        self.next_gpu = (device_id + 1) % self.num_gpus
        
    def forward(self, query, key, value):
        """
        Ring Attention forward pass
        """
        # Split sequence across GPUs
        seq_len = query.size(1)
        chunk_size = seq_len // self.num_gpus
        
        # Each GPU computes attention for its local chunk
        local_query = query[:, chunk_size*self.gpu_id:chunk_size*(self.gpu_id+1)]
        local_key = key[:, chunk_size*self.gpu_id:chunk_size*(self.gpu_id+1)]
        local_value = value[:, chunk_size*self.gpu_id:chunk_size*(self.gpu_id+1)]
        
        # Initialize output
        output = torch.zeros_like(local_query)
        
        # Iterate through all KV chunks in ring order
        for offset in range(self.num_gpus):
            # Compute which KV chunk we need
            kv_gpu = (self.gpu_id + offset) % self.num_gpus
            
            # Fetch KV from that GPU (would be actual communication)
            remote_key = self.fetch_from_gpu(kv_gpu, 'key')
            remote_value = self.fetch_from_gpu(kv_gpu, 'value')
            
            # Compute attention with this chunk
            attn_scores = torch.matmul(local_query, remote_key.transpose(-2, -1))
            attn_scores = attn_scores / (query.size(-1) ** 0.5)
            
            # Apply attention
            chunk_output = torch.matmul(attn_scores, remote_value)
            
            # Accumulate
            output = output + chunk_output
            
            # Pass our KV to next GPU for next iteration
            self.send_to_gpu(self.next_gpu, local_key, 'key')
            self.send_to_gpu(self.next_gpu, local_value, 'value')
        
        return output

Optimized Ring Attention with FlashAttention

import torch
from flash_attn import flash_attn_func

class FlashRingAttention:
    """
    FlashAttention combined with Ring Attention
    """
    
    def __init__(self, num_gpus):
        self.num_gpus = num_gpus
        
    def ring_attention_fused(
        self,
        queries,      # List of query tensors per GPU
        keys,         # List of key tensors per GPU  
        values,       # List of value tensors per GPU
        cu_seqlens    # Cumulative sequence lengths
    ):
        """
        Fused Ring Attention using FlashAttention kernel
        
        Each GPU processes its local queries but needs
        access to all keys/values through the ring
        """
        
        outputs = [None] * self.num_gpus
        
        for gpu_id in range(self.num_gpus):
            q = queries[gpu_id]
            
            # Collect K and V from all GPUs
            # Communication pattern: sequential ring passing
            k_full = []
            v_full = []
            
            for step in range(self.num_gpus):
                source_gpu = (gpu_id + step) % self.num_gpus
                k_full.append(keys[source_gpu])
                v_full.append(values[source_gpu])
            
            k_full = torch.cat(k_full, dim=1)
            v_full = torch.cat(v_full, dim=1)
            
            # Use FlashAttention for local computation
            output = flash_attn_func(q, k_full, v_full, causal=False)
            outputs[gpu_id] = output
        
        return outputs
    
    def ring_attention_streaming(
        self,
        local_query,
        kv_buffer,
        max_chunk_size=4096
    ):
        """
        Streaming version for very long sequences
        
        Instead of waiting for all KV, process in chunks
        while KV streams through the ring
        """
        
        seq_len = local_query.size(1)
        num_chunks = (seq_len + max_chunk_size - 1) // max_chunk_size
        
        output = torch.zeros_like(local_query)
        
        for chunk_idx in range(num_chunks):
            # Get local query chunk
            start = chunk_idx * max_chunk_size
            end = min(start + max_chunk_size, seq_len)
            q_chunk = local_query[:, start:end]
            
            # Accumulate attention from streaming KV
            for _ in range(self.num_gpus):
                # Receive KV from previous GPU
                k_chunk = kv_buffer.receive()
                v_chunk = kv_buffer.receive()
                
                # Compute attention
                attn_chunk = flash_attn_func(q_chunk, k_chunk, v_chunk)
                
                output[:, start:end] += attn_chunk
                
                # Forward KV to next GPU
                kv_buffer.send(k_chunk, v_chunk)
        
        return output

Communication Optimization

class RingCommunicationOptimizer:
    """
    Optimizing communication in Ring Attention
    """
    
    @staticmethod
    def overlap_compute_communication(
        compute_stream,
        comm_stream,
        key_buffer,
        value_buffer
    ):
        """
        Overlap computation and communication using CUDA streams
        """
        
        # While computing attention for chunk i
        # Simultaneously receive KV for chunk i+1
        
        for i in range(num_chunks):
            # Start async receive for next chunk KV
            if i < num_chunks - 1:
                future_kv = comm_stream.stream_receive(
                    key_buffer, value_buffer,
                    src=gpu_ring[i+1]
                )
            
            # Compute attention for current chunk
            output_chunk = attention_kernel(query_chunk[i], key_chunk[i], value_chunk[i])
            
            # Wait for KV to arrive if needed
            if i < num_chunks - 1:
                future_kv.wait()
            
            # Send current KV to next GPU (async)
            if i > 0:
                comm_stream.async_send(key_chunk[i], value_chunk[i], dst=gpu_ring[i+1])
    
    @staticmethod
    def reduce_communication(
        keys,
        values,
        compression_ratio=4
    ):
        """
        Compress KV before sending through ring
        Reduces bandwidth requirements
        """
        
        # Project to lower dimension
        compressed_keys = keys @ self.compress_proj  # [batch, seq, heads, dim//ratio]
        compressed_values = values @ self.compress_proj
        
        # Send compressed version through ring
        # Decompress on-the-fly when computing attention
        
        return compressed_keys, compressed_values

Unified Sequence Parallelism (USP)

Combining Ring and Ulysses

class UnifiedSequenceParallelism:
    """
    USP combines the best of Ring Attention and DeepSpeed-Ulysses
    
    - Ring Attention: Splits sequence, keeps attention heads together
    - Ulysses: Splits attention heads, combines sequences
    
    USP: Allows both dimensions to be split flexibly
    """
    
    def __init__(self, num_gpus, num_heads, parallelism='hybrid'):
        self.num_gpus = num_gpus
        self.num_heads = num_heads
        self.parallelism = parallelism
        
    def usp_forward(
        self,
        hidden_states,
        attention_mask,
        ring_dim=2,
        ulysses_dim=2
    ):
        """
        USP combines ring and ulysses parallelism
        
        Args:
            ring_dim: How to split along sequence
            ulysses_dim: How to split along attention heads
        """
        
        # Step 1: Sequence splitting (Ring-style)
        # Split hidden states across GPUs along sequence dimension
        hidden_chunks = self.split_sequence(hidden_states, ring_dim)
        
        # Each GPU gets a chunk of the sequence
        local_hidden = hidden_chunks[self.gpu_id]
        
        # Step 2: Attention head splitting (Ulysses-style)
        # Further split heads across GPUs within each sequence chunk
        query, key, value = self.compute_qkv(local_hidden)
        query_heads = self.split_heads(query, ulysses_dim)
        
        # Step 3: All-to-All for attention heads
        # Exchange heads across GPUs while keeping sequence together
        # This is the "Ulysses" part - combines head dimension across GPUs
        query_ulysses = self.all_to_all_heads(query_heads)
        key_ulysses = self.all_to_all_heads(key_heads)
        value_ulysses = self.all_to_all_heads(value_heads)
        
        # Step 4: Local attention computation
        # Now each GPU has a subset of heads for a subset of sequence
        # Attention is LOCAL within each GPU (sequence is "together" per GPU)
        local_output = self.local_attention(
            query_ulysses, key_ulysses, value_ulysses, attention_mask
        )
        
        # Step 5: All-to-All to return heads to original GPUs
        output_heads = self.all_to_all_heads(local_output)
        
        # Step 6: Gather sequence chunks (Ring reverse)
        output = self.gather_sequences(output_heads)
        
        return output
    
    def split_sequence(self, x, num_chunks):
        """Split sequence dimension across GPUs"""
        seq_len = x.size(1)
        chunk_size = seq_len // num_chunks
        
        chunks = []
        for i in range(num_chunks):
            start = i * chunk_size
            end = start + chunk_size if i < num_chunks - 1 else seq_len
            chunks.append(x[:, start:end])
        
        return chunks
    
    def split_heads(self, x, num_splits):
        """Split attention heads across GPUs"""
        batch, seq, hidden = x.shape
        head_dim = hidden // self.num_heads
        
        # Reshape to heads
        x = x.view(batch, seq, self.num_heads, head_dim)
        
        # Split heads
        heads_per_gpu = self.num_heads // num_splits
        return x.split(heads_per_gpu, dim=2)
    
    def all_to_all_heads(self, local_heads):
        """
        All-to-All communication for attention heads
        
        Each GPU sends its local heads to all other GPUs
        and receives a portion of each other's heads
        """
        # Using NCCL all_to_all primitive
        # This is the "Ulysses" transformation
        return torch.distributed.all_to_all(local_heads, group=self.gpu_group)

Adaptive USP

class AdaptiveUSP:
    """
    Adaptive USP chooses the best parallelism strategy based on workload
    """
    
    def __init__(self, model_config):
        self.model_config = model_config
        
    def select_parallelism(
        self,
        sequence_length,
        num_gpus,
        num_heads,
        memory_per_gpu
    ):
        """
        Automatically select optimal parallelism strategy
        """
        
        # Calculate what's needed
        attention_memory = self.estimate_attention_memory(sequence_length, num_heads)
        ffn_memory = self.estimate_ffn_memory(sequence_length)
        total_memory = attention_memory + ffn_memory
        
        # Decision tree for parallelism
        if total_memory <= memory_per_gpu:
            # Can fit in one GPU, no parallelism needed
            return {'strategy': 'none', 'ring_dim': 1, 'ulysses_dim': 1}
        
        elif attention_memory > memory_per_gpu:
            # Attention is the bottleneck, use Ring
            # Need sequence parallelism
            ring_dim = int(attention_memory / memory_per_gpu) + 1
            ring_dim = min(ring_dim, num_gpus)
            
            return {
                'strategy': 'ring',
                'ring_dim': ring_dim,
                'ulysses_dim': 1
            }
        
        else:
            # FFN is bottleneck, can use Ulysses
            # Split heads across GPUs
            ulysses_dim = num_gpus
            
            return {
                'strategy': 'hybrid',
                'ring_dim': 1,
                'ulysses_dim': ulysses_dim
            }
    
    def hybrid_strategy_selector(
        self,
        sequence_length,
        available_memory,
        network_bandwidth
    ):
        """
        Select optimal hybrid strategy
        """
        
        # Cost models
        ring_comm = self.estimate_ring_comm(sequence_length)
        ulysses_comm = self.estimate_ulysses_comm(sequence_length)
        
        # Network-bound: prefer Ulysses (less communication)
        # Memory-bound: prefer Ring
        
        if network_bandwidth < 100e9:  # < 100 GB/s
            return 'ulysses'  # Less bandwidth
        elif available_memory < 20e9:  # < 20 GB
            return 'ring'  # More memory efficient
        else:
            return 'hybrid'

Practical Implementation

Using with vLLM

# Deploy long-context LLM with Ring Attention using vLLM
from vllm import LLM

# Configure for long context
llm = LLM(
    model="meta-llama/Llama-3-70B-128K",
    tensor_parallel_size=8,           # Use 8 GPUs
    max_num_seqs=64,
    max_model_len=128 * 1024,         # 128K context
    enable_chunked_prefill=False,     # Full sequence attention
    # vLLM uses Ring Attention internally for long contexts
)

# Generate with long context
outputs = llm.generate(
    prompts=[
        "Summarize this entire book: [very long text...]"
    ],
    sampling_params=sampling_params
)

Custom Implementation with PyTorch

import torch
import torch.nn as nn
import torch.distributed as dist

class RingAttentionLayer(nn.Module):
    """
    Complete Ring Attention layer
    """
    
    def __init__(self, hidden_size, num_heads, num_gpus):
        super().__init__()
        
        self.hidden_size = hidden_size
        self.num_heads = num_heads
        self.head_dim = hidden_size // num_heads
        self.num_gpus = num_gpus
        
        # Standard attention projections
        self.q_proj = nn.Linear(hidden_size, hidden_size)
        self.k_proj = nn.Linear(hidden_size, hidden_size)
        self.v_proj = nn.Linear(hidden_size, hidden_size)
        self.o_proj = nn.Linear(hidden_size, hidden_size)
        
        self.norm = nn.LayerNorm(hidden_size)
        
    def forward(self, x, layer_idx):
        """
        Forward with Ring Attention
        """
        batch_size, seq_len, _ = x.shape
        
        # Compute Q, K, V
        q = self.q_proj(x).view(batch_size, seq_len, self.num_heads, self.head_dim)
        k = self.k_proj(x).view(batch_size, seq_len, self.num_heads, self.head_dim)
        v = self.v_proj(x).view(batch_size, seq_len, self.num_heads, self.head_dim)
        
        # Get current GPU ID
        gpu_id = dist.get_rank()
        
        # Split sequence across GPUs
        chunk_size = seq_len // self.num_gpus
        
        # Local sequence for this GPU
        local_q = q[:, gpu_id*chunk_size:(gpu_id+1)*chunk_size]
        
        # Ring attention: accumulate attention from all KV chunks
        output_chunks = []
        
        # Start with local KV
        current_k = k[:, gpu_id*chunk_size:(gpu_id+1)*chunk_size]
        current_v = v[:, gpu_id*chunk_size:(gpu_id+1)*chunk_size]
        
        for step in range(self.num_gpus):
            # Compute attention with current KV chunk
            attn_scores = torch.matmul(
                local_q,
                current_k.transpose(-2, -1)
            ) / (self.head_dim ** 0.5)
            
            attn_probs = F.softmax(attn_scores, dim=-1)
            output_chunk = torch.matmul(attn_probs, current_v)
            output_chunks.append(output_chunk)
            
            # Receive KV from next GPU in ring
            if step < self.num_gpus - 1:
                next_gpu = (gpu_id + step + 1) % self.num_gpus
                current_k = self.recv_from_gpu(next_gpu, 'key')
                current_v = self.recv_from_gpu(next_gpu, 'value')
        
        # Combine all output chunks
        local_output = torch.cat(output_chunks, dim=1)
        
        # Project and add residual
        output = self.o_proj(local_output)
        output = self.norm(output)
        
        return output + x
    
    def recv_from_gpu(self, src, dtype):
        """Receive tensor from another GPU"""
        # Would use torch.distributed.isend/irecv
        pass
    
    def send_to_gpu(self, dst, tensor, dtype):
        """Send tensor to another GPU"""
        pass

Performance Benchmarks

Context Length Scaling

# Benchmark: Time and memory vs context length
benchmarks = {
    'sequence_length': [32K, 64K, 128K, 256K, 512K, 1M],
    
    'ring_attention_8gpu': {
        'memory_gb': [8, 12, 20, 36, 68, 130],
        'time_per_token_ms': [2.1, 2.3, 2.8, 4.2, 7.5, 15.2],
    },
    
    'standard_single_gpu': {
        'memory_gb': [8, 'OOM', 'OOM', 'OOM', 'OOM', 'OOM'],
        'time_per_token_ms': [2.1, 'N/A', 'N/A', 'N/A', 'N/A', 'N/A'],
    },
    
    'usp_8gpu': {
        'memory_gb': [6, 8, 12, 18, 32, 58],
        'time_per_token_ms': [1.8, 1.9, 2.1, 2.8, 4.5, 8.2],
    }
}

Scaling Efficiency

# Scaling efficiency with number of GPUs
scaling_results = {
    '2_gpus': {'efficiency': 0.92, 'max_context': '256K'},
    '4_gpus': {'efficiency': 0.88, 'max_context': '512K'},
    '8_gpus': {'efficiency': 0.85, 'max_context': '1M'},
    '16_gpus': {'efficiency': 0.78, 'max_context': '2M'},
    
    # Note: Efficiency decreases due to communication overhead
    # But maximum context increases proportionally
}

Comparison with Alternatives

Feature Standard Ring Attention USP Local + Global
Memory O(Nยฒ) O(Nยฒ/P) O(Nยฒ/P) O(Nยฒ)
Compute O(Nยฒ) O(Nยฒ/P) O(Nยฒ/P) O(N)
Max Context Limited 1M+ 1M+ 128K
Communication None Ring All-to-All Tree
Complexity Simple Medium High Medium

Use Cases

Long Document Analysis

use_cases = {
    'document_qa': {
        'context': 'Full books, legal documents',
        'length': '100K - 1M tokens',
        'example': 'Answer questions about entire codebases'
    },
    
    'code_understanding': {
        'context': 'Large code repositories',
        'length': '500K+ tokens',
        'example': 'Understand entire monorepo, cross-file analysis'
    },
    
    'long_conversation': {
        'context': 'Multi-hour chat sessions',
        'length': '128K+ tokens',
        'example': 'AI assistant with perfect long-term memory'
    },
    
    'research_synthesis': {
        'context': 'Reading multiple papers',
        'length': '500K+ tokens',
        'example': 'Synthesize insights from 100s of papers'
    }
}

Future Directions

future_directions = {
    'heterogeneous_systems': {
        'description': 'Mix CPU, GPU, and specialized accelerators',
        'challenge': 'Different memory and compute characteristics'
    },
    
    'hierarchical_ring': {
        'description': 'Multiple rings (within node, across nodes)',
        'challenge': 'Bandwidth differences at different levels'
    },
    
    'adaptive_attention': {
        'description': 'Dynamically choose attention pattern per layer',
        'challenge': 'When to use what pattern'
    },
    
    '1B_context': {
        'description': 'Scale to billion token context',
        'challenge': 'Communication becomes bottleneck'
    }
}

Conclusion

Ring Attention and Unified Sequence Parallelism represent the cutting edge of long-context LLM deployment:

  • ็ช็ ด้™ๅˆถ: 1M+ token context now practical
  • ๅฏๆ‰ฉๅฑ•: ็†่ฎบไธŠๆ— ไธŠ้™๏ผˆ้€š่ฟ‡ๅขžๅŠ GPU๏ผ‰
  • ๆ•ˆ็އ้ซ˜: ็›ธๆฏ”ๅ•GPU๏ผŒๅ†…ๅญ˜็บฟๆ€งๅ‡ๅฐ‘
  • ๅ…ผๅฎนๅฅฝ: ๅฏไธŽFlashAttentionใ€PagedAttention็ป“ๅˆ

่ฟ™ไบ›ๆŠ€ๆœฏๆญฃๅœจๆ”นๅ˜ๆˆ‘ไปฌๅค„็†้•ฟๆ–‡ๆกฃใ€ไปฃ็ ๅบ“ๅ’Œๅคๆ‚ๆŽจ็†ไปปๅŠก็š„ๆ–นๅผใ€‚้š็€GPU้›†็พค่ง„ๆจก็š„ๆ‰ฉๅคงๅ’Œ้€šไฟกไผ˜ๅŒ–็š„ๆ”น่ฟ›๏ผŒ็™พไธ‡token็บงๅˆซ็š„ไธŠไธ‹ๆ–‡ๅฐ†ๆˆไธบๆ ‡ๅ‡†้…็ฝฎใ€‚

Resources

Comments