Introduction
WebSocket provides full-duplex communication between clients and servers, enabling real-time, bidirectional data transfer. Unlike traditional HTTP request-response patterns, WebSocket connections remain open, allowing servers to push data to clients instantly. This guide covers WebSocket API design patterns, implementation considerations, and best practices for building scalable real-time applications.
WebSocket Fundamentals
How WebSocket Works
The WebSocket protocol starts as an HTTP upgrade handshake:
Client โ Server:
GET /ws/chat HTTP/1.1
Host: api.example.com
Upgrade: websocket
Connection: Upgrade
Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==
Sec-WebSocket-Version: 13
Server โ Client:
HTTP/1.1 101 Switching Protocols
Upgrade: websocket
Connection: Upgrade
Sec-WebSocket-Accept: s3pPLMBiTxaQ9kYGzzhZRbK+xOo=
After the handshake, the connection remains open for bidirectional messaging.
When to Use WebSocket
| Use Case | WebSocket | HTTP Polling | Server-Sent Events |
|---|---|---|---|
| Chat applications | โ | โ | โ |
| Real-time dashboards | โ | โ | โ |
| Live notifications | โ | โ | โ |
| Collaborative editing | โ | โ | โ |
| IoT device updates | โ | โ | โ |
| Gaming | โ | โ | โ |
| Stock tickers | โ | โ | โ |
WebSocket API Architecture
Server Implementation
import asyncio
import websockets
import json
from dataclasses import dataclass
from typing import Dict, Set, Optional
from datetime import datetime
import secrets
@dataclass
class WebSocketMessage:
type: str
payload: dict
timestamp: str = None
message_id: str = None
def __post_init__(self):
if self.timestamp is None:
self.timestamp = datetime.utcnow().isoformat()
if self.message_id is None:
self.message_id = secrets.token_hex(8)
class WebSocketServer:
"""Production WebSocket server with room support."""
def __init__(self, host: str = "0.0.0.0", port: int = 8765):
self.host = host
self.port = port
self.clients: Dict[str, Set[websockets.WebSocketServerProtocol]] = {}
self.user_sockets: Dict[str, websockets.WebSocketServerProtocol] = {}
self.server = None
async def handle_connection(self, websocket: websockets.WebSocketServerProtocol, path: str):
"""Handle new WebSocket connection."""
client_id = None
try:
# Authentication handshake
client_id = await self.authenticate(websocket)
if not client_id:
await websocket.close(4001, "Authentication required")
return
# Register client
self.user_sockets[client_id] = websocket
print(f"Client {client_id} connected")
# Send connection confirmation
await self.send_message(websocket, WebSocketMessage(
type="connection_established",
payload={"client_id": client_id, "server_time": datetime.utcnow().isoformat()}
))
# Handle incoming messages
async for raw_message in websocket:
await self.process_message(client_id, raw_message)
except websockets.exceptions.ConnectionClosed:
print(f"Client {client_id} disconnected normally")
except Exception as e:
print(f"Error handling client {client_id}: {e}")
finally:
if client_id:
await self.disconnect(client_id)
async def authenticate(self, websocket) -> Optional[str]:
"""Authenticate WebSocket connection."""
try:
# Wait for auth message
auth_message = await asyncio.wait_for(websocket.recv(), timeout=10)
auth_data = json.loads(auth_message)
if auth_data.get("type") != "auth":
return None
token = auth_data.get("payload", {}).get("token")
# Validate token (implement your own logic)
client_id = await self.validate_token(token)
return client_id
except asyncio.TimeoutError:
return None
async def validate_token(self, token: str) -> Optional[str]:
"""Validate authentication token."""
# Implement token validation
return token # Simplified
async def process_message(self, client_id: str, raw_message: str):
"""Process incoming message."""
try:
data = json.loads(raw_message)
message_type = data.get("type")
payload = data.get("payload", {})
handlers = {
"join_room": self.handle_join_room,
"leave_room": self.handle_leave_room,
"message": self.handle_broadcast,
"ping": self.handle_ping,
}
handler = handlers.get(message_type)
if handler:
await handler(client_id, payload)
except json.JSONDecodeError:
await self.send_error(client_id, "Invalid JSON")
async def handle_join_room(self, client_id: str, payload: dict):
"""Add client to a room."""
room_id = payload.get("room_id")
if room_id not in self.clients:
self.clients[room_id] = set()
self.clients[room_id].add(self.user_sockets[client_id])
await self.send_message(self.user_sockets[client_id], WebSocketMessage(
type="room_joined",
payload={"room_id": room_id}
))
async def handle_leave_room(self, client_id: str, payload: dict):
"""Remove client from a room."""
room_id = payload.get("room_id")
if room_id in self.clients:
self.clients[room_id].discard(self.user_sockets[client_id])
async def handle_broadcast(self, client_id: str, payload: dict):
"""Broadcast message to room or all clients."""
room_id = payload.get("room_id")
message = payload.get("message")
message_obj = WebSocketMessage(
type="chat_message",
payload={
"sender_id": client_id,
"message": message,
"room_id": room_id
}
)
if room_id and room_id in self.clients:
await self.broadcast_to_room(room_id, message_obj)
else:
await self.broadcast_to_all(message_obj)
async def broadcast_to_room(self, room_id: str, message: WebSocketMessage):
"""Send message to all clients in a room."""
if room_id not in self.clients:
return
message_str = json.dumps({
"type": message.type,
"payload": message.payload,
"timestamp": message.timestamp,
"message_id": message.message_id
})
# Send to all clients in room
await asyncio.gather(
*[ws.send(message_str) for ws in self.clients[room_id] if ws.open],
return_exceptions=True
)
async def broadcast_to_all(self, message: WebSocketMessage):
"""Broadcast to all connected clients."""
message_str = json.dumps({
"type": message.type,
"payload": message.payload,
"timestamp": message.timestamp
})
await asyncio.gather(
*[ws.send(message_str) for ws in self.user_sockets.values() if ws.open],
return_exceptions=True
)
async def send_message(self, websocket, message: WebSocketMessage):
"""Send message to specific client."""
await websocket.send(json.dumps({
"type": message.type,
"payload": message.payload,
"timestamp": message.timestamp,
"message_id": message.message_id
}))
async def send_error(self, client_id: str, error: str):
"""Send error message to client."""
if client_id in self.user_sockets:
await self.send_message(self.user_sockets[client_id], WebSocketMessage(
type="error",
payload={"message": error}
))
async def disconnect(self, client_id: str):
"""Clean up on disconnect."""
# Remove from all rooms
for room_id in self.clients:
self.clients[room_id].discard(self.user_sockets.get(client_id))
# Remove from user sockets
self.user_sockets.pop(client_id, None)
async def start(self):
"""Start WebSocket server."""
self.server = await websockets.serve(
self.handle_connection,
self.host,
self.port
)
print(f"WebSocket server started on {self.host}:{self.port}")
async def stop(self):
"""Stop WebSocket server."""
if self.server:
self.server.close()
await self.server.wait_closed()
Client Implementation
class WebSocketClient {
constructor(url, options = {}) {
this.url = url;
this.reconnectAttempts = options.reconnectAttempts || 5;
this.reconnectDelay = options.reconnectDelay || 1000;
this.heartbeatInterval = options.heartbeatInterval || 30000;
this.messageQueue = [];
this.handlers = new Map();
this.clientId = null;
this.ws = null;
this.connected = false;
this.reconnectCount = 0;
this.heartbeatTimer = null;
}
connect(token) {
return new Promise((resolve, reject) => {
this.ws = new WebSocket(this.url);
this.ws.onopen = () => {
console.log('WebSocket connected');
this.connected = true;
this.authenticate(token);
this.startHeartbeat();
this.flushQueue();
resolve();
};
this.ws.onmessage = (event) => {
this.handleMessage(JSON.parse(event.data));
};
this.ws.onclose = (event) => {
console.log('WebSocket closed', event.code, event.reason);
this.connected = false;
this.stopHeartbeat();
this.handleDisconnect();
};
this.ws.onerror = (error) => {
console.error('WebSocket error:', error);
reject(error);
};
});
}
authenticate(token) {
this.send({
type: 'auth',
payload: { token }
});
}
send(message) {
const messageObj = {
...message,
timestamp: new Date().toISOString(),
clientId: this.clientId
};
if (this.connected) {
this.ws.send(JSON.stringify(messageObj));
} else {
this.messageQueue.push(messageObj);
}
}
on(type, handler) {
if (!this.handlers.has(type)) {
this.handlers.set(type, []);
}
this.handlers.get(type).push(handler);
}
off(type, handler) {
if (this.handlers.has(type)) {
const handlers = this.handlers.get(type);
const index = handlers.indexOf(handler);
if (index > -1) {
handlers.splice(index, 1);
}
}
}
handleMessage(data) {
const { type, payload } = data;
if (type === 'connection_established') {
this.clientId = payload.client_id;
}
const handlers = this.handlers.get(type) || [];
handlers.forEach(handler => handler(payload));
}
handleDisconnect() {
if (this.reconnectCount < this.reconnectAttempts) {
this.reconnectCount++;
console.log(`Reconnecting... attempt ${this.reconnectCount}`);
setTimeout(() => {
this.connect().catch(console.error);
}, this.reconnectDelay * this.reconnectCount);
}
}
startHeartbeat() {
this.heartbeatTimer = setInterval(() => {
this.send({ type: 'ping', payload: {} });
}, this.heartbeatInterval);
}
stopHeartbeat() {
if (this.heartbeatTimer) {
clearInterval(this.heartbeatTimer);
}
}
flushQueue() {
while (this.messageQueue.length > 0) {
const message = this.messageQueue.shift();
this.ws.send(JSON.stringify(message));
}
}
disconnect() {
this.reconnectAttempts = 0; // Prevent reconnection
this.ws.close(1000, 'Client disconnect');
}
}
// Usage
const ws = new WebSocketClient('wss://api.example.com/ws');
ws.on('chat_message', (payload) => {
console.log('New message:', payload.message);
});
ws.on('notification', (payload) => {
showNotification(payload.title, payload.body);
});
ws.connect('auth-token').then(() => {
ws.send({
type: 'join_room',
payload: { room_id: 'general' }
});
});
Message Protocol Design
Message Format Standards
{
"type": "message_type",
"payload": { },
"metadata": {
"message_id": "uuid",
"timestamp": "ISO8601",
"correlation_id": "uuid",
"client_id": "string"
}
}
Common Message Types
MESSAGE_TYPES = {
# Connection
"auth": "Authentication request",
"auth_response": "Authentication response",
"connection_established": "Connection confirmed",
# Rooms
"join_room": "Join a room/channel",
"leave_room": "Leave a room/channel",
"room_joined": "Successfully joined",
"room_left": "Successfully left",
"room_members": "List of room members",
# Messaging
"message": "Chat message",
"message_sent": "Message delivered confirmation",
"typing_start": "User started typing",
"typing_stop": "User stopped typing",
"read_receipt": "Message read receipt",
# Presence
"presence_update": "Online/offline status",
"presence_query": "Query presence status",
# System
"ping": "Keep-alive ping",
"pong": "Keep-alive pong",
"error": "Error notification"
}
Scaling WebSocket Servers
Horizontal Scaling with Redis
import aioredis
import json
from typing import List
class RedisPubSubManager:
"""Manage WebSocket scaling with Redis pub/sub."""
def __init__(self, redis_url: str):
self.redis = None
self.redis_url = redis_url
self.pubsub = None
self.channels = {}
async def connect(self):
self.redis = await aioredis.create_redis_pool(self.redis_url)
async def subscribe(self, channel: str, handler):
"""Subscribe to a Redis channel."""
pubsub = self.redis.subscribe(channel)
self.channels[channel] = handler
# Start listening
asyncio.create_task(self._listen(channel, pubsub))
async def publish(self, channel: str, message: dict):
"""Publish message to channel."""
await self.redis.publish(channel, json.dumps(message))
async def _listen(self, channel: str, pubsub):
"""Listen for messages on channel."""
while await pubsub.iter() is not None:
try:
message = await pubsub.get()
if message:
data = json.loads(message.decode())
handler = self.channels.get(channel)
if handler:
await handler(data)
except Exception as e:
print(f"Error in pubsub listener: {e}")
class ScaledWebSocketServer(WebSocketServer):
"""WebSocket server with horizontal scaling."""
def __init__(self, redis_url: str, *args, **kwargs):
super().__init__(*args, **kwargs)
self.pubsub = RedisPubSubManager(redis_url)
async def start(self):
await self.pubsub.connect()
# Subscribe to cross-server communication
await self.pubsub.subscribe('ws:broadcast', self.handle_remote_message)
await super().start()
async def handle_remote_message(self, data: dict):
"""Handle message from another server."""
await self.broadcast_to_room(data['room_id'], WebSocketMessage(
type=data['type'],
payload=data['payload']
))
async def broadcast_to_room(self, room_id: str, message: WebSocketMessage):
"""Broadcast to room, including remote servers."""
# Send to local clients
await super().broadcast_to_room(room_id, message)
# Publish to Redis for other servers
await self.pubsub.publish('ws:broadcast', {
'room_id': room_id,
'type': message.type,
'payload': message.payload,
'timestamp': message.timestamp
})
Connection Management with Nginx
# Nginx WebSocket configuration
map $http_upgrade $connection_upgrade {
default upgrade;
'' close;
}
upstream websocket_backend {
server ws1.example.com:8765;
server ws2.example.com:8766;
server ws3.example.com:8767;
}
server {
listen 443 ssl http2;
server_name api.example.com;
ssl_certificate /etc/nginx/ssl/cert.pem;
ssl_certificate_key /etc/nginx/ssl/key.pem;
location /ws/ {
proxy_pass http://websocket_backend;
proxy_http_version 1.1;
proxy_set_header Upgrade $http_upgrade;
proxy_set_header Connection $connection_upgrade;
proxy_set_header Host $host;
proxy_set_header X-Real-IP $remote_addr;
# Timeouts
proxy_read_timeout 3600s;
proxy_send_timeout 3600s;
# Buffering
proxy_buffering off;
proxy_request_buffering off;
}
}
Security Considerations
Connection Security
import ssl
import secrets
from hashlib import sha256
import hmac
class SecureWebSocketServer(WebSocketServer):
"""WebSocket server with security features."""
def __init__(self, *args, secret_key: str, **kwargs):
super().__init__(*args, **kwargs)
self.secret_key = secret_key
async def validate_token(self, token: str) -> Optional[str]:
"""Validate JWT or custom token."""
try:
# Implement proper JWT validation
# This is simplified
parts = token.split('.')
if len(parts) != 3:
return None
# Verify signature
header, payload, signature = parts
expected_sig = hmac.new(
self.secret_key.encode(),
f"{header}.{payload}".encode(),
sha256
).hexdigest()
if not hmac.compare_digest(signature, expected_sig):
return None
# Extract user ID from payload
import base64
payload_data = json.loads(base64.b64decode(payload + '=='))
return payload_data.get('user_id')
except Exception:
return None
async def process_message(self, client_id: str, raw_message: str):
"""Process message with rate limiting."""
# Check rate limit
if not await self.check_rate_limit(client_id):
await self.send_error(client_id, "Rate limit exceeded")
return
await super().process_message(client_id, raw_message)
async def check_rate_limit(self, client_id: str) -> bool:
"""Implement rate limiting."""
# Use Redis or in-memory counter
return True # Simplified
Origin Validation
ALLOWED_ORIGINS = {
"https://app.example.com",
"https://admin.example.com",
"http://localhost:3000"
}
async def handle_connection(self, websocket, path: str):
"""Validate origin before accepting connection."""
origin = websocket.request_headers.get('Origin')
if origin and origin not in ALLOWED_ORIGINS:
await websocket.close(4003, "Origin not allowed")
return
# Continue with normal handling
await super().handle_connection(websocket, path)
Best Practices
Design Guidelines
- Always authenticate: Use token-based auth, not IP-based
- Validate messages: Sanitize all incoming data
- Implement heartbeats: Detect stale connections
- Use message acknowledgment: Ensure delivery
- Handle reconnection: Client-side auto-reconnect
- Log everything: Monitor connection health
Error Handling
ERROR_CODES = {
4000: "Unknown error",
4001: "Authentication required",
4002: "Authentication failed",
4003: "Origin not allowed",
4004: "Rate limit exceeded",
4005: "Message too large",
4006: "Invalid message format",
4007: "Room not found",
4008: "Room full",
4009: "User banned"
}
async def send_error(self, client_id: str, error: str, code: int = 4000):
"""Send structured error to client."""
await self.send_message(self.user_sockets[client_id], WebSocketMessage(
type="error",
payload={
"message": error,
"code": code
}
))
Conclusion
WebSocket APIs enable powerful real-time features but require careful design for production use. Key takeaways:
- Use established libraries rather than raw WebSocket implementation
- Implement proper authentication and security from the start
- Design a clear message protocol with versioning
- Plan for horizontal scaling from the beginning
- Monitor connection health and implement reconnection logic
With proper implementation, WebSocket APIs can handle millions of concurrent connections, enabling responsive real-time experiences.
Resources
- WebSocket API MDN Documentation
- WebSocket Protocol RFC 6455
- ws: WebSocket library for Node.js
- Autobahn: WebSocket for Python
Comments