Skip to main content

AI in Frontend: Browser AI and WebGPU Revolution

Created: February 23, 2026 Larry Qu 8 min read

Introduction

The browser is undergoing a revolution. With WebGPU now available in all major browsers and frameworks like ONNX Runtime Web and TensorFlow.js maturing, developers can run sophisticated machine learning models directly in the browser—without server-side inference.

This guide explores how to bring AI to your frontend applications in 2025.


What Is Browser AI?

The Basic Concept

Browser AI runs machine learning models directly in client browsers using WebGPU (or WebGL as fallback). This enables features like image recognition, natural language processing, and generative AI without sending data to external servers.

Key Terms

  • WebGPU: Next-generation graphics API for the browser
  • WebGL: Older graphics API (WebGPU fallback)
  • ONNX Runtime Web: Run ONNX models in browsers
  • TensorFlow.js: TensorFlow for JavaScript
  • WebNN: Neural network API for browsers
  • WASM: WebAssembly for near-native performance
  • Model Quantization: Reducing model size for web

Why Browser AI Matters in 2025-2026

Approach Latency Privacy Cost Offline
Server-side API High Low High No
Browser AI Ultra-low Ultra-high Near-zero Yes

WebGPU Fundamentals

Checking Support

// Check WebGPU support
async function checkWebGPUSupport() {
  if (!navigator.gpu) {
    console.warn('WebGPU not supported');
    return false;
  }
  
  const adapter = await navigator.gpu.requestAdapter();
  if (!adapter) {
    console.warn('No GPU adapter available');
    return false;
  }
  
  const device = await adapter.requestDevice();
  console.log('WebGPU supported!');
  
  return { adapter, device };
}

// Run on page load
checkWebGPUSupport().then(({ adapter, device }) => {
  if (device) {
    initializeAI(device);
  }
});

Basic WebGPU Compute

// Simple matrix multiplication on GPU
const shaderCode = `
  @group(0) @binding(0) var<storage, read> matrixA : array<f32>;
  @group(0) @binding(1) var<storage, read> matrixB : array<f32>;
  @group(0) @binding(2) var<storage, read_write> output : array<f32>;
  
  @compute @workgroup_size(8, 8)
  fn main(@builtin(global_invocation_id) id : vec3<u32>) {
    let row = id.y;
    let col = id.x;
    let N = 64u; // Matrix size
    
    var sum = 0.0;
    for (var i = 0u; i < N; i++) {
      sum += matrixA[row * N + i] * matrixB[i * N + col];
    }
    output[row * N + col] = sum;
  }
`;

async function runGPUCompute() {
  const { device } = await checkWebGPUSupport();
  
  // Create shader module
  const shaderModule = device.createShaderModule({ code: shaderCode });
  
  // Create buffers
  const bufferA = createBuffer(device, 64 * 64 * 4);
  const bufferB = createBuffer(device, 64 * 64 * 4);
  const bufferOutput = device.createBuffer({
    size: 64 * 64 * 4,
    usage: GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC | GPUBufferUsage.COPY_DST
  });
  
  // Create compute pipeline
  const pipeline = device.createComputePipeline({
    layout: 'auto',
    compute: { module: shaderModule, entryPoint: 'main' },
    vertex: { module: shaderModule, entryPoint: 'main' }
  });
  
  // ... (submit work)
}

Using ONNX Runtime Web

Loading a Model

<script src="https://cdn.jsdelivr.net/npm/onnxruntime-web/dist/ort.min.js"></script>
import * as ort from 'onnxruntime-web';

class ONNXModel {
  constructor() {
    this.session = null;
    this.modelPath = '/models/bert-base-uncased.onnx';
  }
  
  async load() {
    try {
      this.session = await ort.InferenceSession.create(this.modelPath, {
        executionProviders: ['webgpu'], // Use WebGPU
        graphOptimizationLevel: 'all'
      });
      console.log('Model loaded successfully');
    } catch (error) {
      console.error('Failed to load model:', error);
      // Fallback to WebGL
      this.session = await ort.InferenceSession.create(this.modelPath, {
        executionProviders: ['webgl']
      });
    }
  }
  
  async predict(inputData) {
    if (!this.session) {
      throw new Error('Model not loaded');
    }
    
    const inputTensor = new ort.Tensor('float32', inputData, [1, 3, 224, 224]);
    
    const feeds = { 'input': inputTensor };
    const results = await this.session.run(feeds);
    
    return results.output.data;
  }
}

// Usage
const model = new ONNXModel();
await model.load();
const result = await model.predict(inputData);

Image Classification Example

import * as ort from 'onnxruntime-web';

