Skip to main content
โšก Calmops

Speculative Decoding: Lossless LLM Inference Acceleration

Introduction

Large Language Model inference remains computationally intensive, with autoregressive decoding being inherently sequentialโ€”each token depends on all previous tokens. This creates a memory-bandwidth bottleneck where most time is spent loading model weights rather than computing.

Speculative decoding solves this elegantly: instead of generating tokens one-by-one, we use a smaller “draft” model to propose multiple tokens in parallel, then verify them with the larger target model. When predictions match (which happens ~70-90% of the time), we get multiple tokens for the cost of one verification pass.

In 2026, speculative decoding has become essential for production LLM deployment, achieving 2-3x speedups while maintaining identical output quality. This guide explores the algorithms, implementations, and practical applications.

The Problem: Autoregressive Bottleneck

Standard Autoregressive Generation

class StandardAutoregressive:
    """
    Standard LLM generation - sequential token by token.
    """
    
    def __init__(self, model, tokenizer):
        self.model = model
        self.tokenizer = tokenizer
        
    def generate(self, prompt, max_tokens=100):
        """
        Generate tokens one at a time.
        """
        input_ids = self.tokenizer.encode(prompt)
        
        for _ in range(max_tokens):
            # Forward pass for ALL previous tokens
            logits = self.model(torch.tensor([input_ids]))
            
            # Get next token
            next_token = logits[0, -1].argmax()
            input_ids.append(next_token.item())
            
            if next_token == self.tokenizer.eos_token:
                break
                
        return self.tokenizer.decode(input_ids)

The problem: each token requires a full forward pass through the entire model, loading all weights from memory. For a 70B model, that’s 140GB of memory bandwidth per token.

Speculative Decoding Fundamentals

Core Algorithm

import torch
import torch.nn.functional as F

class SpeculativeDecoder:
    """
    Speculative decoding with draft-verification paradigm.
    
    1. Draft model generates multiple candidate tokens
    2. Target model verifies all candidates in parallel
    3. Accepted tokens advance; rejected tokens trigger correction
    """
    
    def __init__(self, target_model, draft_model, max_draft=8):
        self.target = target_model  # Large, accurate model
        self.draft = draft_model   # Small, fast model
        self.max_draft = max_draft  # Draft length
        
    def generate_next_token(self, input_ids):
        """
        Generate next token(s) using speculation.
        """
        # Step 1: Draft model generates candidates
        draft_tokens = self._draft(input_ids, self.max_draft)
        
        # Step 2: Target model verifies in single forward pass
        verified_tokens, accepted = self._verify(input_ids, draft_tokens)
        
        return verified_tokens, accepted
    
    def _draft(self, input_ids, num_tokens):
        """
        Use draft model to propose tokens.
        """
        draft_ids = list(input_ids)
        draft_tokens = []
        
        for _ in range(num_tokens):
            # Single forward pass per draft token (small model = fast)
            logits = self.draft(torch.tensor([draft_ids]))
            next_token = logits[0, -1].argmax().item()
            
            draft_tokens.append(next_token)
            draft_ids.append(next_token)
            
            if next_token == self.draft.eos_token:
                break
                
        return draft_tokens
    
    def _verify(self, input_ids, draft_tokens):
        """
        Verify draft tokens with target model.
        """
        if not draft_tokens:
            return [], True
            
        # Construct input with draft tokens appended
        extended_ids = input_ids + draft_tokens
        
        # Single forward pass through target model
        logits = self.target(torch.tensor([extended_ids]))
        
        # Check each position
        accepted = []
        rejected_idx = None
        
        for i, draft_token in enumerate(draft_tokens):
            target_idx = len(input_ids) + i
            target_token = logits[0, target_idx].argmax().item()
            
            if draft_token == target_token:
                accepted.append(draft_token)
            else:
                # First rejection - target model diverges
                rejected_idx = i
                break
                
        # If all accepted, add one more from target
        if rejected_idx is None:
            final_token = logits[0, -1].argmax().item()
            accepted.append(final_token)
            
        return accepted, len(accepted) >= len(draft_tokens)

Sampling-Based Speculation

