Introduction
Real-time ML inference requires sub-100ms latency at scale. This guide covers optimization techniques from model compression to edge deployment.
This guide covers achieving low-latency ML inference in production.
Latency Breakdown
Where Time Goes in Inference
Total Latency = Input Processing + Model Inference + Output Processing
Example (image classification):
Input processing (resize): 10ms
Data loading/transfer: 5ms
Model inference: 50ms
Post-processing: 10ms
Network roundtrip (if remote): 20ms
โโโโโโโโโโโโโโโโโโโโโโโโโโโโ
Total: 95ms
Latency Requirements by Use Case
Application Target Latency Tolerance
โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
Real-time recommendation < 50ms Strict
Search ranking < 100ms Moderate
Voice assistant < 100ms Strict
Video analysis < 500ms Moderate
Autonomous vehicle < 10ms Extremely strict
Mobile app < 200ms Moderate
Model Optimization Techniques
Quantization
import tensorflow as tf
# Quantization-aware training (QAT)
def create_quantized_model():
# Simulate quantization during training
quantize_model = tf.keras.Sequential([
tf.keras.layers.Dense(10, activation='relu'),
tf.keras.quantization.quantize_layer(
tf.keras.layers.Dense(1)
),
])
return quantize_model
# Post-training quantization
def quantize_trained_model(model):
converter = tf.lite.TFLiteConverter.from_keras_model(model)
converter.optimizations = [tf.lite.Optimize.DEFAULT]
converter.target_spec.supported_ops = [
tf.lite.OpsSet.TFLITE_BUILTINS_INT8
]
quantized_model = converter.convert()
return quantized_model
# Impact: 4x faster, 4x smaller, <1% accuracy loss
Pruning
import tensorflow_model_optimization as tfmot
def prune_model(model, sparsity=0.5):
pruning_params = {
'pruning_schedule': tfmot.sparsity.keras.PruningSched...Monthly(),
'block_size': 2,
'block_pooling_type': 'AVG'
}
pruned_model = tfmot.sparsity.keras.prune_low_magnitude(
model,
**pruning_params
)
pruned_model.compile(
optimizer='adam',
loss='sparse_categorical_crossentropy',
metrics=['accuracy']
)
pruned_model.fit(
train_data, train_labels,
epochs=10,
validation_data=(val_data, val_labels)
)
# Remove pruning wrapper
final_model = tfmot.sparsity.keras.strip_pruning(pruned_model)
return final_model
# Impact: 10x faster, 10x smaller, 1-2% accuracy loss
Knowledge Distillation
import tensorflow as tf
def distill_model(teacher_model, student_model, x_train, y_train, temperature=3):
"""Train lightweight student to mimic teacher"""
# Teacher predictions
teacher_predictions = teacher_model.predict(x_train, batch_size=128)
class DistillationModel(tf.keras.Model):
def __init__(self, teacher, student):
super().__init__()
self.teacher = teacher
self.student = student
def call(self, inputs):
teacher_logits = self.teacher(inputs, training=False)
student_logits = self.student(inputs, training=True)
return student_logits, teacher_logits
model = DistillationModel(teacher_model, student_model)
def distillation_loss(y_true, y_pred):
student_logits, teacher_logits = y_pred
# Soft targets (high temperature)
teacher_probs = tf.nn.softmax(teacher_logits / temperature)
student_probs = tf.nn.softmax(student_logits / temperature)
# KL divergence
kl_loss = tf.nn.softmax_cross_entropy_with_logits(
teacher_probs,
student_probs
)
# Hard targets (standard loss)
hard_loss = tf.nn.sparse_softmax_cross_entropy_with_logits(
y_true,
student_logits
)
return 0.7 * kl_loss + 0.3 * hard_loss
model.compile(optimizer='adam', loss=distillation_loss)
model.fit(x_train, y_train, epochs=10, batch_size=32)
return student_model
# Impact: 50-100x faster, 1-3% accuracy loss
Batching and Inference Serving
Dynamic Batching
import queue
import threading
import time
class DynamicBatcher:
def __init__(self, model, max_batch_size=32, timeout_ms=100):
self.model = model
self.max_batch_size = max_batch_size
self.timeout_ms = timeout_ms
self.batch_queue = queue.Queue()
# Start batch processor thread
self.processor_thread = threading.Thread(
target=self._process_batches,
daemon=True
)
self.processor_thread.start()
def predict(self, data):
future = threading.Event()
request = {'data': data, 'future': future, 'result': None}
self.batch_queue.put(request)
future.wait()
return request['result']
def _process_batches(self):
while True:
batch_requests = []
start_time = time.time()
# Collect batch
while len(batch_requests) < self.max_batch_size:
timeout = self.timeout_ms / 1000.0 - (time.time() - start_time)
if timeout <= 0:
break
try:
request = self.batch_queue.get(timeout=timeout)
batch_requests.append(request)
except queue.Empty:
break
if batch_requests:
# Inference on batch
batch_data = [r['data'] for r in batch_requests]
batch_predictions = self.model.predict(batch_data)
# Return results
for request, prediction in zip(batch_requests, batch_predictions):
request['result'] = prediction
request['future'].set()
# Performance: 10x throughput improvement
Continuous Batching (LLMs)
class TokenBatcher:
"""Process tokens continuously for LLM inference"""
def __init__(self, model, batch_size=32):
self.model = model
self.batch_size = batch_size
self.requests = {}
self.token_queue = queue.PriorityQueue()
def add_request(self, req_id, input_ids, max_tokens=100):
self.requests[req_id] = {
'input_ids': input_ids,
'generated_ids': input_ids.copy(),
'finished': False,
}
def process_batch(self):
# Select requests that haven't finished
active_reqs = [
(rid, req) for rid, req in self.requests.items()
if not req['finished']
]
if not active_reqs:
return
# Prepare batch (variable length)
batch = [req['generated_ids'] for _, req in active_reqs]
# Inference
next_tokens = self.model.generate_batch(batch)
# Update state
for (rid, _), next_token in zip(active_reqs, next_tokens):
self.requests[rid]['generated_ids'].append(next_token)
if len(self.requests[rid]['generated_ids']) >= max_tokens:
self.requests[rid]['finished'] = True
Hardware Acceleration
GPU Optimization
import torch
def optimize_for_gpu(model):
# Half precision (FP16)
model = model.half()
# CUDA optimizations
model = model.cuda()
# CuDNN benchmark (find best algorithm)
torch.backends.cudnn.benchmark = True
# Compile (PyTorch 2.0+)
model = torch.compile(model, mode="reduce-overhead")
return model
# Impact: 2x faster on modern GPUs
TensorRT Optimization (NVIDIA)
# Convert model to TensorRT
trtexec --onnx=model.onnx \
--saveEngine=model.engine \
--fp16 \
--workspace=4096 \
--avgRuns=10
# Use in inference
import tensorrt as trt
engine = load_engine("model.engine")
context = engine.create_execution_context()
# Run inference (much faster)
output = context.execute_v2([input_data])
Edge Deployment
import tensorflow as tf
# Convert to TFLite for mobile/edge
def create_tflight_model(model):
converter = tf.lite.TFLiteConverter.from_keras_model(model)
# Quantization for edge
converter.optimizations = [tf.lite.Optimize.DEFAULT]
converter.target_spec.supported_ops = [
tf.lite.OpsSet.TFLITE_BUILTINS_INT8
]
tflight_model = converter.convert()
return tflight_model
# Edge runtime
import tensorflow_lite_support as tf_lite
interpreter = tf_lite.Interpreter(model_path="model.tflite")
interpreter.allocate_tensors()
# Fast inference (even on ARM processors)
input_tensor = interpreter.get_input_details()[0]
output_tensor = interpreter.get_output_details()[0]
interpreter.set_tensor(input_tensor['index'], input_data)
interpreter.invoke()
result = interpreter.get_tensor(output_tensor['index'])
Caching and Prefetching
Response Caching
import hashlib
from functools import lru_cache
class PredictionCache:
def __init__(self, max_size=10000, ttl_seconds=3600):
self.cache = {}
self.max_size = max_size
self.ttl_seconds = ttl_seconds
def get_hash(self, data):
"""Generate cache key"""
data_str = str(data)
return hashlib.md5(data_str.encode()).hexdigest()
def get(self, data):
key = self.get_hash(data)
if key in self.cache:
value, timestamp = self.cache[key]
if time.time() - timestamp < self.ttl_seconds:
return value
else:
del self.cache[key]
return None
def set(self, data, result):
key = self.get_hash(data)
self.cache[key] = (result, time.time())
# Evict oldest if over capacity
if len(self.cache) > self.max_size:
oldest_key = min(
self.cache.keys(),
key=lambda k: self.cache[k][1]
)
del self.cache[oldest_key]
# Usage
cache = PredictionCache(max_size=100000)
def predict(data):
# Check cache
cached = cache.get(data)
if cached is not None:
return cached
# Inference
result = model.predict(data)
cache.set(data, result)
return result
# Impact: 10-100x faster for repeated queries
Benchmark Results
Real-World Performance
Model/Config Latency Throughput Memory
โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
ResNet-50 (FP32) 50ms 20 req/s 500MB
ResNet-50 (FP16) 25ms 40 req/s 300MB
ResNet-50 (INT8) 12ms 80 req/s 150MB
ResNet-50 (INT8 + Batch4) 15ms 320 req/s 150MB
ResNet-50 (TFLite mobile) 100ms 10 req/s 50MB
Savings achieved:
โโ FP32 โ INT8: 75% latency, 75% memory reduction
โโ INT8 โ Batched: 4x throughput improvement
Monitoring Real-Time Performance
from prometheus_client import Histogram, Counter
# Metrics
inference_latency = Histogram(
'inference_latency_ms',
'Inference latency in milliseconds',
buckets=[10, 25, 50, 100, 200]
)
cache_hits = Counter('cache_hits', 'Cache hits')
cache_misses = Counter('cache_misses', 'Cache misses')
# Usage
with inference_latency.time():
result = model.predict(data)
Glossary
- Quantization: Reducing bit precision (FP32 โ INT8)
- Pruning: Removing unimportant weights
- Distillation: Teaching smaller model from larger
- Batching: Processing multiple requests together
- Latency: Time from input to output
Comments