class ImageClassifier {
  constructor() {
    this.session = null;
    this.classLabels = null;
  }
  
  async initialize() {
    // Load MobileNet model
    this.session = await ort.InferenceSession.create(
      '/models/mobilenetv2-7.onnx',
      { executionProviders: ['webgpu'] }
    );
    
    // Load class labels
    const response = await fetch('/models/imagenet-classes.json');
    this.classLabels = await response.json();
  }
  
  async classifyImage(imageElement) {
    // Preprocess image
    const tensor = await this.preprocessImage(imageElement);
    
    // Run inference
    const feeds = { 'input': tensor };
    const results = await this.session.run(feeds);
    
    // Get top predictions
    const output = results.output.data;
    const top5 = this.getTopPredictions(output, 5);
    
    return top5;
  }
  
  async preprocessImage(image) {
    // Resize to 224x224
    const canvas = document.createElement('canvas');
    canvas.width = 224;
    canvas.height = 224;
    const ctx = canvas.getContext('2d');
    ctx.drawImage(image, 0, 0, 224, 224);
    
    // Get pixel data and normalize
    const imageData = ctx.getImageData(0, 0, 224, 224);
    const floatData = new Float32Array(3 * 224 * 224);
    
    // Normalize to [-1, 1]
    for (let i = 0; i < 224 * 224; i++) {
      floatData[i] = (imageData.data[i * 4] / 255.0 - 0.5) / 0.5;         // R
      floatData[224 * 224 + i] = (imageData.data[i * 4 + 1] / 255.0 - 0.5) / 0.5; // G
      floatData[2 * 224 * 224 + i] = (imageData.data[i * 4 + 2] / 255.0 - 0.5) / 0.5; // B
    }
    
    return new ort.Tensor('float32', floatData, [1, 3, 224, 224]);
  }
  
  getTopPredictions(output, count) {
    const predictions = Array.from(output).map((score, index) => ({
      class: this.classLabels[index],
      score: Math.exp(score) // Softmax
    }));
    
    predictions.sort((a, b) => b.score - a.score);
    return predictions.slice(0, count);
  }
}

TensorFlow.js Integration

Basic Usage

<script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs@4/dist/tf.min.js"></script>
import * as tf from '@tensorflow/tfjs';

// Load a pre-trained model
async function loadModel() {
  const model = await tf.loadLayersModel('/models/model.json');
  console.log('Model loaded:', model.summary());
  return model;
}

// Make predictions
async function predict(inputTensor) {
  const model = await loadModel();
  
  // Run inference
  const prediction = model.predict(inputTensor);
  const result = await prediction.data();
  
  // Cleanup
  inputTensor.dispose();
  prediction.dispose();
  
  return result;
}

// Image preprocessing
function preprocessImage(imageElement) {
  return tf.tidy(() => {
    let tensor = tf.browser.fromPixels(imageElement);
    
    // Resize to model input size
    tensor = tf.image.resizeBilinear(tensor, [224, 224]);
    
    // Normalize to [-1, 1]
    tensor = tensor.div(127.5).sub(1);
    
    // Add batch dimension
    tensor = tensor.expandDims(0);
    
    return tensor;
  });
}

Using WebGPU Backend

import * as tf from '@tensorflow/tfjs';
import '@tensorflow/tfjs-backend-webgpu';

async function initializeWebGPU() {
  // Set WebGPU as priority backend
  tf.setBackend('webgpu');
  await tf.ready();
  
  console.log('TensorFlow.js ready with WebGPU');
  console.log('GPU available:', tf.env().getBool('WEBGPU_AVAILABLE'));
}

// Use WebGPU for operations
async function runInference() {
  const model = await tf.loadLayersModel('/models/mobilenet/webgpu/model.json');
  
  const input = tf.randomNormal([1, 224, 224, 3]);
  
  // Run on GPU
  const result = model.predict(input);
  console.log('Result:', await result.array());
  
  // Cleanup
  input.dispose();
  result.dispose();
}

Real-World Applications

1. Object Detection

import * as tf from '@tensorflow/tfjs';
import * as cocoSsd from '@tensorflow-models/coco-ssd';

async function setupObjectDetection() {
  // Load COCO-SSD model
  const model = await cocoSsd.load({
    base: 'lite_mobilenet_v2' // Smaller, faster model
  });
  
  // Detect objects in image/video
  async function detectObjects(videoElement) {
    const predictions = await model.detect(videoElement);
    
    return predictions.map(pred => ({
      class: pred.class,
      score: pred.score.toFixed(2),
      bbox: pred.bbox
    }));
  }
  
  return { model, detectObjects };
}