class StochasticSpeculativeDecoder:
    """
    Speculative decoding with sampling instead of greedy.
    """
    
    def __init__(self, target_model, draft_model, temperature=0.6):
        self.target = target_model
        self.draft = draft_model
        self.temperature = temperature
        
    def generate(self, input_ids, max_draft=8):
        """Generate with probabilistic sampling."""
        all_tokens = []
        
        while len(all_tokens) < max_draft:
            # Draft with temperature
            draft_tokens = self._sample_draft(input_ids, max_draft)
            
            if not draft_tokens:
                break
                
            # Verify with target
            accepted, target_token = self._verify_sampling(
                input_ids, draft_tokens
            )
            
            # Add accepted tokens
            all_tokens.extend(accepted)
            
            # Update input for next iteration
            input_ids = input_ids + accepted
            
            # If target produced a new token, add it and continue
            if target_token is not None:
                all_tokens.append(target_token)
                input_ids = input_ids + [target_token]
            else:
                break
                
        return all_tokens
    
    def _sample_draft(self, input_ids, num_tokens):
        """Sample from draft model."""
        draft_ids = list(input_ids)
        tokens = []
        
        for _ in range(num_tokens):
            logits = self.draft(torch.tensor([draft_ids]))
            probs = F.softmax(logits[0, -1] / self.temperature, dim=-1)
            token = torch.multinomial(probs, 1).item()
            
            tokens.append(token)
            draft_ids.append(token)
            
            if token == self.draft.eos_token:
                break
                
        return tokens
    
    def _verify_sampling(self, input_ids, draft_tokens):
        """
        Verify with sampling - allows controlled divergence.
        """
        extended_ids = input_ids + draft_tokens
        logits = self.target(torch.tensor([extended_ids]))
        
        accepted = []
        
        for i, draft_token in enumerate(draft_tokens):
            target_idx = len(input_ids) + i
            probs = F.softmax(logits[0, target_idx] / self.temperature, dim=-1)
            
            # Sample from target distribution
            target_token = torch.multinomial(probs, 1).item()
            
            if draft_token == target_token:
                accepted.append(draft_token)
            else:
                # Rejection: use target's choice
                return accepted, target_token
                
        # All accepted - sample one more
        final_token = torch.multinomial(
            F.softmax(logits[0, -1] / self.temperature, dim=-1), 1
        ).item()
        
        return accepted, final_token

Advanced Speculative Algorithms

1. Self-Speculative Decoding

class SelfSpeculativeDecoder:
    """
    Use the target model itself as both draft and verifier.
    Saves memory by not loading a separate draft model.
    """
    
    def __init__(self, model, max_draft=6):
        self.model = model
        self.max_draft = max_draft
        
    def generate(self, input_ids):
        """
        Self-speculation: use earlier layers as draft, later as verifier.
        """
        # Extract hidden states at draft depth
        hidden = self._get_hidden_states(input_ids, draft_depth=15)
        
        # Generate draft from hidden states
        draft_tokens = self._hidden_to_tokens(hidden)
        
        # Verify with full model
        accepted = self._verify(input_ids, draft_tokens)
        
        return accepted
    
    def _get_hidden_states(self, input_ids, draft_depth):
        """Extract intermediate hidden states."""
        # In practice, use model's hidden states
        # Simplified here
        return self.model(input_ids, output_hidden_states=True).hidden_states[draft_depth]
    
    def _hidden_to_tokens(self, hidden):
        """Convert hidden states to tokens."""
        # Use projection layer
        logits = self.model.lm_head(hidden)
        return logits.argmax(dim=-1)
    
    def _verify(self, input_ids, draft_tokens):
        """Verify by running full forward pass."""
        # Full forward pass
        full_output = self.model(input_ids + draft_tokens)
        
        # Compare tokens
        accepted = []
        for i, draft in enumerate(draft_tokens):
            target_idx = len(input_ids) + i
            target_token = full_output.logits[0, target_idx].argmax().item()
            
            if draft == target_token:
                accepted.append(draft)
            else:
                break
                
        return accepted

2. Hierarchical Speculative Decoding

class HierarchicalSpeculativeDecoder:
    """
    Multi-level speculation: draft1 -> draft2 -> target.
    Uses progressively larger models for better acceptance.
    """
    
    def __init__(self, models, thresholds=[0.3, 0.7]):
        """
        models: [small, medium, target]
        thresholds: acceptance rate thresholds to escalate
        """
        self.models = models
        self.thresholds = thresholds
        
    def generate(self, input_ids):
        """Generate with hierarchical speculation."""
        # Level 1: smallest model
        draft1 = self._generate_draft(input_ids, self.models[0], max_tokens=6)
        
        # Check acceptance
        accepted1, next_input = self._verify_level(input_ids, draft1, self.models[1])
        
        if len(accepted1) / len(draft1) < self.thresholds[0]:
            # Level 2: medium model
            draft2 = self._generate_draft(next_input, self.models[1], max_tokens=4)
            accepted2, final_input = self._verify_level(
                next_input, draft2, self.models[2]
            )
            return accepted1 + accepted2
        else:
            return accepted1
    
    def _generate_draft(self, input_ids, model, max_tokens):
        """Generate draft tokens."""
        tokens = []
        for _ in range(max_tokens):
            logits = model(input_ids + tokens)
            token = logits[0, -1].argmax().item()
            tokens.append(token)
            if token == model.eos_token:
                break
        return tokens
    
    def _verify_level(self, input_ids, draft_tokens, verifier_model):
        """Verify at current level."""
        extended = input_ids + draft_tokens
        logits = verifier_model(extended)
        
        accepted = []
        for i in range(len(draft_tokens)):
            target_idx = len(input_ids) + i
            target_token = logits[0, target_idx].argmax().item()
            
            if draft_tokens[i] == target_token:
                accepted.append(draft_tokens[i])
            else:
                break
                
        next_input = input_ids + accepted
        return accepted, next_input

