Skip to main content
โšก Calmops

AI Model Compression: Quantization, Pruning, and Distillation

Introduction

Deploying large AI models requires significant computational resources. Model compression techniques reduce size and latency while maintaining accuracy. This guide covers three main approaches: Quantization, Pruning, and Distillation.

Understanding Model Compression

Why compress models:

  • Reduced Memory: Smaller model files
  • Faster Inference: Fewer computations
  • Lower Cost: Cheaper deployment
  • Edge Deployment: Run on limited devices
# Compression impact example
model_sizes = {
    "FP32 (baseline)": "7B params ร— 4 bytes = 28 GB",
    "INT8": "7B ร— 1 byte = 7 GB (4x smaller)",
    "INT4": "7B ร— 0.5 bytes = 3.5 GB (8x smaller)",
    "Pruned 50%": "3.5B params ร— 4 bytes = 14 GB",
    "Distilled": "3B params ร— 4 bytes = 12 GB"
}

Quantization

Quantization reduces precision of weights and activations.

Post-Training Quantization (PTQ)

import torch
from torch.quantization import quantize_dynamic

# Load model
model = torch.load("model.pt")
model.eval()

# Dynamic quantization (simplest)
quantized_model = quantize_dynamic(
    model,
    {torch.nn.Linear},  # Quantize linear layers
    dtype=torch.qint8
)

# Save quantized model
torch.save(quantized_model, "model_int8.pt")

Static Quantization

import torch.quantization

# Prepare model for static quantization
model = Model()
model.eval()
model.qconfig = torch.quantization.get_default_qconfig('fbgemm')

# Fuse operations for better quantization
model_fused = torch.quantization.fuse_modules(
    model,
    [['conv1', 'bn1', 'relu']]
)

# Calibration with representative dataset
calibration_data = load_calibration_data()
model_fused.eval()
with torch.no_grad():
    for data in calibration_data:
        model_fused(data)

# Apply quantization
model_quantized = torch.quantization.convert(model_fused)

INT8 Quantization with ONNX

import onnx
from onnxruntime.quantization import quantize_dynamic, quantize_static

# Dynamic quantization
quantize_dynamic(
    "model.onnx",
    "model_int8.onnx",
    weight_quantization_type=onnxruntime.quantization.QuantType.QInt8
)

# Static quantization with calibration
quantize_static(
    "model.onnx",
    "model_int8_static.onnx",
    quantization_params="quantization_params.json"
)

GPTQ Quantization

from transformers import AutoModelForCausalLM, AutoTokenizer
from optimum.gptq import GPTQQuantizer

# Load model
model = AutoModelForCausalLM.from_pretrained(
    "meta-llama/Llama-2-7b-hf",
    device_map="auto"
)

# Configure GPTQ
quantizer = GPTQQuantizer(
    bits=4,
    group_size=128,
    desc_act=False
)

# Quantize
quantized_model, quantization_config = quantizer.quantize_model(
    model,
    tokenizer
)

# Save
quantized_model.save_pretrained("llama-2-7b-4bit")

AWQ Quantization

from awq import AutoAWQForCausalLM

# Quantize with AWQ
model = AutoAWQForCausalLM.from_pretrained(
    "meta-llama/Llama-2-7b-hf",
    quantize_config={"w_bit": 4, "q_group_size": 128}
)

# Find quantization parameters
model.quantize(tokenizer, quant_samples=512)

# Save
model.save_quantized("llama-2-7b-awq")

GGML/llama.cpp Quantization

# Using llama.cpp for quantization
# Install: pip install llama-cpp-python

from llama_cpp import Llama

# Load and quantize
llm = Llama(
    model_path="models/llama-7b-f16.gguf",
    n_quantize=8192,  # Q5_K_M quantization
    n_ctx=4096,
    n_gpu_layers=35
)

# Or load pre-quantized model
llm = Llama(model_path="models/llama-7b-q4_k_m.gguf")

Pruning

Pruning removes unnecessary weights or neurons.

Weight Pruning

import torch.nn.utils.prune as prune

