Skip to main content
โšก Calmops

Real-time ML Model Deployment with Rust and ONNX Runtime

Building High-Performance Inference Engines in Production

Machine learning models are everywhere, but getting them into production with low latency and high reliability is a different beast. Python frameworks like TensorFlow and PyTorch dominate model development, but when it comes to deployment, you need speed, safety, and predictability. This is where Rust and ONNX Runtime shine.

In this article, we’ll explore how to deploy real-time ML models using Rust and ONNX Runtime, focusing on practical production patterns that give you sub-millisecond inference latency without sacrificing safety or maintainability.

Why Rust for ML Deployment?

You might ask: “Why not just use Python for inference too?” The answer lies in three critical factors:

Performance: Rust compiles to native code with zero-cost abstractions. No garbage collection pauses. No GIL (Global Interpreter Lock). Inference latency in Rust can be 5โ€“10x faster than Python for the same model.

Memory Safety: Rust’s borrow checker prevents data races and memory leaks at compile time. In production, this means fewer crashes and more predictable resource usage.

Concurrency: Rust’s async/await system (think Tokio) lets you handle hundreds of concurrent inference requests on a single thread, making it ideal for microservices and edge deployments.

Python is great for experimentation and training. Rust is great for deploying at scale.

What is ONNX Runtime?

ONNX (Open Neural Network Exchange) is an open standard for representing machine learning models. It’s a format-agnostic interchange, meaning you can train a model in PyTorch, export it to ONNX format, and run it in production with a different runtimeโ€”completely decoupled from the training framework.

ONNX Runtime is a high-performance inference engine maintained by Microsoft that optimizes model execution across CPUs, GPUs, and specialized hardware (TPUs, NPUs). It supports model optimization, quantization, and multi-platform deployment.

The Rust bindings for ONNX Runtime give you access to this power from within Rust.

Setting Up Your Rust Project

Start with a fresh Rust project:

cargo new ml_inference_service
cd ml_inference_service

Add the necessary dependencies to your Cargo.toml:

// filepath: Cargo.toml
[dependencies]
ort = "2.0"  # ONNX Runtime Rust bindings
tokio = { version = "1", features = ["full"] }
axum = "0.7"  # Web framework for serving predictions
serde = { version = "1", features = ["derive"] }
serde_json = "1"
ndarray = "0.15"  # For tensor manipulation
anyhow = "1"
tracing = "0.1"
tracing-subscriber = "0.3"

[profile.release]
opt-level = 3
lto = true
codegen-units = 1

The ort crate is the primary Rust binding for ONNX Runtime. It’s well-maintained and provides a safe, idiomatic interface.

Building Your First Inference Engine

Let’s start with a simple example: loading a pre-trained model and running a single inference.

// filepath: src/lib.rs
use anyhow::Result;
use ndarray::Array2;
use ort::{Environment, SessionBuilder};
use std::sync::Arc;

pub struct InferenceEngine {
    session: Arc<ort::Session<'static>>,
}

impl InferenceEngine {
    /// Load a model from disk
    pub fn new(model_path: &str) -> Result<Self> {
        // Initialize the ONNX Runtime environment (do this once per application)
        let environment = Environment::builder()
            .with_name("InferenceEngine")
            .build()?
            .into_arc();

        // Build a session from the model file
        let session = SessionBuilder::new(&environment)?
            .with_optimization_level(ort::GraphOptimizationLevel::Level3)?
            .with_intra_threads(4)?
            .commit_from_file(model_path)?;

        Ok(InferenceEngine {
            session: Arc::new(session),
        })
    }

    /// Run inference on input data
    pub fn predict(&self, input: &[f32]) -> Result<Vec<f32>> {
        // Reshape input to match model expectations (e.g., [1, 784] for MNIST)
        let input_array = Array2::from_shape_vec((1, input.len()), input.to_vec())?;

        // Run inference
        let outputs = self.session.run(vec![input_array.view().into()])?;

        // Extract the output tensor
        let output_tensor = outputs[0].try_extract_tensor::<f32>()?;
        let output_data = output_tensor.to_owned().into_shape(output_tensor.len())?;

        Ok(output_data.to_vec())
    }
}

Now let’s use it in a simple application:

// filepath: src/main.rs
use anyhow::Result;

mod lib;
use lib::InferenceEngine;

#[tokio::main]
async fn main() -> Result<()> {
    // Initialize tracing for logging
    tracing_subscriber::fmt::init();

    // Load the model
    let engine = InferenceEngine::new("path/to/model.onnx")?;
    tracing::info!("Model loaded successfully");

    // Create dummy input (e.g., 784 features for MNIST)
    let input = vec![0.5; 784];

    // Run inference
    let output = engine.predict(&input)?;
    println!("Prediction output: {:?}", &output[0..10]); // Print first 10 values

    Ok(())
}

Scaling to Production: A Web Service

In production, you’ll typically expose your inference engine via HTTP. Let’s build a service using Axum, Tokio’s modern web framework:

// filepath: src/service.rs
use axum::{
    extract::State,
    http::StatusCode,
    routing::post,
    Json, Router,
};
use serde::{Deserialize, Serialize};
use std::sync::Arc;
use crate::lib::InferenceEngine;

#[derive(Deserialize)]
pub struct PredictRequest {
    pub features: Vec<f32>,
}

#[derive(Serialize)]
pub struct PredictResponse {
    pub predictions: Vec<f32>,
    pub latency_ms: u128,
}

