Skip to main content
โšก Calmops

Knowledge Distillation: Model Compression Techniques

Introduction

As deep learning models grow larger and more powerful, a fundamental challenge emerges: how do we deploy these massive models on resource-constrained devices? Knowledge distillation offers an elegant solution by transferring knowledge from a large “teacher” model to a smaller “student” model that can be deployed efficiently.

The core idea is elegantly simple: instead of training the student model only on hard labels, we train it to match the soft probability outputs of the teacher model. This allows the student to learn not just what the correct answer is, but how the teacher thinks about the problemโ€”capturing the “dark knowledge” embedded in the teacher’s predictions.

In 2026, knowledge distillation has become essential for deploying large language models and vision transformers efficiently. This comprehensive guide explores the mathematics, implementations, and practical applications of knowledge distillation.

Foundations of Knowledge Distillation

The Basic Concept

Knowledge distillation transfers information from a teacher model T to a student model S:

import torch
import torch.nn as nn
import torch.nn.functional as F

class KnowledgeDistillation:
    """
    Basic knowledge distillation framework.
    
    The student learns from:
    1. Hard labels (standard cross-entropy)
    2. Soft labels from teacher (KL divergence)
    """
    
    def __init__(self, teacher_model, student_model, temperature=4.0, alpha=0.5):
        self.teacher = teacher_model
        self.student = student_model
        self.temperature = temperature  # Higher = softer probabilities
        self.alpha = alpha  # Balance between hard and soft labels
        
    def distillation_loss(self, student_logits, teacher_logits, labels):
        """
        Compute combined distillation loss.
        
        Loss = ฮฑ ร— KL_div(soft_teacher || soft_student) + (1-ฮฑ) ร— CE(student, labels)
        """
        # Soft targets from teacher
        soft_teacher = F.softmax(teacher_logits / self.temperature, dim=-1)
        soft_student = F.log_softmax(student_logits / self.temperature, dim=-1)
        
        # KL divergence (scaled by Tยฒ as per Hinton et al.)
        soft_loss = F.kl_div(soft_student, soft_teacher, reduction='batchmean')
        soft_loss = soft_loss * (self.temperature ** 2)
        
        # Hard label loss
        hard_loss = F.cross_entropy(student_logits, labels)
        
        # Combined loss
        return self.alpha * soft_loss + (1 - self.alpha) * hard_loss

The Mathematics

The distillation loss is:

$$\mathcal{L}_{KD} = \alpha \cdot T^2 \cdot \text{KL}(p_T || p_S) + (1-\alpha) \cdot \mathcal{L}_{CE}(y, p_S)$$

Where:

  • $p_T = \text{softmax}(z_T/T)$ โ€” teacher soft predictions
  • $p_S = \text{softmax}(z_S/T)$ โ€” student soft predictions
  • $T$ โ€” temperature (higher = softer distributions)
  • $\alpha$ โ€” balances hard and soft targets
def temperature_softmax(logits, temperature):
    """
    Apply temperature-scaled softmax.
    
    Higher temperature smooths the distribution,
    revealing relationships between classes.
    """
    return F.softmax(logits / temperature, dim=-1)

Distillation Strategies

1. Response-Based Distillation

Match the final outputs:

class ResponseDistillation(nn.Module):
    """
    Response-based: student matches teacher's output layer.
    
    Simplest form of distillation.
    """
    
    def __init__(self, teacher, student, temperature=4.0):
        super().__init__()
        self.teacher = teacher
        self.student = student
        self.temperature = temperature
        
    def forward(self, x):
        """
        Forward pass computing both outputs.
        """
        with torch.no_grad():
            teacher_output = self.teacher(x)
            
        student_output = self.student(x)
        
        return teacher_output, student_output
    
    def loss(self, teacher_output, student_output, labels):
        """
        Compute response-based distillation loss.
        """
        # Soft loss
        soft_teacher = F.softmax(teacher_output / self.temperature, dim=-1)
        soft_student = F.log_softmax(student_output / self.temperature, dim=-1)
        soft_loss = F.kl_div(soft_student, soft_teacher, reduction='batchmean')
        soft_loss = soft_loss * (self.temperature ** 2)
        
        # Hard loss
        hard_loss = F.cross_entropy(student_output, labels)
        
        return soft_loss + hard_loss