# L1 unstructured pruning
prune.l1_unstructured(
    model.layer1.weight,
    name="weight",
    amount=0.5  # Remove 50% of weights
)

# Magnitude pruning
prune.global_unstructured(
    parameters_to_prune=[
        (model.conv1, 'weight'),
        (model.conv2, 'weight'),
    ],
    pruning_method=prune.L1Unstructured,
    amount=0.2
)

# Make pruning permanent
prune.remove(model.conv1, 'weight')

Structured Pruning

# Prune entire neurons
def prune_neurons(model, sparsity=0.5):
    for name, module in model.named_modules():
        if isinstance(module, torch.nn.Linear):
            # Calculate L2 norm of each neuron
            weights = module.weight.data
            norms = torch.norm(weights, dim=1)
            
            # Find threshold
            threshold = torch.kthvalue(
                norms, 
                int(norms.shape[0] * sparsity)
            )[0]
            
            # Create mask
            mask = norms > threshold
            
            # Apply pruning
            module.weight.data *= mask.unsqueeze(1)
            
    return model

# Prune attention heads
def prune_attention_heads(attention, num_heads_to_keep):
    head_size = attention.head_dim
    num_heads = attention.num_heads
    
    # Calculate importance scores
    scores = torch.norm(attention.query.weight, dim=(0, 2))
    
    # Keep top heads
    _, keep_indices = torch.topk(scores, num_heads_to_keep)
    
    return keep_indices

Lottery Ticket Hypothesis

# Find winning tickets
def find_lottery_ticket(model, train_loader):
    # Train once to get final weights
    train_model(model, train_loader)
    
    # Get pruning mask (top 20% weights)
    mask = get_top_k_percent_mask(model, percent=20)
    
    # Reset to initial weights
    model.reset_weights()
    
    # Apply mask
    apply_mask(model, mask)
    
    # Retrain - should perform as well as full model
    train_model(model, train_loader)
    
    return model, mask

Knowledge Distillation

Train smaller model to mimic larger model.

Basic Distillation

import torch.nn.functional as F

class DistillationLoss(nn.Module):
    def __init__(self, temperature=4.0, alpha=0.5):
        super().__init__()
        self.temperature = temperature
        self.alpha = alpha
        
    def forward(self, student_logits, teacher_logits, labels):
        # Soft targets from teacher
        soft_targets = F.softmax(teacher_logits / self.temperature, dim=-1)
        soft_loss = F.kl_div(
            F.log_softmax(student_logits / self.temperature, dim=-1),
            soft_targets,
            reduction='batchmean'
        ) * (self.temperature ** 2)
        
        # Hard targets
        hard_loss = F.cross_entropy(student_logits, labels)
        
        # Combined loss
        return self.alpha * soft_loss + (1 - self.alpha) * hard_loss

# Training loop
def train_with_distillation(student, teacher, train_loader):
    teacher.eval()
    student.train()
    
    criterion = DistillationLoss(temperature=4.0, alpha=0.7)
    optimizer = torch.optim.Adam(student.parameters())
    
    for batch in train_loader:
        inputs, labels = batch
        
        # Get teacher predictions (no grad)
        with torch.no_grad():
            teacher_logits = teacher(inputs)
        
        # Student predictions
        student_logits = student(inputs)
        
        # Distillation loss
        loss = criterion(student_logits, teacher_logits, labels)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

Feature Distillation

# Distill intermediate features
class FeatureDistillationLoss(nn.Module):
    def __init__(self, feature_dim):
        super().__init__()
        self.projection = nn.Linear(feature_dim, feature_dim)
        
    def forward(self, student_features, teacher_features):
        # Project to same dimension
        student_proj = self.projection(student_features)
        
        # MSE loss on features
        return F.mse_loss(student_proj, teacher_features)

# Use intermediate layers
class DistillationTrainer:
    def __init__(self, student, teacher):
        self.student = student
        self.teacher = teacher
        
        # Feature alignment layers
        self.feature_loss = FeatureDistillationLoss(768)
        
    def train_step(self, batch):
        inputs, labels = batch
        
        # Get intermediate features
        student_features = self.student.get_features(inputs)
        teacher_features = self.teacher.get_features(inputs)
        
        # Feature distillation loss
        feat_loss = self.feature_loss(student_features, teacher_logits)
        
        # Output distillation loss
        output_loss = self.distillation_loss(
            self.student(inputs), 
            self.teacher(inputs)
        )
        
        return feat_loss + output_loss