pub struct AppState {
    pub engine: Arc<InferenceEngine>,
}

pub async fn predict_handler(
    State(state): State<AppState>,
    Json(payload): Json<PredictRequest>,
) -> Result<Json<PredictResponse>, (StatusCode, String)> {
    let start = std::time::Instant::now();

    let predictions = state
        .engine
        .predict(&payload.features)
        .map_err(|e| (StatusCode::BAD_REQUEST, e.to_string()))?;

    let latency_ms = start.elapsed().as_millis();

    Ok(Json(PredictResponse {
        predictions,
        latency_ms,
    }))
}

pub fn create_router(engine: Arc<InferenceEngine>) -> Router {
    let state = AppState { engine };

    Router::new()
        .route("/predict", post(predict_handler))
        .with_state(state)
}

Update your main.rs to serve the web service:

// filepath: src/main.rs
use anyhow::Result;
use std::sync::Arc;

mod lib;
mod service;

use lib::InferenceEngine;
use service::create_router;

#[tokio::main]
async fn main() -> Result<()> {
    tracing_subscriber::fmt::init();

    // Load model once
    let engine = Arc::new(InferenceEngine::new("path/to/model.onnx")?);
    tracing::info!("Model loaded successfully");

    let app = create_router(engine);

    // Start server
    let listener = tokio::net::TcpListener::bind("0.0.0.0:3000").await?;
    tracing::info!("Server listening on http://0.0.0.0:3000");

    axum::serve(listener, app).await?;

    Ok(())
}

Test it with curl:

curl -X POST http://localhost:3000/predict \
  -H "Content-Type: application/json" \
  -d '{"features": [0.5, 0.2, 0.8, ...]}'

Production Considerations

1. Model Optimization

Before deploying, optimize your ONNX model:

let session = SessionBuilder::new(&environment)?
    .with_optimization_level(ort::GraphOptimizationLevel::Level3)?
    .with_graph_optimization_enabled(true)?
    .commit_from_file(model_path)?;

2. Quantization

Quantization reduces model size and inference time by converting float32 weights to int8. ONNX Runtime handles this transparently:

// Load a quantized model
let session = SessionBuilder::new(&environment)?
    .commit_from_file("path/to/model_quantized.onnx")?;

3. Batching

For throughput, batch multiple requests:

pub async fn predict_batch(&self, inputs: Vec<Vec<f32>>) -> Result<Vec<Vec<f32>>> {
    let batch_size = inputs.len();
    let input_len = inputs[0].len();

    // Flatten all inputs into a single array
    let flattened: Vec<f32> = inputs.into_iter().flatten().collect();
    let batched = Array2::from_shape_vec((batch_size, input_len), flattened)?;

    let outputs = self.session.run(vec![batched.view().into()])?;
    
    // Parse outputs for each item in the batch
    let output_tensor = outputs[0].try_extract_tensor::<f32>()?;
    Ok(output_tensor.outer_iter().map(|row| row.to_vec()).collect())
}

4. Error Handling and Monitoring

Use structured logging to track inference performance:

pub async fn predict_handler(
    State(state): State<AppState>,
    Json(payload): Json<PredictRequest>,
) -> Result<Json<PredictResponse>, (StatusCode, String)> {
    let start = std::time::Instant::now();

    let predictions = match state.engine.predict(&payload.features) {
        Ok(p) => p,
        Err(e) => {
            tracing::error!(error = %e, "Inference failed");
            return Err((StatusCode::INTERNAL_SERVER_ERROR, e.to_string()));
        }
    };

    let latency_ms = start.elapsed().as_millis();
    tracing::info!(latency_ms, "Inference completed");

    Ok(Json(PredictResponse {
        predictions,
        latency_ms,
    }))
}

5. GPU Support

ONNX Runtime can use CUDA for GPU acceleration. Enable it in your Cargo.toml:

[dependencies]
ort = { version = "2.0", features = ["cuda"] }

Then configure the session:

let session = SessionBuilder::new(&environment)?
    .with_execution_providers([ExecutionProvider::CUDA(Default::default())])
    .commit_from_file(model_path)?;

Benchmarking

Create a simple benchmark to measure inference latency:

// filepath: examples/benchmark.rs
use std::time::Instant;

#[tokio::main]
async fn main() {
    let engine = InferenceEngine::new("model.onnx").unwrap();
    let input = vec![0.5; 784];

    let iterations = 1000;
    let start = Instant::now();

    for _ in 0..iterations {
        let _ = engine.predict(&input).unwrap();
    }

    let elapsed = start.elapsed();
    let avg_latency = elapsed.as_millis() as f64 / iterations as f64;

    println!("Average latency: {:.3}ms", avg_latency);
    println!("Throughput: {:.0} req/s", 1000.0 / avg_latency);
}

Run it with:

cargo run --example benchmark --release

Summary

Rust + ONNX Runtime delivers a compelling stack for production ML inference:

  • Predictable latency: No garbage collection pauses; sub-millisecond inference is realistic.
  • Safety: Memory safety and concurrency safety by design.
  • Flexibility: ONNX works across frameworksโ€”train in PyTorch, deploy in Rust.
  • Scale: Async/await lets you handle thousands of concurrent requests efficiently.

Whether you’re building a real-time recommendation engine, a fraud detection service, or an edge AI application, Rust and ONNX Runtime give you the performance, safety, and reliability production demands.

Resources

Comments