2. Feature-Based Distillation

Match intermediate representations:

class FeatureDistillation(nn.Module):
    """
    Feature-based: student matches teacher's intermediate features.
    
    Useful for different architectures.
    """
    
    def __init__(self, teacher, student, feature_dim=768):
        super().__init__()
        self.teacher = teacher
        self.student = student
        self.feature_dim = feature_dim
        
        # Adapter to match dimensions
        self.adapter = nn.Linear(student.feature_dim, feature_dim)
        
    def get_teacher_features(self, x):
        """
        Extract intermediate features from teacher.
        """
        features = []
        
        def hook(module, input, output):
            features.append(output)
            
        # Register hooks on teacher layers
        handles = []
        for layer in self.teacher.layers[-4:]:
            handle = layer.register_forward_hook(hook)
            handles.append(handle)
            
        _ = self.teacher(x)
        
        # Remove hooks
        for handle in handles:
            handle.remove()
            
        return features[-1]  # Last layer features
        
    def get_student_features(self, x):
        """
        Extract intermediate features from student.
        """
        return self.student.encoder(x)
    
    def feature_loss(self, teacher_features, student_features):
        """
        Match features between teacher and student.
        """
        # Project student features
        student_proj = self.adapter(student_features)
        
        # MSE loss
        loss = F.mse_loss(student_proj, teacher_features)
        
        # Or cosine similarity
        cos_loss = 1 - F.cosine_similarity(
            student_proj, teacher_features, dim=-1
        ).mean()
        
        return loss

3. Relation-Based Distillation

Capture relationships between examples:

class RelationDistillation(nn.Module):
    """
    Relation-based: student learns how teacher relates examples.
    
    Captures inter-sample relationships.
    """
    
    def __init__(self, teacher, student):
        super().__init__()
        self.teacher = teacher
        self.student = student
        
    def relation_loss(self, teacher_logits, student_logits):
        """
        Compute relation-based loss.
        
        Teacher knows that similar inputs should have similar relationships.
        Student should learn this too.
        """
        batch_size = teacher_logits.size(0)
        
        # Compute similarity matrices
        teacher_sim = F.cosine_similarity(
            teacher_logits.unsqueeze(1), 
            teacher_logits.unsqueeze(0),
            dim=-1
        )
        
        student_sim = F.cosine_similarity(
            student_logits.unsqueeze(1),
            student_logits.unsqueeze(0),
            dim=-1
        )
        
        # Match similarity structures
        loss = F.mse_loss(student_sim, teacher_sim)
        
        return loss

Advanced Distillation Techniques

1. Self-Distillation

Train student to match a deeper version of itself:

class SelfDistillation(nn.Module):
    """
    Self-distillation: same architecture at different depths.
    
    Earlier layers teach later layers.
    """
    
    def __init__(self, model, num_layers):
        super().__init__()
        self.model = model
        self.num_layers = num_layers
        
    def extract_layer_outputs(self, x):
        """
        Extract outputs from each layer.
        """
        outputs = [x]
        
        for i, layer in enumerate(self.model.layers):
            x = layer(x)
            outputs.append(x)
            
        return outputs
    
    def self_distillation_loss(self, x):
        """
        Earlier layers teach later layers.
        """
        layer_outputs = self.extract_layer_outputs(x)
        
        loss = 0
        num_pairs = 0
        
        # Each layer taught by all previous layers
        for teach_idx in range(len(layer_outputs) - 1):
            for learn_idx in range(teach_idx + 1, len(layer_outputs)):
                teach_out = layer_outputs[teach_idx]
                learn_out = layer_outputs[learn_idx]
                
                # Match dimensions if needed
                if teach_out.size(-1) != learn_out.size(-1):
                    projector = nn.Linear(
                        learn_out.size(-1), 
                        teach_out.size(-1)
                    ).to(learn_out.device)
                    learn_out = projector(learn_out)
                
                loss += F.mse_loss(learn_out, teach_out)
                num_pairs += 1
                
        return loss / max(num_pairs, 1)

