Skip to main content
โšก Calmops

KV Cache Eviction Strategies for Long-Context LLM Inference

Introduction

As large language models push context windows from 4K to 128K and beyond, managing the Key-Value (KV) cache becomes increasingly critical. The KV cache stores attention keys and values for all previously processed tokens, enabling efficient autoregressive generation. However, memory grows linearly with sequence length, quickly exhausting GPU memory and limiting maximum context.

KV cache eviction strategies solve this by intelligently selecting which tokens to keep in memory and which to discard. This article explores the algorithms, implementations, and best practices for effective KV cache management.

The KV Cache Challenge

Memory Growth Problem

def kv_cache_memory():
    """
    KV cache memory requirements
    """
    # For a 7B model with 128K context:
    config = {
        'num_layers': 32,
        'num_heads': 32,
        'head_dim': 128,
        'batch_size': 1,
        'sequence_length': 128 * 1024,  # 128K
    }
    
    # Memory per layer: 2 * seq_len * num_heads * head_dim * bytes
    # (K and V matrices, FP16 = 2 bytes)
    per_layer = (
        config['sequence_length'] * 
        config['num_heads'] * 
        config['head_dim'] * 
        2  # K and V
        * 2  # FP16 bytes
    )
    
    total_per_layer_gb = per_layer / (1024 ** 3)
    total_gb = total_per_layer_gb * config['num_layers']
    
    print(f"KV Cache per layer: {total_per_layer_gb:.2f} GB")
    print(f"Total KV Cache: {total_gb:.2f} GB")
    # Results: ~4 GB per layer, ~128 GB total for 128K context!

Eviction Requirements

eviction_requirements = {
    'problem': 'Memory grows O(N) with sequence length',
    'impact': '128K+ context becomes impossible on single GPU',
    'solution': 'Evict less important tokens while preserving quality',
    'challenge': 'Which tokens to keep? Which to evict?'
}

Eviction Strategy Categories

1. Attention-Based Eviction

class AttentionBasedEviction:
    """
    Evict tokens with low attention scores
    """
    
    def __init__(self, keep_ratio=0.5):
        self.keep_ratio = keep_ratio
    
    def compute_attention_importance(self, attention_weights):
        """
        Compute importance scores based on attention
        
        Tokens that attend to many important tokens should be kept
        """
        # Average attention across heads and positions
        importance = attention_weights.mean(dim=(0, 1))  # [seq_len]
        
        return importance
    
    def evict(self, kv_cache, attention_weights, target_size):
        """
        Keep tokens with highest attention importance
        """
        importance = self.compute_attention_importance(attention_weights)
        
        # Select top-k tokens to keep
        num_keep = int(len(importance) * self.keep_ratio)
        keep_indices = torch.topk(importance, num_keep).indices
        
        # Evict others
        return kv_cache[keep_indices], keep_indices

2. Score-Based Eviction (KEYDIFF)

class ScoreBasedEviction:
    """
    KEYDIFF: Key Similarity-based KV Cache Eviction
    Uses geometric features of keys to determine importance
    """
    
    def __init__(self, keep_ratio=0.5):
        self.keep_ratio = keep_ratio
    
    def compute_key_importance(self, keys):
        """
        Compute importance based on key similarity
        
        Keys with low cosine similarity to mean key often have high attention
        """
        # Compute mean key (anchor)
        mean_key = keys.mean(dim=1, keepdim=True)  # [1, seq, heads, dim]
        
        # Compute cosine similarity to mean
        keys_norm = F.normalize(keys, dim=-1)
        mean_key_norm = F.normalize(mean_key, dim=-1)
        
        similarity = (keys_norm * mean_key_norm).sum(dim=-1)  # [batch, seq, heads]
        
        # Average across heads
        importance = similarity.mean(dim=-1)  # [batch, seq]
        
        return importance
    
    def evict(self, keys, values, target_size):
        """
        Evict based on key similarity scores
        """
        importance = self.compute_key_importance(keys)
        
        # Keep tokens with LOW similarity (they often have high attention)
        num_keep = int(importance.shape[1] * self.keep_ratio)
        
        # Select tokens to keep
        keep_indices = torch.topk(
            importance.squeeze(), 
            num_keep, 
            largest=False  # Keep low similarity (high attention)
        ).indices
        
        keep_indices = keep_indices.sort().values
        
        # Return pruned cache
        return keys[:, keep_indices], values[:, keep_indices]