3. Lookahead Speculation

class LookaheadSpeculation:
    """
    Lookahead: use n-gram patterns for speculation.
    Leverages local context patterns.
    """
    
    def __init__(self, model, ngram_size=5, max_draft=8):
        self.model = model
        self.ngram_size = ngram_size
        self.max_draft = max_draft
        self.ngram_cache = {}
        
    def generate(self, input_ids):
        """Generate with n-gram guided speculation."""
        tokens = list(input_ids)
        
        while len(tokens) < self.max_draft:
            # Build n-gram from recent tokens
            context = tuple(tokens[-self.ngram_size+1:])
            
            # Check cache for possible continuations
            candidates = self._get_candidates(context)
            
            if candidates:
                # Verify candidates
                verified = self._verify_candidates(tokens, candidates)
                if verified:
                    tokens.append(verified)
                    continue
                    
            # Fallback: single token generation
            logits = self.model(torch.tensor([tokens]))
            token = logits[0, -1].argmax().item()
            tokens.append(token)
            
        return tokens
    
    def _get_candidates(self, context):
        """Get cached n-gram continuations."""
        return self.ngram_cache.get(context, [])
    
    def _verify_candidates(self, tokens, candidates):
        """Verify candidate tokens."""
        for candidate in candidates[:3]:  # Try top 3
            test_tokens = tokens + [candidate]
            logits = self.model(torch.tensor([test_tokens]))
            verified = logits[0, -1].argmax().item()
            
            if verified == candidate:
                return candidate
                
        return None
    
    def update_cache(self, generated_text):
        """Update n-gram cache from generated text."""
        tokens = generated_text.split()
        for i in range(len(tokens) - self.ngram_size):
            ngram = tuple(tokens[i:i+self.ngram_size-1])
            continuation = tokens[i+self.ngram_size-1]
            
            if ngram not in self.ngram_cache:
                self.ngram_cache[ngram] = []
            if continuation not in self.ngram_cache[ngram]:
                self.ngram_cache[ngram].append(continuation)

Implementation Framework

class SpeculativeGenerationPipeline:
    """
    Complete speculative decoding pipeline.
    """
    
    def __init__(self, config):
        # Load models
        self.target = self._load_model(config.target_model)
        self.draft = self._load_model(config.draft_model)
        
        # Configuration
        self.max_draft = config.max_draft
        self.temperature = config.temperature
        
        # Choose algorithm
        if config.algorithm == 'standard':
            self.decoder = SpeculativeDecoder(self.target, self.draft)
        elif config.algorithm == 'sampling':
            self.decoder = StochasticSpeculativeDecoder(
                self.target, self.draft, config.temperature
            )
        elif config.algorithm == 'self':
            self.decoder = SelfSpeculativeDecoder(self.target)
        else:
            self.decoder = SpeculativeDecoder(self.target, self.draft)
            
    def generate(self, prompt, max_tokens=100):
        """Generate with speculative decoding."""
        input_ids = self.tokenizer.encode(prompt)
        output_tokens = []
        
        while len(output_tokens) < max_tokens:
            # Get next tokens
            new_tokens, accepted_all = self.decoder.generate_next_token(input_ids)
            
            if not new_tokens:
                break
                
            output_tokens.extend(new_tokens)
            input_ids = input_ids + new_tokens
            
            # Check for EOS
            if new_tokens[-1] == self.tokenizer.eos_token:
                break
                
        return self.tokenizer.decode(output_tokens)
    
    def benchmark(self, prompts, baseline_time):
        """Benchmark speedup."""
        speculative_time = self._time_generation(prompts)
        
        return {
            'speedup': baseline_time / speculative_time,
            'tokens_per_second': total_tokens / speculative_time,
            'acceptance_rate': self._compute_acceptance_rate()
        }

Optimizations and Tricks

1. Adaptive Draft Length