2. Multi-Teacher Distillation

Learn from multiple teachers:

class MultiTeacherDistillation(nn.Module):
    """
    Multi-teacher: student learns from ensemble of teachers.
    
    Combines knowledge from different models.
    """
    
    def __init__(self, teachers, student):
        super().__init__()
        self.teachers = nn.ModuleList(teachers)
        self.student = student
        
    def multi_teacher_loss(self, x, labels):
        """
        Combine knowledge from multiple teachers.
        """
        # Get all teacher outputs
        teacher_logits = []
        for teacher in self.teachers:
            with torch.no_grad():
                logits = teacher(x)
                teacher_logits.append(logits)
                
        # Average or weighted combination
        avg_teacher = torch.stack(teacher_logits).mean(dim=0)
        
        # Student output
        student_logits = self.student(x)
        
        # Loss against average teacher
        soft_teacher = F.softmax(avg_teacher / 4.0, dim=-1)
        soft_student = F.log_softmax(student_logits / 4.0, dim=-1)
        soft_loss = F.kl_div(soft_student, soft_teacher, reduction='batchmean')
        
        # Hard labels
        hard_loss = F.cross_entropy(student_logits, labels)
        
        return soft_loss + hard_loss

3. Layer-by-Layer Distillation

Progressive knowledge transfer:

class ProgressiveDistillation(nn.Module):
    """
    Progressive: distill layer by layer sequentially.
    
    Teacher progressively teaches student.
    """
    
    def __init__(self, teacher, student):
        super().__init__()
        self.teacher = teacher
        self.student = student
        self.current_layer = 0
        
    def distill_single_layer(self, x, layer_idx):
        """
        Distill single layer at a time.
        """
        # Get teacher layer output
        with torch.no_grad():
            teacher_out = self.teacher.layers[layer_idx](x)
            
        # Student layer output
        student_out = self.student.layers[layer_idx](x)
        
        # Match outputs
        if teacher_out.size() != student_out.size():
            # Add adapter
            adapter = nn.Linear(
                student_out.size(-1), 
                teacher_out.size(-1)
            ).to(student_out.device)
            student_out = adapter(student_out)
            
        loss = F.mse_loss(student_out, teacher_out.detach())
        
        return loss
    
    def train_step(self, x, labels):
        """
        Train student layer by layer.
        """
        # Freeze already distilled layers
        for i in range(self.current_layer):
            for param in self.student.layers[i].parameters():
                param.requires_grad = False
                
        # Train current layer
        loss = self.distill_single_layer(x, self.current_layer)
        
        # Advance to next layer
        self.current_layer = (self.current_layer + 1) % self.student.num_layers
        
        return loss

Distillation for Different Architectures

1. BERT Distillation

class BERTDistillation:
    """
    Distill BERT to smaller model.
    """
    
    def __init__(self, teacher_model, student_model, temperature=2.0):
        self.teacher = teacher_model
        self.student = student_model
        self.temperature = temperature
        
    def distill_bert(self, input_ids, attention_mask, labels):
        """
        Distill BERT model.
        """
        # Teacher predictions
        with torch.no_grad():
            teacher_output = self.teacher(
                input_ids=input_ids,
                attention_mask=attention_mask
            )
            
        # Student predictions
        student_output = self.student(
            input_ids=input_ids,
            attention_mask=attention_mask
        )
        
        # MLM loss
        loss_fct = nn.CrossEntropyLoss()
        mlm_loss = loss_fct(
            student_output.logits.view(-1, 30522),
            labels.view(-1)
        )
        
        # Distillation loss
        soft_teacher = F.softmax(teacher_output.logits / self.temperature, dim=-1)
        soft_student = F.log_softmax(student_output.logits / self.temperature, dim=-1)
        
        kd_loss = F.kl_div(soft_student, soft_teacher, reduction='batchmean')
        kd_loss = kd_loss * (self.temperature ** 2)
        
        # Combined
        total_loss = 0.5 * mlm_loss + 0.5 * kd_loss
        
        return total_loss

2. Transformer to CNN Distillation