3. Recency-Based Eviction

class RecencyBasedEviction:
    """
    Keep more recent tokens (simpler but effective)
    """
    
    def __init__(self, keep_recent_ratio=0.5):
        self.keep_recent_ratio = keep_recent_ratio
    
    def evict(self, kv_cache, current_position):
        """
        Keep only recent tokens
        """
        seq_len = kv_cache.shape[1]
        keep_count = int(seq_len * self.keep_recent_ratio)
        
        # Keep last N tokens
        start = max(0, seq_len - keep_count)
        
        return kv_cache[:, start:]

4. Hybrid Eviction Strategies

class HybridEviction:
    """
    Combine multiple eviction strategies
    """
    
    def __init__(self, attention_weight=0.4, recency_weight=0.3, score_weight=0.3):
        self.attention_eviction = AttentionBasedEviction()
        self.recency_eviction = RecencyBasedEviction()
        self.score_eviction = ScoreBasedEviction()
        
        self.weights = {
            'attention': attention_weight,
            'recency': recency_weight,
            'score': score_weight
        }
    
    def evict(self, kv_cache, attention_weights, target_size):
        """
        Combine scores from all strategies
        """
        # Get importance from each method
        attn_importance = self.attention_eviction.compute_attention_importance(attention_weights)
        
        # Normalize all scores
        attn_scores = F.normalize(attn_importance, dim=0)
        
        # Recency scores (higher for recent)
        seq_len = attn_importance.shape[0]
        recency_scores = torch.arange(seq_len).float() / seq_len
        recency_scores = recency_scores.to(attn_scores.device)
        
        # Combined score
        combined_scores = (
            self.weights['attention'] * attn_scores +
            self.weights['recency'] * recency_scores
        )
        
        # Select tokens to keep
        num_keep = int(seq_len * target_size)
        keep_indices = torch.topk(combined_scores, num_keep).indices
        
        return kv_cache[:, keep_indices]

Advanced Eviction Algorithms

LRU (Least Recently Used)

class LRUEviction:
    """
    Evict least recently used tokens
    Track when each token was last attended to
    """
    
    def __init__(self):
        self.last_used = {}  # token_idx -> last_used_time
    
    def update_access(self, token_idx, current_time):
        """Update last access time"""
        self.last_used[token_idx] = current_time
    
    def evict(self, kv_cache, num_to_evict):
        """
        Evict tokens not accessed for longest time
        """
        if len(self.last_used) == 0:
            # No history, evict oldest
            return kv_cache[:, num_to_evict:]
        
        # Sort by last access time
        sorted_tokens = sorted(
            self.last_used.items(), 
            key=lambda x: x[1]
        )
        
        # Get indices to keep (most recently used)
        keep_count = kv_cache.shape[1] - num_to_evict
        keep_indices = [idx for idx, _ in sorted_tokens[-keep_count:]]
        keep_indices = sorted(keep_indices)
        
        return kv_cache[:, keep_indices]

Learned Eviction

