Introduction
Large Language Model inference remains computationally intensive, with autoregressive decoding being inherently sequentialโeach token depends on all previous tokens. This creates a memory-bandwidth bottleneck where most time is spent loading model weights rather than computing.
Speculative decoding solves this elegantly: instead of generating tokens one-by-one, we use a smaller “draft” model to propose multiple tokens in parallel, then verify them with the larger target model. When predictions match (which happens ~70-90% of the time), we get multiple tokens for the cost of one verification pass.
In 2026, speculative decoding has become essential for production LLM deployment, achieving 2-3x speedups while maintaining identical output quality. This guide explores the algorithms, implementations, and practical applications.
The Problem: Autoregressive Bottleneck
Standard Autoregressive Generation
class StandardAutoregressive:
"""
Standard LLM generation - sequential token by token.
"""
def __init__(self, model, tokenizer):
self.model = model
self.tokenizer = tokenizer
def generate(self, prompt, max_tokens=100):
"""
Generate tokens one at a time.
"""
input_ids = self.tokenizer.encode(prompt)
for _ in range(max_tokens):
# Forward pass for ALL previous tokens
logits = self.model(torch.tensor([input_ids]))
# Get next token
next_token = logits[0, -1].argmax()
input_ids.append(next_token.item())
if next_token == self.tokenizer.eos_token:
break
return self.tokenizer.decode(input_ids)
The problem: each token requires a full forward pass through the entire model, loading all weights from memory. For a 70B model, that’s 140GB of memory bandwidth per token.
Speculative Decoding Fundamentals
Core Algorithm
import torch
import torch.nn.functional as F
class SpeculativeDecoder:
"""
Speculative decoding with draft-verification paradigm.
1. Draft model generates multiple candidate tokens
2. Target model verifies all candidates in parallel
3. Accepted tokens advance; rejected tokens trigger correction
"""
def __init__(self, target_model, draft_model, max_draft=8):
self.target = target_model # Large, accurate model
self.draft = draft_model # Small, fast model
self.max_draft = max_draft # Draft length
def generate_next_token(self, input_ids):
"""
Generate next token(s) using speculation.
"""
# Step 1: Draft model generates candidates
draft_tokens = self._draft(input_ids, self.max_draft)
# Step 2: Target model verifies in single forward pass
verified_tokens, accepted = self._verify(input_ids, draft_tokens)
return verified_tokens, accepted
def _draft(self, input_ids, num_tokens):
"""
Use draft model to propose tokens.
"""
draft_ids = list(input_ids)
draft_tokens = []
for _ in range(num_tokens):
# Single forward pass per draft token (small model = fast)
logits = self.draft(torch.tensor([draft_ids]))
next_token = logits[0, -1].argmax().item()
draft_tokens.append(next_token)
draft_ids.append(next_token)
if next_token == self.draft.eos_token:
break
return draft_tokens
def _verify(self, input_ids, draft_tokens):
"""
Verify draft tokens with target model.
"""
if not draft_tokens:
return [], True
# Construct input with draft tokens appended
extended_ids = input_ids + draft_tokens
# Single forward pass through target model
logits = self.target(torch.tensor([extended_ids]))
# Check each position
accepted = []
rejected_idx = None
for i, draft_token in enumerate(draft_tokens):
target_idx = len(input_ids) + i
target_token = logits[0, target_idx].argmax().item()
if draft_token == target_token:
accepted.append(draft_token)
else:
# First rejection - target model diverges
rejected_idx = i
break
# If all accepted, add one more from target
if rejected_idx is None:
final_token = logits[0, -1].argmax().item()
accepted.append(final_token)
return accepted, len(accepted) >= len(draft_tokens)
Sampling-Based Speculation
class StochasticSpeculativeDecoder:
"""
Speculative decoding with sampling instead of greedy.
"""
def __init__(self, target_model, draft_model, temperature=0.6):
self.target = target_model
self.draft = draft_model
self.temperature = temperature
def generate(self, input_ids, max_draft=8):
"""Generate with probabilistic sampling."""
all_tokens = []
while len(all_tokens) < max_draft:
# Draft with temperature
draft_tokens = self._sample_draft(input_ids, max_draft)
if not draft_tokens:
break
# Verify with target
accepted, target_token = self._verify_sampling(
input_ids, draft_tokens
)
# Add accepted tokens
all_tokens.extend(accepted)
# Update input for next iteration
input_ids = input_ids + accepted
# If target produced a new token, add it and continue
if target_token is not None:
all_tokens.append(target_token)
input_ids = input_ids + [target_token]
else:
break
return all_tokens
def _sample_draft(self, input_ids, num_tokens):
"""Sample from draft model."""
draft_ids = list(input_ids)
tokens = []
for _ in range(num_tokens):
logits = self.draft(torch.tensor([draft_ids]))
probs = F.softmax(logits[0, -1] / self.temperature, dim=-1)
token = torch.multinomial(probs, 1).item()
tokens.append(token)
draft_ids.append(token)
if token == self.draft.eos_token:
break
return tokens
def _verify_sampling(self, input_ids, draft_tokens):
"""
Verify with sampling - allows controlled divergence.
"""
extended_ids = input_ids + draft_tokens
logits = self.target(torch.tensor([extended_ids]))
accepted = []
for i, draft_token in enumerate(draft_tokens):
target_idx = len(input_ids) + i
probs = F.softmax(logits[0, target_idx] / self.temperature, dim=-1)
# Sample from target distribution
target_token = torch.multinomial(probs, 1).item()
if draft_token == target_token:
accepted.append(draft_token)
else:
# Rejection: use target's choice
return accepted, target_token
# All accepted - sample one more
final_token = torch.multinomial(
F.softmax(logits[0, -1] / self.temperature, dim=-1), 1
).item()
return accepted, final_token
Advanced Speculative Algorithms
1. Self-Speculative Decoding
class SelfSpeculativeDecoder:
"""
Use the target model itself as both draft and verifier.
Saves memory by not loading a separate draft model.
"""
def __init__(self, model, max_draft=6):
self.model = model
self.max_draft = max_draft
def generate(self, input_ids):
"""
Self-speculation: use earlier layers as draft, later as verifier.
"""
# Extract hidden states at draft depth
hidden = self._get_hidden_states(input_ids, draft_depth=15)
# Generate draft from hidden states
draft_tokens = self._hidden_to_tokens(hidden)
# Verify with full model
accepted = self._verify(input_ids, draft_tokens)
return accepted
def _get_hidden_states(self, input_ids, draft_depth):
"""Extract intermediate hidden states."""
# In practice, use model's hidden states
# Simplified here
return self.model(input_ids, output_hidden_states=True).hidden_states[draft_depth]
def _hidden_to_tokens(self, hidden):
"""Convert hidden states to tokens."""
# Use projection layer
logits = self.model.lm_head(hidden)
return logits.argmax(dim=-1)
def _verify(self, input_ids, draft_tokens):
"""Verify by running full forward pass."""
# Full forward pass
full_output = self.model(input_ids + draft_tokens)
# Compare tokens
accepted = []
for i, draft in enumerate(draft_tokens):
target_idx = len(input_ids) + i
target_token = full_output.logits[0, target_idx].argmax().item()
if draft == target_token:
accepted.append(draft)
else:
break
return accepted
2. Hierarchical Speculative Decoding
class HierarchicalSpeculativeDecoder:
"""
Multi-level speculation: draft1 -> draft2 -> target.
Uses progressively larger models for better acceptance.
"""
def __init__(self, models, thresholds=[0.3, 0.7]):
"""
models: [small, medium, target]
thresholds: acceptance rate thresholds to escalate
"""
self.models = models
self.thresholds = thresholds
def generate(self, input_ids):
"""Generate with hierarchical speculation."""
# Level 1: smallest model
draft1 = self._generate_draft(input_ids, self.models[0], max_tokens=6)
# Check acceptance
accepted1, next_input = self._verify_level(input_ids, draft1, self.models[1])
if len(accepted1) / len(draft1) < self.thresholds[0]:
# Level 2: medium model
draft2 = self._generate_draft(next_input, self.models[1], max_tokens=4)
accepted2, final_input = self._verify_level(
next_input, draft2, self.models[2]
)
return accepted1 + accepted2
else:
return accepted1
def _generate_draft(self, input_ids, model, max_tokens):
"""Generate draft tokens."""
tokens = []
for _ in range(max_tokens):
logits = model(input_ids + tokens)
token = logits[0, -1].argmax().item()
tokens.append(token)
if token == model.eos_token:
break
return tokens
def _verify_level(self, input_ids, draft_tokens, verifier_model):
"""Verify at current level."""
extended = input_ids + draft_tokens
logits = verifier_model(extended)
accepted = []
for i in range(len(draft_tokens)):
target_idx = len(input_ids) + i
target_token = logits[0, target_idx].argmax().item()
if draft_tokens[i] == target_token:
accepted.append(draft_tokens[i])
else:
break
next_input = input_ids + accepted
return accepted, next_input
3. Lookahead Speculation
class LookaheadSpeculation:
"""
Lookahead: use n-gram patterns for speculation.
Leverages local context patterns.
"""
def __init__(self, model, ngram_size=5, max_draft=8):
self.model = model
self.ngram_size = ngram_size
self.max_draft = max_draft
self.ngram_cache = {}
def generate(self, input_ids):
"""Generate with n-gram guided speculation."""
tokens = list(input_ids)
while len(tokens) < self.max_draft:
# Build n-gram from recent tokens
context = tuple(tokens[-self.ngram_size+1:])
# Check cache for possible continuations
candidates = self._get_candidates(context)
if candidates:
# Verify candidates
verified = self._verify_candidates(tokens, candidates)
if verified:
tokens.append(verified)
continue
# Fallback: single token generation
logits = self.model(torch.tensor([tokens]))
token = logits[0, -1].argmax().item()
tokens.append(token)
return tokens
def _get_candidates(self, context):
"""Get cached n-gram continuations."""
return self.ngram_cache.get(context, [])
def _verify_candidates(self, tokens, candidates):
"""Verify candidate tokens."""
for candidate in candidates[:3]: # Try top 3
test_tokens = tokens + [candidate]
logits = self.model(torch.tensor([test_tokens]))
verified = logits[0, -1].argmax().item()
if verified == candidate:
return candidate
return None
def update_cache(self, generated_text):
"""Update n-gram cache from generated text."""
tokens = generated_text.split()
for i in range(len(tokens) - self.ngram_size):
ngram = tuple(tokens[i:i+self.ngram_size-1])
continuation = tokens[i+self.ngram_size-1]
if ngram not in self.ngram_cache:
self.ngram_cache[ngram] = []
if continuation not in self.ngram_cache[ngram]:
self.ngram_cache[ngram].append(continuation)
Implementation Framework
class SpeculativeGenerationPipeline:
"""
Complete speculative decoding pipeline.
"""
def __init__(self, config):
# Load models
self.target = self._load_model(config.target_model)
self.draft = self._load_model(config.draft_model)
# Configuration
self.max_draft = config.max_draft
self.temperature = config.temperature
# Choose algorithm
if config.algorithm == 'standard':
self.decoder = SpeculativeDecoder(self.target, self.draft)
elif config.algorithm == 'sampling':
self.decoder = StochasticSpeculativeDecoder(
self.target, self.draft, config.temperature
)
elif config.algorithm == 'self':
self.decoder = SelfSpeculativeDecoder(self.target)
else:
self.decoder = SpeculativeDecoder(self.target, self.draft)
def generate(self, prompt, max_tokens=100):
"""Generate with speculative decoding."""
input_ids = self.tokenizer.encode(prompt)
output_tokens = []
while len(output_tokens) < max_tokens:
# Get next tokens
new_tokens, accepted_all = self.decoder.generate_next_token(input_ids)
if not new_tokens:
break
output_tokens.extend(new_tokens)
input_ids = input_ids + new_tokens
# Check for EOS
if new_tokens[-1] == self.tokenizer.eos_token:
break
return self.tokenizer.decode(output_tokens)
def benchmark(self, prompts, baseline_time):
"""Benchmark speedup."""
speculative_time = self._time_generation(prompts)
return {
'speedup': baseline_time / speculative_time,
'tokens_per_second': total_tokens / speculative_time,
'acceptance_rate': self._compute_acceptance_rate()
}
Optimizations and Tricks
1. Adaptive Draft Length
class AdaptiveSpeculativeDecoder:
"""
Adjust draft length based on acceptance rate.
"""
def __init__(self, target, draft):
self.decoder = SpeculativeDecoder(target, draft)
self.acceptance_history = []
self.current_draft_len = 6
def generate(self, input_ids):
"""Generate with adaptive draft length."""
# Use current draft length
self.decoder.max_draft = self.current_draft_len
# Generate
tokens, fully_accepted = self.decoder.generate_next_token(input_ids)
# Track acceptance
acceptance = len(tokens) / self.current_draft_len
self.acceptance_history.append(acceptance)
# Adapt draft length
if len(self.acceptance_history) > 10:
avg_acceptance = sum(self.acceptance_history[-10:]) / 10
if avg_acceptance > 0.9:
self.current_draft_len = min(12, self.current_draft_len + 1)
elif avg_acceptance < 0.6:
self.current_draft_len = max(3, self.current_draft_len - 1)
return tokens
2. Batch Speculation
class BatchSpeculativeDecoder:
"""
Process multiple sequences with speculation in parallel.
"""
def __init__(self, target, draft):
self.target = target
self.draft = draft
def generate_batch(self, prompts, max_tokens):
"""
Generate multiple sequences simultaneously.
"""
# Encode all prompts
input_ids_list = [self.tokenizer.encode(p) for p in prompts]
results = [[] for _ in prompts]
finished = [False] * len(prompts)
while not all(finished) and len(results[0]) < max_tokens:
# Draft for all sequences
drafts = []
for ids in input_ids_list:
draft_tokens = self._quick_draft(ids)
drafts.append(draft_tokens)
# Verify all in batch
for i, (ids, draft) in enumerate(zip(input_ids_list, drafts)):
if finished[i]:
continue
accepted = self._verify(ids, draft)
results[i].extend(accepted)
input_ids_list[i].extend(accepted)
if not accepted or draft[-1] == self.tokenizer.eos_token:
finished[i] = True
return [self.tokenizer.decode(r) for r in results]
Performance Analysis
Expected Speedups
| Configuration | Draft Model | Acceptance Rate | Speedup |
|---|---|---|---|
| 70B โ 7B | 7B | 85-95% | 2.5-3.0x |
| 70B โ 3B | 3B | 70-85% | 2.0-2.5x |
| Self-speculate | Same | 75-90% | 1.8-2.2x |
| Hierarchical | Multi-level | 90-95% | 2.5-3.5x |
def analyze_speedup(target_time_per_token, draft_time_per_token, acceptance_rate):
"""
Compute expected speedup.
Traditional: target_time_per_token per token
Speculative: draft_time_per_token * n + target_time_per_token / n
where n = 1/(1-acceptance_rate)
"""
# Average tokens per iteration
avg_tokens = 1 / (1 - acceptance_rate)
# Time per iteration
speculative_time = (draft_time_per_token * avg_tokens +
target_time_per_token)
# Tokens per time
speculative_rate = avg_tokens / speculative_time
traditional_rate = 1 / target_time_per_token
return traditional_rate / speculative_rate
Best Practices
1. Draft Model Selection
def select_draft_model(target_model_size):
"""
Select appropriate draft model based on target.
"""
# Rule of thumb: 10x smaller is usually good
if target_model_size >= 70_000_000_000:
return 7_000_000_000 # 7B draft for 70B target
elif target_model_size >= 10_000_000_000:
return 3_000_000_000 # 3B draft for 10B target
else:
return target_model_size // 4 # ~4x smaller
2. Handling Rejection
class RobustSpeculativeDecoder:
"""
Handle rejection gracefully with multiple strategies.
"""
def handle_rejection(self, input_ids, draft_tokens, target_logits):
"""
When draft diverges from target, recover gracefully.
"""
# Strategy 1: Use target's token
target_token = target_logits[len(input_ids)].argmax().item()
# Strategy 2: Temperature sampling for diversity
if self.use_sampling:
probs = F.softmax(target_logits[len(input_ids)] / self.temp)
target_token = torch.multinomial(probs, 1).item()
# Strategy 3: Beam search continuation
# (more complex - generates multiple options)
return target_token
Future Directions in 2026
Emerging Innovations
- Self-Speculation: Using model itself as draft (no separate model)
- Multi-Query Attention: Batch speculation across multiple requests
- Hardware-Software Co-design: Specialized pipelines for GPUs
- Diffusion Speculation: Extending to non-autoregressive models
Resources
Conclusion
Speculative decoding represents a breakthrough in LLM inference efficiency. By leveraging the insight that most tokens are predictable, we can achieve 2-3x speedups without any quality loss.
The key is choosing the right approach: standard greedy for maximum speed, stochastic for creative generation, hierarchical for varied quality requirements, or self-speculation when memory is constrained.
As LLM deployment scales, speculative decoding will become standard practice. The technique is lossless, requires no model retraining, and provides immediate performance benefits. It’s one of the most practical optimization techniques in the modern AI toolkit.
Comments