// Usage with webcam
async function startCamera() {
  const video = document.getElementById('video');
  const stream = await navigator.mediaDevices.getUserMedia({ 
    video: { width: 640, height: 480 } 
  });
  video.srcObject = stream;
  await video.play();
  
  const { detectObjects } = await setupObjectDetection();
  
  // Continuous detection
  setInterval(async () => {
    const predictions = await detectObjects(video);
    renderPredictions(predictions);
  }, 100);
}

2. Pose Estimation

import * as poseDetection from '@tensorflow-models/pose-detection';

async function setupPoseDetection() {
  const detector = await poseDetection.createDetector(
    poseDetection.SupportedModels.MoveNet,
    { modelType: poseDetection.movenet.modelType.SINGLEPOSE_LIGHTNING }
  );
  
  async function detectPose(imageElement) {
    const poses = await detector.estimatePoses(imageElement);
    
    return poses[0]?.keypoints || [];
  }
  
  return { detector, detectPose };
}

// Draw skeleton
function drawSkeleton(keypoints) {
  const connections = [
    ['nose', 'left_eye'], ['nose', 'right_eye'],
    ['left_shoulder', 'right_shoulder'],
    // ... more connections
  ];
  
  connections.forEach(([from, to]) => {
    const fromPoint = keypoints.find(k => k.name === from);
    const toPoint = keypoints.find(k => k.name === to);
    
    if (fromPoint?.score > 0.3 && toPoint?.score > 0.3) {
      // Draw line between points
      drawLine(fromPoint, toPoint);
    }
  });
}

3. Text Classification

import * as ort from 'onnxruntime-web';

class SentimentAnalyzer {
  constructor() {
    this.session = null;
    this.tokenizer = null;
  }
  
  async load() {
    this.session = await ort.InferenceSession.create(
      '/models/distilbert-sentiment.onnx',
      { executionProviders: ['webgpu'] }
    );
    
    // Load tokenizer (simplified)
    this.tokenizer = await loadTokenizer();
  }
  
  async analyze(text) {
    const tokens = this.tokenizer.encode(text);
    const inputIds = new ort.Tensor('int64', tokens, [1, 128]);
    
    const feeds = { 'input_ids': inputIds };
    const results = await this.session.run(feeds);
    
    const logits = results.logits.data;
    const sentiment = logits[0] > logits[1] ? 'negative' : 'positive';
    const confidence = Math.max(logits[0], logits[1]);
    
    return { sentiment, confidence };
  }
}

// Usage
const analyzer = new SentimentAnalyzer();
await analyzer.load();

const result = await analyzer.analyze('This product is amazing!');
console.log(result); // { sentiment: 'positive', confidence: 0.98 }

Best Practices

1. Model Optimization

// Quantize model for faster loading
// Use smaller models (MobileNet, EfficientNet-Lite)
// Lazy load models when needed
const loadModel = () => import('./model').then(m => m.load());

2. Handle WebGPU Fallback

async function loadWithFallback() {
  const backends = ['webgpu', 'webgl', 'wasm'];
  
  for (const backend of backends) {
    try {
      tf.setBackend(backend);
      await tf.ready();
      console.log(`Using backend: ${backend}`);
      return backend;
    } catch (e) {
      console.warn(`${backend} failed, trying next...`);
    }
  }
  
  throw new Error('No suitable backend found');
}

3. Memory Management

function processInferences() {
  // Always dispose tensors
  const tensor = tf.browser.fromPixels(image);
  const result = model.predict(tensor);
  
  // Manual cleanup
  tensor.dispose();
  result.dispose();
  
  // Or use tidy()
  const output = tf.tidy(() => {
    const tensor = tf.browser.fromPixels(image);
    return model.predict(tensor);
  });
}

Browser Support

Feature Chrome Firefox Safari Edge
WebGPU 113+ 145+ 17+ 113+
WebGL2 56+ 51+ 15+ 79+
WASM 57+ 52+ 11+ 16+
WebNN 118+ Behind flag No 118+

External Resources

Documentation

Models

Learning


Key Takeaways

  • WebGPU enables near-native GPU performance in browsers
  • ONNX Runtime Web runs ONNX models in browsers with WebGPU
  • TensorFlow.js provides familiar APIs with WebGPU backend
  • Use cases: image classification, object detection, pose estimation, NLP
  • Best practices: model optimization, fallback handling, memory management
  • Browser support: WebGPU in all major browsers (2024+)

Next Steps: Explore React 2025: Server Components and Next.js 15 for the latest in React development.

Resources

Comments

Share this article

Scan to read on mobile