class CrossArchitectureDistillation:
    """
    Distill between different architectures.
    """
    
    def __init__(self, teacher, student):
        self.teacher = teacher  # Transformer
        self.student = student  # CNN
        
    def distill_cross_architecture(self, x, labels):
        """
        Match features across architectures.
        """
        # Get teacher features (with attention)
        with torch.no_grad():
            teacher_features = self.teacher.extract_features(x)
            teacher_logits = self.teacher(x)
            
        # Get student features
        student_features = self.student.extract_features(x)
        student_logits = self.student(x)
        
        # Feature matching with projection
        projected_student = self.projector(student_features)
        
        feature_loss = F.mse_loss(projected_student, teacher_features)
        
        # Logit matching
        soft_teacher = F.softmax(teacher_logits / 2.0, dim=-1)
        soft_student = F.log_softmax(student_logits / 2.0, dim=-1)
        logit_loss = F.kl_div(soft_student, soft_teacher)
        
        # Hard label loss
        hard_loss = F.cross_entropy(student_logits, labels)
        
        return feature_loss + logit_loss + hard_loss

Complete Training Framework

class DistillationTrainer:
    """
    Complete training framework for knowledge distillation.
    """
    
    def __init__(self, teacher, student, config):
        self.teacher = teacher
        self.student = student
        self.config = config
        
        # Freeze teacher
        for param in teacher.parameters():
            param.requires_grad = False
            
        self.optimizer = torch.optim.AdamW(
            student.parameters(), 
            lr=config.lr,
            weight_decay=config.weight_decay
        )
        
    def train_epoch(self, dataloader):
        """
        Train for one epoch.
        """
        self.student.train()
        total_loss = 0
        
        for batch in dataloader:
            x = batch['input'].cuda()
            labels = batch['label'].cuda()
            
            # Forward through teacher
            with torch.no_grad():
                teacher_logits = self.teacher(x)
                
            # Forward through student
            student_logits = self.student(x)
            
            # Compute loss
            loss = self.compute_distillation_loss(
                student_logits, teacher_logits, labels
            )
            
            # Backward
            self.optimizer.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(
                self.student.parameters(), max_norm=1.0
            )
            self.optimizer.step()
            
            total_loss += loss.item()
            
        return total_loss / len(dataloader)
    
    def compute_distillation_loss(self, student_logits, teacher_logits, labels):
        """
        Compute combined loss.
        """
        T = self.config.temperature
        alpha = self.config.alpha
        
        # Soft loss
        soft_teacher = F.softmax(teacher_logits / T, dim=-1)
        soft_student = F.log_softmax(student_logits / T, dim=-1)
        soft_loss = F.kl_div(soft_student, soft_teacher, reduction='batchmean')
        soft_loss = soft_loss * (T ** 2)
        
        # Hard loss
        hard_loss = F.cross_entropy(student_logits, labels)
        
        # Combine
        return alpha * soft_loss + (1 - alpha) * hard_loss

Practical Applications

1. Edge Deployment

class EdgeDistiller:
    """
    Distill large models for edge devices.
    """
    
    def __init__(self, large_model):
        self.teacher = large_model
        
    def create_student(self, target_params_ratio=0.1):
        """
        Create smaller student model.
        """
        # Estimate target size
        teacher_params = sum(p.numel() for p in self.teacher.parameters())
        target_params = int(teacher_params * target_params_ratio)
        
        # Create smaller architecture
        student = SmallModel(
            hidden_dim=int(self.teacher.hidden_dim * 0.5),
            num_layers=int(self.teacher.num_layers * 0.5)
        )
        
        return student
    
    def distill_for_edge(self, data_loader, epochs=10):
        """
        Distill model for edge deployment.
        """
        student = self.create_student()
        
        trainer = DistillationTrainer(
            self.teacher, student,
            Config(temperature=4.0, alpha=0.7)
        )
        
        for epoch in range(epochs):
            loss = trainer.train_epoch(data_loader)
            print(f"Epoch {epoch}: Loss = {loss:.4f}")
            
        return student

2. Model Ensemble Compression

