Introduction
Edge AI brings machine learning capabilities directly to edge devices, enabling real-time inference without cloud connectivity. This approach reduces latency, preserves privacy, lowers bandwidth costs, and enables new use cases in IoT, mobile, and embedded systems. This guide covers techniques for deploying ML models on edge devices and the frameworks that make it possible.
Why Edge AI?
Benefits of Edge Deployment
- Low Latency: Real-time inference (ms vs. network round-trip seconds)
- Privacy: Data stays on device
- Reliability: Works without internet
- Cost: Reduces cloud computation and bandwidth
- Battery Efficiency: Optimized for mobile
Edge AI vs Cloud AI
| Aspect | Cloud AI | Edge AI |
|---|---|---|
| Latency | 100-500ms | <10ms |
| Bandwidth | High | Low |
| Privacy | Data sent to cloud | Data stays local |
| Reliability | Internet required | Works offline |
| Cost | Compute + transfer | One-time device cost |
| Model Size | Unlimited | Limited by device |
Model Optimization Techniques
Model Quantization
import torch
import torch.quantization
class QuantizedModel:
"""Dynamic quantization for PyTorch models."""
@staticmethod
def dynamic_quantize(model):
"""Post-training dynamic quantization."""
quantized_model = torch.quantization.quantize_dynamic(
model,
{torch.nn.Linear, torch.nn.LSTM},
dtype=torch.qint8
)
return quantized_model
@staticmethod
def static_quantize(model, calibration_data):
"""Post-training static quantization."""
model.eval()
# Prepare for quantization
model.qconfig = torch.quantization.get_default_qconfig('fbgemm')
torch.quantization.prepare(model, inplace=True)
# Calibrate with representative dataset
with torch.no_grad():
for data in calibration_data:
model(data)
# Convert
quantized_model = torch.quantization.convert(model, inplace=False)
return quantized_model
class QuantizationAwareTraining:
"""Simulate quantization during training."""
@staticmethod
def apply_qat(model):
"""Apply quantization-aware training."""
model.qconfig = torch.quantization.get_default_qat_qconfig('fbgemm')
torch.quantization.prepare_qat(model, inplace=True)
# Training loop with fake quantization
return model
@staticmethod
def convert_for_inference(model):
"""Convert QAT model for inference."""
model.eval()
return torch.quantization.convert(model, inplace=False)
def compare_model_sizes(original_model, quantized_model):
"""Compare model sizes before and after quantization."""
import pickle
# Save original
pickle.dump(original_model, open("original.pkl", "wb"))
original_size = len(open("original.pkl", "rb").read()) / (1024 * 1024)
# Save quantized
pickle.dump(quantized_model, open("quantized.pkl", "wb"))
quantized_size = len(open("quantized.pkl", "rb").read()) / (1024 * 1024)
print(f"Original: {original_size:.2f} MB")
print(f"Quantized: {quantized_size:.2f} MB")
print(f"Reduction: {(1 - quantized_size/original_size) * 100:.1f}%")
# TensorFlow Lite quantization example
"""
import tensorflow as tf
def representative_dataset():
for data in tf.data.Dataset.from_tensor_slices(train_images).batch(1).take(100):
yield [data tf.float32]
converter = tf.lite.TFLiteConverter.from_keras_model(model)
converter.optimizations = [tf.lite.Optimize.DEFAULT]
converter.representative_dataset = representative_dataset
converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
converter.inference_input_type = tf.int8
converter.inference_output_type = tf.int8
quantized_tflite_model = converter.convert()
"""
Pruning
import torch
import torch.nn.utils.prune as prune
class ModelPruner:
"""Prune neural network weights."""
@staticmethod
def magnitude_prune(model, amount=0.5):
"""Remove weights with smallest magnitude."""
for name, module in model.named_modules():
if isinstance(module, (torch.nn.Linear, torch.nn.Conv2d)):
prune.l1_unstructured(
module,
name='weight',
amount=amount
)
return model
@staticmethod
def structured_prune(model, amount=0.3):
"""Prune entire channels."""
for name, module in model.named_modules():
if isinstance(module, torch.nn.Conv2d):
prune.ln_structured(
module,
name='weight',
amount=amount,
n=2,
dim=0
)
return model
@staticmethod
def iterative_prune(model, train_fn, data_loader,
initial_amount=0.2, final_amount=0.7, steps=10):
"""Gradual iterative pruning."""
for step in range(steps):
amount = initial_amount + (final_amount - initial_amount) * step / steps
# Prune
ModelPruner.magnitude_prune(model, amount)
# Fine-tune
train_fn(model, data_loader)
print(f"Step {step + 1}/{steps}, Amount: {amount:.1%}")
return model
def remove_pruning_reparameterization(model):
"""Remove pruning reparameterization for deployment."""
for module in model.modules():
if hasattr(module, 'weight_orig'):
prune.remove(module, 'weight')
return model
Knowledge Distillation
import torch
import torch.nn as nn
import torch.nn.functional as F
class DistillationLoss(nn.Module):
"""Combined loss for knowledge distillation."""
def __init__(self, temperature=3.0, alpha=0.5):
super().__init__()
self.temperature = temperature
self.alpha = alpha
def forward(self, student_logits, teacher_logits, labels):
# Soft target loss
soft_teacher = F.softmax(teacher_logits / self.temperature, dim=-1)
soft_student = F.log_softmax(student_logits / self.temperature, dim=-1)
soft_loss = F.kl_div(soft_student, soft_teacher,
reduction='batchmean') * (self.temperature ** 2)
# Hard target loss
hard_loss = F.cross_entropy(student_logits, labels)
return self.alpha * soft_loss + (1 - self.alpha) * hard_loss
class Distiller:
"""Knowledge distillation trainer."""
def __init__(self, teacher_model, student_model,
temperature=3.0, alpha=0.5):
self.teacher = teacher_model
self.student = student_model
self.criterion = DistillationLoss(temperature, alpha)
def train_step(self, data, labels):
# Teacher forward (no gradients)
with torch.no_grad():
teacher_logits = self.teacher(data)
# Student forward
student_logits = self.student(data)
# Compute loss
loss = self.criterion(student_logits, teacher_logits, labels)
# Backward
loss.backward()
return loss.item()
def distill(self, train_loader, epochs=10):
"""Train student model using distillation."""
self.student.train()
optimizer = torch.optim.Adam(self.student.parameters())
for epoch in range(epochs):
total_loss = 0
for data, labels in train_loader:
optimizer.zero_grad()
loss = self.train_step(data, labels)
optimizer.step()
total_loss += loss
print(f"Epoch {epoch+1}, Loss: {total_loss/len(train_loader):.4f}")
Edge ML Frameworks
TensorFlow Lite
# TensorFlow Lite conversion and deployment
import tensorflow as tf
def convert_to_tflite(keras_model, quantize=True):
"""Convert Keras model to TensorFlow Lite."""
converter = tf.lite.TFLiteConverter.from_keras_model(keras_model)
converter.optimizations = [tf.lite.Optimize.DEFAULT]
if quantize:
converter.target_spec.supported_types = [tf.float16]
tflite_model = converter.convert()
return tflite_model
def save_tflite_model(model, filename):
"""Save TFLite model to file."""
with open(filename, 'wb') as f:
f.write(model)
print(f"Saved to {filename}")
def load_and_run_tflite(model_path, input_data):
"""Run inference with TFLite interpreter."""
interpreter = tf.lite.Interpreter(model_path=model_path)
interpreter.allocate_tensors()
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()
# Set input
interpreter.set_tensor(
input_details[0]['index'],
input_data
)
# Run
interpreter.invoke()
# Get output
output = interpreter.get_tensor(output_details[0]['index'])
return output
# Android deployment
"""
import org.tensorflow.lite.Interpreter;
import org.tensorflow.lite.support.common.FileUtil;
public class ImageClassifier {
private Interpreter interpreter;
public ImageClassifier() throws IOException {
MappedByteBuffer modelBuffer = FileUtil.loadMappedFile(
this, "model.tflite"
);
interpreter = new Interpreter(modelBuffer);
}
public float[] classify(Bitmap bitmap) {
// Preprocess
float[][][][] input = preprocess(bitmap);
// Run inference
float[][] output = new float[1][NUM_CLASSES];
interpreter.run(input, output);
return output[0];
}
}
"""
ONNX Runtime for Edge
import onnxruntime as ort
import numpy as np
class ONNXEdgeInference:
"""Run ONNX models on edge devices."""
def __init__(self, model_path, providers=['CPUExecutionProvider']):
self.session = ort.InferenceSession(
model_path,
providers=providers
)
self.input_name = self.session.get_inputs()[0].name
self.output_name = self.session.get_outputs()[0].name
def predict(self, input_data):
"""Run prediction."""
result = self.session.run(
[self.output_name],
{self.input_name: input_data}
)
return result[0]
def get_model_info(self):
"""Get model input/output info."""
inputs = self.session.get_inputs()
outputs = self.session.get_outputs()
return {
'inputs': [(i.name, i.shape, i.type) for i in inputs],
'outputs': [(o.name, o.shape, o.type) for o in outputs]
}
def convert_to_onnx(pytorch_model, sample_input, output_path):
"""Convert PyTorch model to ONNX."""
import torch.onnx
torch.onnx.export(
pytorch_model,
sample_input,
output_path,
export_params=True,
opset_version=14,
do_constant_folding=True,
input_names=['input'],
output_names=['output'],
dynamic_axes={
'input': {0: 'batch_size'},
'output': {0: 'batch_size'}
}
)
print(f"Model exported to {output_path}")
PyTorch Mobile
import torch
import torchvision
class PyTorchMobile:
"""PyTorch Mobile inference."""
@staticmethod
def convert_for_mobile(scripted_model, sample_input):
"""Convert model for mobile deployment."""
# Trace model
traced_model = torch.jit.trace(scripted_model, sample_input)
# Optimize for mobile
traced_model = torch.jit.optimize_for_inference(traced_model)
return traced_model
@staticmethod
def save_mobile_model(model, path):
"""Save model for mobile."""
torch.jit.save(model, path)
@staticmethod
def load_mobile_model(path):
"""Load model on mobile device."""
return torch.jit.load(path)
def create_mobile_model(backbone='resnet18', num_classes=1000):
"""Create optimized mobile model."""
model = torchvision.models.__dict__[backbone](
pretrained=False,
num_classes=num_classes
)
# Quantize
model.qconfig = torch.quantization.get_default_qconfig('qnnpack')
torch.quantization.prepare(model, inplace=True)
torch.quantization.convert(model, inplace=True)
return model
Edge Device Deployment
Raspberry Pi Deployment
# Raspberry Pi setup script
"""
# Install dependencies
sudo apt-get update
sudo apt-get install -y python3-pip libopenblas-base libopenmpi-dev
# Install PyTorch
pip3 install torch torchvision --index-url https://download.pytorch.org/whl/armv7l
# Install TensorFlow Lite
pip3 install tflite-runtime
"""
import tflite_runtime.interpreter as tflite
class RaspberryPiDeploy:
"""Deploy ML on Raspberry Pi."""
def __init__(self, model_path, num_threads=4):
self.interpreter = tflite.Interpreter(
model_path,
num_threads=num_threads
)
self.interpreter.allocate_tensors()
self.input_details = self.interpreter.get_input_details()
self.output_details = self.interpreter.get_output_details()
def preprocess_camera(self, frame):
"""Preprocess camera frame for inference."""
import cv2
# Resize
img = cv2.resize(frame, (224, 224))
# Normalize
img = img.astype(np.float32) / 255.0
img = (img - 0.5) / 0.5
# Add batch dimension
img = np.expand_dims(img, axis=0)
return img
def inference(self, input_data):
"""Run inference."""
self.interpreter.set_tensor(
self.input_details[0]['index'],
input_data
)
self.interpreter.invoke()
output = self.interpreter.get_tensor(
self.output_details[0]['index']
)
return output
NVIDIA Jetson Deployment
class JetsonDeploy:
"""Deploy on NVIDIA Jetson devices."""
@staticmethod
def setup_tensorrt(model_path):
"""Optimize model with TensorRT."""
import tensorrt as trt
logger = trt.Logger(trt.Logger.WARNING)
builder = trt.Builder(logger)
network = builder.create_network(
1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)
)
parser = trt.OnnxParser(network, logger)
with open(model_path, 'rb') as f:
parser.parse(f.read())
config = builder.create_builder_config()
config.set_memory_pool_limit(
trt.MemoryPoolType.WORKSPACE,
1 << 30 # 1GB
)
engine = builder.build_serialized_network(network, config)
return engine
@staticmethod
def run_inference(engine, input_data):
"""Run TensorRT inference."""
import pycuda.driver as cuda
import pycuda.autoinit
# Allocate buffers
stream = cuda.Stream()
# Deserialize engine
runtime = cuda.Runtime()
engine = runtime.deserialize_cuda_engine(engine.serialize())
context = engine.create_execution_context()
# Run inference
context.execute_async_v2(
bindings=[int(input_data), int(output_data)],
stream_handle=stream.handle
)
Real-time Edge Applications
Object Detection on Edge
class EdgeObjectDetector:
"""Real-time object detection on edge device."""
def __init__(self, model_path, confidence=0.5):
self.interpreter = tflite.Interpreter(model_path)
self.interpreter.allocate_tensors()
self.confidence = confidence
# Get model info
input_details = self.interpreter.get_input_details()
self.input_size = input_details[0]['shape'][1:3]
def detect(self, frame):
"""Detect objects in frame."""
import cv2
import numpy as np
# Preprocess
img = cv2.resize(frame, self.input_size)
img = img.astype(np.float32) / 255.0
img = np.expand_dims(img, axis=0)
# Detect
self.interpreter.set_tensor(
self.interpreter.get_input_details()[0]['index'],
img
)
self.interpreter.invoke()
# Post-process
boxes = self.interpreter.get_tensor(
self.interpreter.get_output_details()[0]['index']
)
classes = self.interpreter.get_tensor(
self.interpreter.get_output_details()[1]['index']
)
scores = self.interpreter.get_tensor(
self.interpreter.get_output_details()[2]['index']
)
# Filter by confidence
detections = []
for i in range(len(scores[0])):
if scores[0][i] > self.confidence:
detections.append({
'box': boxes[0][i],
'class': int(classes[0][i]),
'score': float(scores[0][i])
})
return detections
def draw_detections(self, frame, detections):
"""Draw detection boxes on frame."""
import cv2
for det in detections:
box = det['box']
y1, x1, y2, x2 = map(int, box)
cv2.rectangle(frame, (x1, y1), (x2, y2), (0, 255, 0), 2)
label = f"Class {det['class']}: {det['score']:.2f}"
cv2.putText(frame, label, (x1, y1 - 10),
cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 2)
return frame
Voice Assistant on Edge
class EdgeVoiceAssistant:
"""Keyword spotting and voice recognition on edge."""
def __init__(self, keyword_model, asr_model):
self.keyword_model = keyword_model
self.asr_model = asr_model
def detect_keyword(self, audio_chunk):
"""Detect wake word."""
# MFCC features
features = self.extract_mfcc(audio_chunk)
# Predict
prediction = self.keyword_model.predict(features)
return prediction > 0.9
def transcribe(self, audio):
"""Transcribe audio."""
# Run ASR
transcription = self.asr_model.transcribe(audio)
return transcription
def process_audio_stream(self, stream):
"""Process continuous audio stream."""
import numpy as np
buffer = []
while True:
chunk = stream.read(1600) # 100ms at 16kHz
if self.detect_keyword(chunk):
print("Wake word detected!")
# Collect command
command_buffer = []
for _ in range(30): # 3 seconds max
command_buffer.append(stream.read(1600))
# Transcribe
audio = np.concatenate(command_buffer)
text = self.transcribe(audio)
yield text
Best Practices
Optimization Checklist
- Quantize model (start with int8)
- Prune unnecessary weights
- Fuse operations where possible
- Use hardware accelerators
- Profile and measure latency
- Test on actual device
- Monitor power consumption
Common Issues
| Issue | Solution |
|---|---|
| High latency | Quantize, prune, or use smaller model |
| Low accuracy | Use quantization-aware training |
| Memory errors | Reduce batch size, use streaming |
| Overheating | Reduce clock speed, batch processing |
Conclusion
Edge AI enables intelligent applications that work offline, preserve privacy, and respond in real-time. By applying model optimization techniques like quantization, pruning, and knowledge distillation, we can deploy powerful ML models on resource-constrained devices. The key is to profile thoroughly on target hardware and iterate on optimizations.
As edge hardware continues to improve, we’ll see increasingly sophisticated AI capabilities on mobile and IoT devices. Start with optimized pre-trained models from TensorFlow Lite and ONNX, then customize as needed for your specific use case.
Comments