class AdaptiveSpeculativeDecoder:
    """
    Adjust draft length based on acceptance rate.
    """
    
    def __init__(self, target, draft):
        self.decoder = SpeculativeDecoder(target, draft)
        self.acceptance_history = []
        self.current_draft_len = 6
        
    def generate(self, input_ids):
        """Generate with adaptive draft length."""
        # Use current draft length
        self.decoder.max_draft = self.current_draft_len
        
        # Generate
        tokens, fully_accepted = self.decoder.generate_next_token(input_ids)
        
        # Track acceptance
        acceptance = len(tokens) / self.current_draft_len
        self.acceptance_history.append(acceptance)
        
        # Adapt draft length
        if len(self.acceptance_history) > 10:
            avg_acceptance = sum(self.acceptance_history[-10:]) / 10
            
            if avg_acceptance > 0.9:
                self.current_draft_len = min(12, self.current_draft_len + 1)
            elif avg_acceptance < 0.6:
                self.current_draft_len = max(3, self.current_draft_len - 1)
                
        return tokens

2. Batch Speculation

class BatchSpeculativeDecoder:
    """
    Process multiple sequences with speculation in parallel.
    """
    
    def __init__(self, target, draft):
        self.target = target
        self.draft = draft
        
    def generate_batch(self, prompts, max_tokens):
        """
        Generate multiple sequences simultaneously.
        """
        # Encode all prompts
        input_ids_list = [self.tokenizer.encode(p) for p in prompts]
        
        results = [[] for _ in prompts]
        finished = [False] * len(prompts)
        
        while not all(finished) and len(results[0]) < max_tokens:
            # Draft for all sequences
            drafts = []
            for ids in input_ids_list:
                draft_tokens = self._quick_draft(ids)
                drafts.append(draft_tokens)
                
            # Verify all in batch
            for i, (ids, draft) in enumerate(zip(input_ids_list, drafts)):
                if finished[i]:
                    continue
                    
                accepted = self._verify(ids, draft)
                results[i].extend(accepted)
                input_ids_list[i].extend(accepted)
                
                if not accepted or draft[-1] == self.tokenizer.eos_token:
                    finished[i] = True
                    
        return [self.tokenizer.decode(r) for r in results]

Performance Analysis

Expected Speedups

Configuration Draft Model Acceptance Rate Speedup
70B โ†’ 7B 7B 85-95% 2.5-3.0x
70B โ†’ 3B 3B 70-85% 2.0-2.5x
Self-speculate Same 75-90% 1.8-2.2x
Hierarchical Multi-level 90-95% 2.5-3.5x
def analyze_speedup(target_time_per_token, draft_time_per_token, acceptance_rate):
    """
    Compute expected speedup.
    
    Traditional: target_time_per_token per token
    Speculative: draft_time_per_token * n + target_time_per_token / n
    where n = 1/(1-acceptance_rate)
    """
    # Average tokens per iteration
    avg_tokens = 1 / (1 - acceptance_rate)
    
    # Time per iteration
    speculative_time = (draft_time_per_token * avg_tokens + 
                       target_time_per_token)
    
    # Tokens per time
    speculative_rate = avg_tokens / speculative_time
    traditional_rate = 1 / target_time_per_token
    
    return traditional_rate / speculative_rate

Best Practices

1. Draft Model Selection

def select_draft_model(target_model_size):
    """
    Select appropriate draft model based on target.
    """
    # Rule of thumb: 10x smaller is usually good
    if target_model_size >= 70_000_000_000:
        return 7_000_000_000   # 7B draft for 70B target
    elif target_model_size >= 10_000_000_000:
        return 3_000_000_000   # 3B draft for 10B target
    else:
        return target_model_size // 4  # ~4x smaller

2. Handling Rejection

class RobustSpeculativeDecoder:
    """
    Handle rejection gracefully with multiple strategies.
    """
    
    def handle_rejection(self, input_ids, draft_tokens, target_logits):
        """
        When draft diverges from target, recover gracefully.
        """
        # Strategy 1: Use target's token
        target_token = target_logits[len(input_ids)].argmax().item()
        
        # Strategy 2: Temperature sampling for diversity
        if self.use_sampling:
            probs = F.softmax(target_logits[len(input_ids)] / self.temp)
            target_token = torch.multinomial(probs, 1).item()
            
        # Strategy 3: Beam search continuation
        # (more complex - generates multiple options)
        
        return target_token

Future Directions in 2026

Emerging Innovations

  1. Self-Speculation: Using model itself as draft (no separate model)
  2. Multi-Query Attention: Batch speculation across multiple requests
  3. Hardware-Software Co-design: Specialized pipelines for GPUs
  4. Diffusion Speculation: Extending to non-autoregressive models

Resources

Conclusion

Speculative decoding represents a breakthrough in LLM inference efficiency. By leveraging the insight that most tokens are predictable, we can achieve 2-3x speedups without any quality loss.

The key is choosing the right approach: standard greedy for maximum speed, stochastic for creative generation, hierarchical for varied quality requirements, or self-speculation when memory is constrained.

As LLM deployment scales, speculative decoding will become standard practice. The technique is lossless, requires no model retraining, and provides immediate performance benefits. It’s one of the most practical optimization techniques in the modern AI toolkit.

Comments