class EnsembleDistiller:
    """
    Compress ensemble into single model.
    """
    
    def __init__(self, ensemble_models):
        self.teachers = ensemble_models
        
    def compress_ensemble(self, student_model, data_loader):
        """
        Compress ensemble of teachers into single student.
        """
        student_model.train()
        optimizer = torch.optim.Adam(student_model.parameters())
        
        for batch in data_loader:
            x = batch['input']
            labels = batch['label']
            
            # Average teacher predictions
            teacher_preds = []
            for teacher in self.teachers:
                with torch.no_grad():
                    pred = teacher(x)
                    teacher_preds.append(pred)
                    
            avg_teacher = torch.stack(teacher_preds).mean(dim=0)
            
            # Student prediction
            student_pred = student_model(x)
            
            # Loss
            soft_teacher = F.softmax(avg_teacher / 3.0, dim=-1)
            soft_student = F.log_softmax(student_pred / 3.0, dim=-1)
            
            loss = F.kl_div(soft_student, soft_teacher)
            
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
        return student_model

Best Practices

1. Temperature Selection

def select_temperature(teacher_model, data_sample):
    """
    Select optimal temperature.
    """
    with torch.no_grad():
        logits = teacher_model(data_sample)
        
    # Compute probability distribution
    probs = F.softmax(logits, dim=-1)
    
    # Entropy measures how "spread out" the distribution is
    entropy = -(probs * torch.log(probs + 1e-10)).sum(dim=-1).mean()
    
    # Lower entropy = more confident = lower temperature
    # Higher entropy = less confident = higher temperature
    if entropy < 1.0:
        return 2.0  # Lower temperature
    elif entropy < 2.0:
        return 4.0  # Medium temperature
    else:
        return 8.0  # Higher temperature

2. Layer Mapping

class LayerMapper:
    """
    Map teacher layers to student layers.
    """
    
    @staticmethod
    def optimal_mapping(teacher_num_layers, student_num_layers):
        """
        Find optimal layer mapping.
        
        Returns list of (teacher_idx, student_idx) pairs.
        """
        mapping = []
        
        if teacher_num_layers == student_num_layers:
            # Direct mapping
            return [(i, i) for i in range(teacher_num_layers)]
            
        elif student_num_layers < teacher_num_layers:
            # Uniform spacing
            step = teacher_num_layers / student_num_layers
            
            for i in range(student_num_layers):
                teacher_idx = int(i * step)
                mapping.append((teacher_idx, i))
                
        return mapping

3. Training Schedule

class DistillationSchedule:
    """
    Progressive training schedule for distillation.
    """
    
    @staticmethod
    def warmup_then_distill(config):
        """
        Start with hard labels, then add soft labels.
        """
        epochs = config.epochs
        
        # Epoch 0-1: Hard labels only
        if config.current_epoch < 1:
            return {'alpha': 0.0}
            
        # Epoch 1-3: Increasing soft labels
        elif config.current_epoch < 3:
            return {'alpha': 0.3}
            
        # After: Full distillation
        else:
            return {'alpha': 0.7}

Comparison

Method When to Use Trade-offs
Response-based Same architecture Simple, effective
Feature-based Different architectures More complex
Relation-based Capture relationships Computationally heavy
Self-distillation Same model, improve No teacher needed
Multi-teacher Multiple good models Better quality

Future Directions in 2026

Emerging Research

  1. Task-Aware Distillation: Customize distillation for specific tasks
  2. Continuous Distillation: Ongoing knowledge transfer
  3. Gradient-Based Selection: Which examples matter most
  4. Foundation Model Distillation: Compressing massive models

Resources

Conclusion

Knowledge distillation has become essential for deploying large AI models efficiently. By transferring not just predictions but the reasoning patterns from teacher models, students can achieve remarkable performance despite having far fewer parameters.

The key insightsโ€”soft targets, temperature scaling, and feature matchingโ€”provide a robust toolkit for model compression. Whether you’re distilling BERT for mobile deployment or compressing a GPT model for inference, knowledge distillation offers a principled approach to efficiency without sacrificing too much capability.

As models continue to grow, distillation becomes increasingly critical. The future lies in intelligent, task-specific distillation strategies that maximize knowledge transfer while minimizing computational overhead.

Comments