class LearnedEviction(nn.Module):
    """
    Trainable eviction strategy
    """
    
    def __init__(self, d_model):
        super().__init__()
        
        # Small network to predict eviction
        self.eviction_predictor = nn.Sequential(
            nn.Linear(d_model, d_model // 2),
            nn.ReLU(),
            nn.Linear(d_model // 2, 1),
            nn.Sigmoid()
        )
    
    def forward(self, hidden_states, keep_ratio=0.5):
        """
        Predict importance and evict
        """
        # Get importance scores
        importance = self.eviction_predictor(hidden_states).squeeze(-1)
        
        # Select tokens to keep
        num_keep = int(importance.shape[1] * keep_ratio)
        keep_indices = torch.topk(importance, num_keep).indices
        
        return keep_indices

Implementation with vLLM

# Using eviction strategies with vLLM
from vllm import LLM, SamplingParams

# Configure for long context with eviction
llm = LLM(
    model="meta-llama/Llama-2-70b-hf",
    max_model_len=128 * 1024,
    # Eviction configuration
    kv_cache_dtype="auto",
    gpu_memory_utilization=0.9,
)

# vLLM automatically handles eviction
# Default: block-based LRU eviction
# Can customize through engine arguments

Custom Eviction Policy

class CustomEvictionPolicy:
    """
    Custom eviction for specific use cases
    """
    
    def __init__(self, policy='hybrid'):
        self.policy = policy
        
        if policy == 'importance':
            self.evictor = AttentionBasedEviction(keep_ratio=0.3)
        elif policy == 'recency':
            self.evictor = RecencyBasedEviction(keep_recent_ratio=0.3)
        elif policy == 'hybrid':
            self.evictor = HybridEviction()
    
    def evict(self, cache_manager, required_space):
        """
        Perform eviction to free space
        """
        current_size = cache_manager.get_current_size()
        
        if current_size < required_space:
            return  # Already enough space
        
        target_size = required_space
        
        while cache_manager.get_current_size() > target_size:
            # Evict based on policy
            cache_manager.evict(
                self.evictor,
                num_to_evict=cache_manager.block_size
            )

Performance Comparison

Eviction Strategy Benchmarks

benchmarks = {
    'long_context_qa': {
        'context_length': '32K',
        'full_cache': {
            'accuracy': 85.2,
            'memory': 'OOM'
        },
        'attention_eviction': {
            'accuracy': 82.1,
            'memory': '16GB'
        },
        'recency_eviction': {
            'accuracy': 78.5,
            'memory': '16GB'
        },
        'hybrid_eviction': {
            'accuracy': 81.8,
            'memory': '16GB'
        },
        'score_eviction': {
            'accuracy': 83.2,
            'memory': '16GB'
        }
    }
}

Quality vs Memory Tradeoffs

tradeoffs = {
    'keep_ratio_50': {
        'accuracy_retention': '95%',
        'memory_savings': '50%',
        'use_case': 'Standard long-context tasks'
    },
    'keep_ratio_25': {
        'accuracy_retention': '88%',
        'memory_savings': '75%',
        'use_case': 'Memory-constrained scenarios'
    },
    'keep_ratio_10': {
        'accuracy_retention': '75%',
        'memory_savings': '90%',
        'use_case': 'Extreme memory constraints'
    }
}

Best Practices

When to Use Eviction

guidelines = {
    'use_eviction_when': [
        'Context length > 16K tokens',
        'GPU memory limited',
        'Batch processing multiple requests',
        'Long-running conversations'
    ],
    
    'choose_strategy': {
        'general': 'Hybrid (attention + recency)',
        'qa_tasks': 'Attention-based',
        'conversations': 'Recency (keep recent context)',
        'research': 'Score-based (KEYDIFF)'
    },
    
    'keep_ratio': {
        '16K_context': 0.8,
        '32K_context': 0.5,
        '64K_context': 0.3,
        '128K_context': 0.2
    }
}

Integration Tips

integration_tips = {
    'prefetch': 'Keep system prompt fully cached',
    'streaming': 'Evict old tokens gradually',
    'batch': 'Individual eviction per request',
    'monitor': 'Track eviction rates and adjust'
}

Conclusion

KV cache eviction is essential for practical long-context LLM deployment:

  • Enables Scale: Process 128K+ context on limited GPU memory
  • Multiple Strategies: Attention-based, recency, score-based, or hybrid
  • Quality Tradeoffs: More aggressive eviction = some quality loss
  • Active Research: New methods like KEYDIFF show promise

As context windows continue to grow, sophisticated eviction strategies become critical for real-world applications.

Resources

Comments