Introduction
As large language models push context windows from 4K to 128K and beyond, managing the Key-Value (KV) cache becomes increasingly critical. The KV cache stores attention keys and values for all previously processed tokens, enabling efficient autoregressive generation. However, memory grows linearly with sequence length, quickly exhausting GPU memory and limiting maximum context.
KV cache eviction strategies solve this by intelligently selecting which tokens to keep in memory and which to discard. This article explores the algorithms, implementations, and best practices for effective KV cache management.
The KV Cache Challenge
Memory Growth Problem
def kv_cache_memory():
"""
KV cache memory requirements
"""
# For a 7B model with 128K context:
config = {
'num_layers': 32,
'num_heads': 32,
'head_dim': 128,
'batch_size': 1,
'sequence_length': 128 * 1024, # 128K
}
# Memory per layer: 2 * seq_len * num_heads * head_dim * bytes
# (K and V matrices, FP16 = 2 bytes)
per_layer = (
config['sequence_length'] *
config['num_heads'] *
config['head_dim'] *
2 # K and V
* 2 # FP16 bytes
)
total_per_layer_gb = per_layer / (1024 ** 3)
total_gb = total_per_layer_gb * config['num_layers']
print(f"KV Cache per layer: {total_per_layer_gb:.2f} GB")
print(f"Total KV Cache: {total_gb:.2f} GB")
# Results: ~4 GB per layer, ~128 GB total for 128K context!
Eviction Requirements
eviction_requirements = {
'problem': 'Memory grows O(N) with sequence length',
'impact': '128K+ context becomes impossible on single GPU',
'solution': 'Evict less important tokens while preserving quality',
'challenge': 'Which tokens to keep? Which to evict?'
}
Eviction Strategy Categories
1. Attention-Based Eviction
class AttentionBasedEviction:
"""
Evict tokens with low attention scores
"""
def __init__(self, keep_ratio=0.5):
self.keep_ratio = keep_ratio
def compute_attention_importance(self, attention_weights):
"""
Compute importance scores based on attention
Tokens that attend to many important tokens should be kept
"""
# Average attention across heads and positions
importance = attention_weights.mean(dim=(0, 1)) # [seq_len]
return importance
def evict(self, kv_cache, attention_weights, target_size):
"""
Keep tokens with highest attention importance
"""
importance = self.compute_attention_importance(attention_weights)
# Select top-k tokens to keep
num_keep = int(len(importance) * self.keep_ratio)
keep_indices = torch.topk(importance, num_keep).indices
# Evict others
return kv_cache[keep_indices], keep_indices
2. Score-Based Eviction (KEYDIFF)
class ScoreBasedEviction:
"""
KEYDIFF: Key Similarity-based KV Cache Eviction
Uses geometric features of keys to determine importance
"""
def __init__(self, keep_ratio=0.5):
self.keep_ratio = keep_ratio
def compute_key_importance(self, keys):
"""
Compute importance based on key similarity
Keys with low cosine similarity to mean key often have high attention
"""
# Compute mean key (anchor)
mean_key = keys.mean(dim=1, keepdim=True) # [1, seq, heads, dim]
# Compute cosine similarity to mean
keys_norm = F.normalize(keys, dim=-1)
mean_key_norm = F.normalize(mean_key, dim=-1)
similarity = (keys_norm * mean_key_norm).sum(dim=-1) # [batch, seq, heads]
# Average across heads
importance = similarity.mean(dim=-1) # [batch, seq]
return importance
def evict(self, keys, values, target_size):
"""
Evict based on key similarity scores
"""
importance = self.compute_key_importance(keys)
# Keep tokens with LOW similarity (they often have high attention)
num_keep = int(importance.shape[1] * self.keep_ratio)
# Select tokens to keep
keep_indices = torch.topk(
importance.squeeze(),
num_keep,
largest=False # Keep low similarity (high attention)
).indices
keep_indices = keep_indices.sort().values
# Return pruned cache
return keys[:, keep_indices], values[:, keep_indices]
3. Recency-Based Eviction
class RecencyBasedEviction:
"""
Keep more recent tokens (simpler but effective)
"""
def __init__(self, keep_recent_ratio=0.5):
self.keep_recent_ratio = keep_recent_ratio
def evict(self, kv_cache, current_position):
"""
Keep only recent tokens
"""
seq_len = kv_cache.shape[1]
keep_count = int(seq_len * self.keep_recent_ratio)
# Keep last N tokens
start = max(0, seq_len - keep_count)
return kv_cache[:, start:]
4. Hybrid Eviction Strategies
class HybridEviction:
"""
Combine multiple eviction strategies
"""
def __init__(self, attention_weight=0.4, recency_weight=0.3, score_weight=0.3):
self.attention_eviction = AttentionBasedEviction()
self.recency_eviction = RecencyBasedEviction()
self.score_eviction = ScoreBasedEviction()
self.weights = {
'attention': attention_weight,
'recency': recency_weight,
'score': score_weight
}
def evict(self, kv_cache, attention_weights, target_size):
"""
Combine scores from all strategies
"""
# Get importance from each method
attn_importance = self.attention_eviction.compute_attention_importance(attention_weights)
# Normalize all scores
attn_scores = F.normalize(attn_importance, dim=0)
# Recency scores (higher for recent)
seq_len = attn_importance.shape[0]
recency_scores = torch.arange(seq_len).float() / seq_len
recency_scores = recency_scores.to(attn_scores.device)
# Combined score
combined_scores = (
self.weights['attention'] * attn_scores +
self.weights['recency'] * recency_scores
)
# Select tokens to keep
num_keep = int(seq_len * target_size)
keep_indices = torch.topk(combined_scores, num_keep).indices
return kv_cache[:, keep_indices]
Advanced Eviction Algorithms
LRU (Least Recently Used)
class LRUEviction:
"""
Evict least recently used tokens
Track when each token was last attended to
"""
def __init__(self):
self.last_used = {} # token_idx -> last_used_time
def update_access(self, token_idx, current_time):
"""Update last access time"""
self.last_used[token_idx] = current_time
def evict(self, kv_cache, num_to_evict):
"""
Evict tokens not accessed for longest time
"""
if len(self.last_used) == 0:
# No history, evict oldest
return kv_cache[:, num_to_evict:]
# Sort by last access time
sorted_tokens = sorted(
self.last_used.items(),
key=lambda x: x[1]
)
# Get indices to keep (most recently used)
keep_count = kv_cache.shape[1] - num_to_evict
keep_indices = [idx for idx, _ in sorted_tokens[-keep_count:]]
keep_indices = sorted(keep_indices)
return kv_cache[:, keep_indices]
Learned Eviction
class LearnedEviction(nn.Module):
"""
Trainable eviction strategy
"""
def __init__(self, d_model):
super().__init__()
# Small network to predict eviction
self.eviction_predictor = nn.Sequential(
nn.Linear(d_model, d_model // 2),
nn.ReLU(),
nn.Linear(d_model // 2, 1),
nn.Sigmoid()
)
def forward(self, hidden_states, keep_ratio=0.5):
"""
Predict importance and evict
"""
# Get importance scores
importance = self.eviction_predictor(hidden_states).squeeze(-1)
# Select tokens to keep
num_keep = int(importance.shape[1] * keep_ratio)
keep_indices = torch.topk(importance, num_keep).indices
return keep_indices
Implementation with vLLM
# Using eviction strategies with vLLM
from vllm import LLM, SamplingParams
# Configure for long context with eviction
llm = LLM(
model="meta-llama/Llama-2-70b-hf",
max_model_len=128 * 1024,
# Eviction configuration
kv_cache_dtype="auto",
gpu_memory_utilization=0.9,
)
# vLLM automatically handles eviction
# Default: block-based LRU eviction
# Can customize through engine arguments
Custom Eviction Policy
class CustomEvictionPolicy:
"""
Custom eviction for specific use cases
"""
def __init__(self, policy='hybrid'):
self.policy = policy
if policy == 'importance':
self.evictor = AttentionBasedEviction(keep_ratio=0.3)
elif policy == 'recency':
self.evictor = RecencyBasedEviction(keep_recent_ratio=0.3)
elif policy == 'hybrid':
self.evictor = HybridEviction()
def evict(self, cache_manager, required_space):
"""
Perform eviction to free space
"""
current_size = cache_manager.get_current_size()
if current_size < required_space:
return # Already enough space
target_size = required_space
while cache_manager.get_current_size() > target_size:
# Evict based on policy
cache_manager.evict(
self.evictor,
num_to_evict=cache_manager.block_size
)
Performance Comparison
Eviction Strategy Benchmarks
benchmarks = {
'long_context_qa': {
'context_length': '32K',
'full_cache': {
'accuracy': 85.2,
'memory': 'OOM'
},
'attention_eviction': {
'accuracy': 82.1,
'memory': '16GB'
},
'recency_eviction': {
'accuracy': 78.5,
'memory': '16GB'
},
'hybrid_eviction': {
'accuracy': 81.8,
'memory': '16GB'
},
'score_eviction': {
'accuracy': 83.2,
'memory': '16GB'
}
}
}
Quality vs Memory Tradeoffs
tradeoffs = {
'keep_ratio_50': {
'accuracy_retention': '95%',
'memory_savings': '50%',
'use_case': 'Standard long-context tasks'
},
'keep_ratio_25': {
'accuracy_retention': '88%',
'memory_savings': '75%',
'use_case': 'Memory-constrained scenarios'
},
'keep_ratio_10': {
'accuracy_retention': '75%',
'memory_savings': '90%',
'use_case': 'Extreme memory constraints'
}
}
Best Practices
When to Use Eviction
guidelines = {
'use_eviction_when': [
'Context length > 16K tokens',
'GPU memory limited',
'Batch processing multiple requests',
'Long-running conversations'
],
'choose_strategy': {
'general': 'Hybrid (attention + recency)',
'qa_tasks': 'Attention-based',
'conversations': 'Recency (keep recent context)',
'research': 'Score-based (KEYDIFF)'
},
'keep_ratio': {
'16K_context': 0.8,
'32K_context': 0.5,
'64K_context': 0.3,
'128K_context': 0.2
}
}
Integration Tips
integration_tips = {
'prefetch': 'Keep system prompt fully cached',
'streaming': 'Evict old tokens gradually',
'batch': 'Individual eviction per request',
'monitor': 'Track eviction rates and adjust'
}
Conclusion
KV cache eviction is essential for practical long-context LLM deployment:
- Enables Scale: Process 128K+ context on limited GPU memory
- Multiple Strategies: Attention-based, recency, score-based, or hybrid
- Quality Tradeoffs: More aggressive eviction = some quality loss
- Active Research: New methods like KEYDIFF show promise
As context windows continue to grow, sophisticated eviction strategies become critical for real-world applications.
Resources
- KEYDIFF: Key Similarity-Based Eviction
- vLLM Documentation
- Long Context LLM Survey
- PagedAttention Paper
Comments