Self-Distillation

# Self-distillation: train model to mimic itself
def self_distillation(model, train_loader, num_stages=3):
    """Progressively train with intermediate checkpoints."""
    
    # Initial training
    model = train(model, train_loader)
    checkpoints = [copy(model)]
    
    for stage in range(1, num_stages):
        # Train new model to mimic previous stage
        prev_model = checkpoints[-1]
        
        # Use previous as teacher
        model = create_model_like(model)  # New initialization
        model = train_with_distillation(
            model, 
            prev_model, 
            train_loader
        )
        
        checkpoints.append(model)
    
    return checkpoints[-1]

Combined Compression

# Complete compression pipeline
def compress_model(model, train_loader):
    # Step 1: Prune
    model = prune_model(model, sparsity=0.5)
    
    # Step 2: Quantize
    model = quantize_model(model)
    
    # Step 3: Distill to smaller architecture
    student = create_student_model()
    student = train_with_distillation(student, model, train_loader)
    
    # Step 4: Quantize student
    student = quantize_model(student)
    
    return student

Comparison

Technique Size Reduction Accuracy Impact Speed Improvement
FP16 2x Minimal ~2x
INT8 4x 1-2% loss ~4x
INT4 8x 3-5% loss ~8x
Pruning 50% 2x 1-3% loss ~2x
Distillation 2-3x Varies ~2-3x

When to Use Each

Quantization

  • Quick deployment
  • Memory constraints
  • Latency requirements

Pruning

  • Known sparse patterns
  • Custom architectures
  • Hardware that supports sparsity

Distillation

  • Different architecture
  • Maintain accuracy
  • Ensemble compression

Bad Practices

Bad Practice 1: No Accuracy Validation

# Bad: Blindly quantize
quantized = quantize(model, bits=4)
# Lost all accuracy!

# Good: Validate accuracy
acc = evaluate(quantized, test_loader)
if acc < threshold:
    print(f"Accuracy too low: {acc}")

Bad Practice 2: Ignoring Calibration

# Bad: No calibration for static quantization
quantized = quantize_static(model, bits=8)  # Wrong!

# Good: Use representative data
calibrate(model, calibration_data)

Bad Practice 3: Over-Pruning

# Bad: Too much pruning
pruned = prune(model, amount=0.95)  # Destroyed!

# Good: Start conservatively
for amount in [0.3, 0.5, 0.7]:
    pruned = prune(model, amount=amount)
    if validate(pruned) < threshold:
        break

Good Practices

Quantization Best Practices

# Good: Use appropriate quantization
if model_size > 10B:  # Use INT4 for large models
    quantize(model, bits=4, method="gptq")
else:  # Use INT8 for smaller
    quantize(model, bits=8, method="dynamic")

Validation Pipeline

# Good: Comprehensive validation
def validate_compression(model, test_loader):
    metrics = {}
    
    # Accuracy
    metrics['accuracy'] = evaluate_accuracy(model, test_loader)
    
    # Latency
    metrics['latency'] = measure_latency(model)
    
    # Memory
    metrics['memory'] = measure_memory(model)
    
    # Check thresholds
    if metrics['accuracy'] < 0.95 * baseline:
        return False, metrics
    
    return True, metrics

Gradual Compression

# Good: Progressive compression
def compress_gradually(model):
    results = []
    
    for method, params in [
        ("fp16", {}),
        ("int8", {}),
        ("prune", {"amount": 0.3}),
        ("prune", {"amount": 0.5}),
        ("int4", {})
    ]:
        compressed = apply_compression(model, method, params)
        accuracy = evaluate(compressed)
        results.append((method, params, accuracy))
        
        # Check if acceptable
        if accuracy > target_accuracy:
            return compressed
    
    return model

External Resources

Comments