Introduction
WebSocket enables bidirectional, persistent communication between clients and servers. Originally designed for real-time web applications, WebSocket now powers chat systems, live updates, collaborative tools, IoT devices, and trading platforms. However, this persistent connection model introduces security challenges that differ from traditional HTTP.
This comprehensive guide covers WebSocket security from foundation to advanced protection. You’ll learn about authentication strategies, authorization patterns, common vulnerabilities, and implementation best practices. Each concept includes code examples demonstrating secure patterns you can apply immediately.
Whether you’re building a chat application, real-time dashboard, or IoT backend, securing your WebSocket connections is critical. A single vulnerability can expose sensitive data or enable attacks affecting all connected clients.
WebSocket Security Fundamentals
Understanding WebSocket Security
WebSocket connections begin with an HTTP handshake upgrade. The client sends an upgrade request, the server responds with 101 Switching Protocols, and the connection transforms into a persistent WebSocket. This handshake-based approach creates both opportunities and challenges for security.
Key security considerations:
- Connection Persistence: Unlike HTTP, connections stay open indefinitelyโsecurity state must persist
- Bidirectional Communication: Both client and server can send messages at any time
- No Same-Origin Policy: Cross-origin WebSocket connections are possible
- Stateful Protocol: More complex state management than stateless HTTP
// Client-side WebSocket connection (insecure vs secure)
// โ Insecure: Plain WebSocket
const socket = new WebSocket('ws://example.com/socket');
// โ
Secure: WebSocket Secure (WSS)
const socket = new WebSocket('wss://example.com/socket');
// With custom headers for authentication
const socket = new WebSocket('wss://example.com/socket', [], {
headers: {
'Authorization': `Bearer ${token}`
}
});
The WebSocket Threat Model
Understanding threats helps prioritize defenses:
Man-in-the-Middle Attacks: Unencrypted connections (ws://) allow attackers to intercept messages. Solution: Always use WSS.
Cross-Site WebSocket Hijacking (CSWSH): Malicious page opens WebSocket to target site using victim’s cookies. Solution: Implement origin validation and CSRF tokens.
Denial of Service: Attackers flood servers with connections or messages. Solution: Rate limiting, connection limits, message validation.
Injection Attacks: Unsanitized messages execute code or access data. Solution: Input validation, message sanitization.
Authentication Bypass: Stale or missing authentication allows unauthorized access. Solution: Token validation, re-authentication.
# Common WebSocket attack vectors and defenses
ATTACKS = {
"MITM": {
"description": "Man-in-the-Middle interception",
"defense": "Use WSS (WebSocket Secure) with TLS 1.3"
},
"CSWSH": {
"description": "Cross-Site WebSocket Hijacking",
"defense": "Validate Origin header, implement CSRF tokens"
},
"DoS": {
"description": "Denial of Service via connection flooding",
"defense": "Rate limiting, connection limits, message size limits"
},
"Injection": {
"description": "SQL/NoSQL/Command injection via messages",
"defense": "Input validation, parameterized queries, output encoding"
},
"AuthBypass": {
"description": "Exploiting weak or missing authentication",
"defense": "Token-based auth, session validation, re-auth for sensitive ops"
}
}
Authentication Strategies
Token-Based Authentication
WebSocket connections should be authenticated at connection time. JWT (JSON Web Tokens) provide a common approach:
# Python WebSocket server with JWT authentication
import jwt
import asyncio
import websockets
from typing import Set
SECRET_KEY = "your-secret-key" # In production, use environment variable
ALGORITHM = "HS256"
class AuthenticatedWebSocketServer:
def __init__(self):
self.clients: Set = set()
self.user_sessions: dict = {} # websocket -> user_id
async def authenticate(self, websocket, token: str) -> bool:
"""Validate JWT token and extract user info."""
try:
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
user_id = payload.get('sub')
# Check token expiration
exp = payload.get('exp')
if exp and exp < asyncio.get_event_loop().time():
return False
# Store user session
self.user_sessions[websocket] = user_id
return True
except jwt.ExpiredSignatureError:
print("Token expired")
return False
except jwt.InvalidTokenError:
print("Invalid token")
return False
async def handler(self, websocket):
"""Handle authenticated WebSocket connections."""
# Require authentication
try:
# Get token from first message or connection URL
message = await asyncio.wait_for(websocket.recv(), timeout=5)
if isinstance(message, str) and message.startswith('auth:'):
token = message[5:] # Extract token
if await self.authenticate(websocket, token):
self.clients.add(websocket)
print(f"Client authenticated: {self.user_sessions[websocket]}")
# Handle authenticated messages
await self.handle_messages(websocket)
else:
await websocket.send('{"error": "Authentication failed"}')
await websocket.close(1008, "Authentication required")
except asyncio.TimeoutError:
await websocket.close(1008, "Authentication timeout")
async def handle_messages(self, websocket):
"""Process messages from authenticated client."""
try:
async for message in websocket:
# Verify user still authenticated
if websocket not in self.user_sessions:
break
await self.process_message(websocket, message)
except websockets.exceptions.ConnectionClosed:
pass
finally:
self.clients.discard(websocket)
if websocket in self.user_sessions:
del self.user_sessions[websocket]
async def process_message(self, websocket, message):
"""Process individual message."""
user_id = self.user_sessions[websocket]
# Process message with user context
print(f"Message from {user_id}: {message}")
# Example: Send response
await websocket.send(f'{"ack": "received"}')
Session-Based Authentication
For applications using traditional sessions:
# Session-based authentication with cookies
from aiohttp import web
import aiohttp
from aiohttp import WSMessageType
import secrets
class SessionWebSocketHandler:
def __init__(self, session_store):
self.sessions = session_store
async def handle(self, request):
ws = web.WebSocketResponse()
await ws.prepare(request)
# Get session from cookie
session_id = request.cookies.get('session_id')
if not session_id:
await ws.close(code=1008, message=b'No session')
return ws
# Validate session
session = await self.sessions.get(session_id)
if not session or session.get('authenticated') != True:
await ws.close(code=1008, message=b'Not authenticated')
return ws
# Store user in request context
request['user_id'] = session.get('user_id')
# Handle WebSocket messages
try:
async for msg in ws:
if msg.type == WSMessageType.TEXT:
await self.process_message(request, msg.data)
elif msg.type == WSMessageType.ERROR:
print(f'WebSocket error: {ws.exception()}')
finally:
print('WebSocket connection closed')
return ws
async def process_message(self, request, message):
"""Process message with session context."""
user_id = request['user_id']
# Process message...
pass
OAuth 2.0 for WebSocket
Integrating OAuth 2.0 with WebSocket:
// Client: Authenticate with OAuth token
class SecureWebSocketClient {
constructor(url, oauthConfig) {
this.url = url;
this.oauthConfig = oauthConfig;
this.socket = null;
}
async connect() {
// Get OAuth token
const token = await this.getOAuthToken();
// Connect with token
this.socket = new WebSocket(this.url, [], {
headers: {
'Authorization': `Bearer ${token}`
}
});
return new Promise((resolve, reject) => {
this.socket.onopen = () => resolve(this.socket);
this.socket.onerror = reject;
});
}
async getOAuthToken() {
// Use refresh token to get access token
const response = await fetch(this.oauthConfig.tokenUrl, {
method: 'POST',
headers: {
'Content-Type': 'application/x-www-form-urlencoded'
},
body: new URLSearchParams({
grant_type: 'refresh_token',
refresh_token: this.oauthConfig.refreshToken,
client_id: this.oauthConfig.clientId,
client_secret: this.oauthConfig.clientSecret
})
});
const data = await response.json();
return data.access_token;
}
// Handle token refresh on reconnect
async handleConnectionClose(event) {
if (event.code === 4001) { // Custom: token expired
const token = await this.getOAuthToken();
this.socket = await this.connect();
}
}
}
Authorization and Access Control
Channel-Based Authorization
Control which channels/topics users can access:
# Channel authorization with permissions
from enum import Enum
from typing import Dict, Set
import jwt
class Permission(Enum):
READ = "read"
WRITE = "write"
ADMIN = "admin"
class ChannelAuthorization:
"""Manage channel permissions."""
def __init__(self):
# Channel -> set of allowed permissions
self.channel_permissions: Dict[str, Set[Permission]] = {}
self.user_permissions: Dict[str, Set[Permission]] = {}
def define_channel(self, channel: str, required_permissions: Set[Permission]):
"""Define a channel and its required permissions."""
self.channel_permissions[channel] = required_permissions
def grant_user_permission(self, user_id: str, permission: Permission):
"""Grant a permission to a user."""
if user_id not in self.user_permissions:
self.user_permissions[user_id] = set()
self.user_permissions[user_id].add(permission)
def can_access(self, user_id: str, channel: str, required_permission: Permission) -> bool:
"""Check if user can access channel with required permission."""
# Get user permissions
user_perms = self.user_permissions.get(user_id, set())
# Check if user has required permission
if required_permission not in user_perms:
return False
# Check if channel exists and requires this permission
channel_perms = self.channel_permissions.get(channel, set())
if required_permission not in channel_perms:
return True # Channel doesn't require this permission
return True
def authorize_message(self, user_id: str, channel: str, message: str) -> bool:
"""Authorize a message send attempt."""
return self.can_access(user_id, channel, Permission.WRITE)
# Usage example
auth = ChannelAuthorization()
# Define channels with permissions
auth.define_channel("public:chat", {Permission.READ, Permission.WRITE})
auth.define_channel("private:room-{room_id}", {Permission.READ, Permission.WRITE})
auth.define_channel("admin:logs", {Permission.READ, Permission.ADMIN})
auth.define_channel("user:{user_id}:messages", {Permission.READ, Permission.WRITE})
# Grant permissions to users
auth.grant_user_permission("user123", Permission.READ)
auth.grant_user_permission("user123", Permission.WRITE)
auth.grant_user_permission("admin456", Permission.ADMIN)
# Check access
print(auth.can_access("user123", "public:chat", Permission.READ)) # True
print(auth.can_access("user123", "admin:logs", Permission.ADMIN)) # False
print(auth.can_access("admin456", "admin:logs", Permission.ADMIN)) # True
Role-Based Access Control (RBAC)
Implement RBAC for WebSocket:
# RBAC implementation for WebSocket
from typing import Dict, List, Optional
from dataclasses import dataclass
@dataclass
class Role:
name: str
permissions: List[str]
@dataclass
class User:
user_id: str
roles: List[str]
metadata: Dict
class WebSocketRBAC:
"""Role-Based Access Control for WebSocket."""
def __init__(self):
self.roles: Dict[str, Role] = {}
self.user_roles: Dict[str, List[str]] = {}
self.channel_acl: Dict[str, Dict[str, List[str]]] = {}
def define_role(self, role: Role):
"""Define a role with permissions."""
self.roles[role.name] = role
def assign_role(self, user_id: str, role_name: str):
"""Assign a role to a user."""
if role_name not in self.roles:
raise ValueError(f"Unknown role: {role_name}")
if user_id not in self.user_roles:
self.user_roles[user_id] = []
if role_name not in self.user_roles[user_id]:
self.user_roles[user_id].append(role_name)
def set_channel_acl(self, channel: str, allow: List[str] = None, deny: List[str] = None):
"""Set channel access control list."""
self.channel_acl[channel] = {
'allow': allow or [],
'deny': deny or []
}
def check_permission(self, user_id: str, channel: str, action: str) -> bool:
"""Check if user can perform action on channel."""
# Get user roles
user_roles = self.user_roles.get(user_id, [])
# Get user permissions from roles
user_permissions = set()
for role_name in user_roles:
role = self.roles.get(role_name)
if role:
user_permissions.update(role.permissions)
# Check channel ACL
if channel in self.channel_acl:
acl = self.channel_acl[channel]
# Check deny first
for role_name in user_roles:
if role_name in acl['deny']:
return False
# Check allow
if acl['allow']:
allowed = any(role in acl['allow'] for role in user_roles)
if not allowed:
return False
# Check if user has action permission
return action in user_permissions
# Example setup
rbac = WebSocketRBAC()
# Define roles
rbac.define_role(Role("guest", ["read"]))
rbac.define_role(Role("user", ["read", "write"]))
rbac.define_role(Role("moderator", ["read", "write", "moderate"]))
rbac.define_role(Role("admin", ["read", "write", "moderate", "admin"]))
# Assign roles
rbac.assign_role("user123", "user")
rbac.assign_role("mod456", "moderator")
rbac.assign_role("admin789", "admin")
# Set channel permissions
rbac.set_channel_acl("chat:general", allow=["user", "moderator", "admin"])
rbac.set_channel_acl("chat:mod-only", allow=["moderator", "admin"])
rbac.set_channel_acl("admin:stats", allow=["admin"])
# Check permissions
print(rbac.check_permission("user123", "chat:general", "write")) # True
print(rbac.check_permission("user123", "admin:stats", "read")) # False
print(rbac.check_permission("admin789", "admin:stats", "admin")) # True
Protecting Against Vulnerabilities
Cross-Site WebSocket Hijacking (CSWSH)
This vulnerability occurs when a malicious site tricks a user’s browser into establishing a WebSocket connection to a target site:
# Server-side CSWSH protection
from urllib.parse import urlparse
import re
class CSWSHProtection:
"""Protect against Cross-Site WebSocket Hijacking."""
def __init__(self, allowed_origins: List[str]):
self.allowed_origins = set(allowed_origins)
self.csrf_tokens: Dict[str, str] = {}
def generate_csrf_token(self, session_id: str) -> str:
"""Generate and store CSRF token for session."""
import secrets
token = secrets.token_urlsafe(32)
self.csrf_tokens[session_id] = token
return token
def validate_origin(self, origin: str, allowed_origins: List[str]) -> bool:
"""Validate Origin header against allowed list."""
if not origin:
return False # No origin provided
parsed = urlparse(origin)
# Check exact match
if origin in allowed_origins:
return True
# Check wildcard patterns
for pattern in allowed_origins:
if '*' in pattern:
regex = pattern.replace('.', r'\.').replace('*', '.*')
if re.match(f'^{regex}$', origin):
return True
return False
def validate_csrf(self, session_id: str, token: str) -> bool:
"""Validate CSRF token."""
expected = self.csrf_tokens.get(session_id)
if not expected:
return False
# Timing-safe comparison
import hmac
return hmac.compare_digest(expected, token)
async def handle_websocket(self, request):
"""Apply CSWSH protection to WebSocket handler."""
origin = request.headers.get('Origin')
# Validate origin
allowed = ['https://example.com', 'https://app.example.com']
if origin and not self.validate_origin(origin, allowed):
raise web.HTTPForbidden(text="Origin not allowed")
# Generate CSRF token for session
session_id = request.cookies.get('session_id')
csrf_token = self.generate_csrf_token(session_id)
# Require CSRF token in first message
# (Implementation depends on your message protocol)
return await self.websocket_handler(request)
Input Validation and Sanitization
Prevent injection attacks:
# Input validation for WebSocket messages
import re
import json
from typing import Any, Dict, Union
class MessageValidator:
"""Validate and sanitize WebSocket messages."""
# Define message schemas
SCHEMAS = {
"chat_message": {
"type": "object",
"required": ["type", "content"],
"properties": {
"type": {"type": "string", "enum": ["chat_message"]},
"content": {"type": "string", "maxLength": 5000},
"room_id": {"type": "string", "pattern": "^[a-zA-Z0-9_-]{1,50}$"}
}
},
"action": {
"type": "object",
"required": ["type", "action"],
"properties": {
"type": {"type": "string", "enum": ["action"]},
"action": {"type": "string", "enum": ["join", "leave", "typing"]},
"target": {"type": "string", "maxLength": 100}
}
}
}
@staticmethod
def validate_message(message: Union[str, Dict]) -> tuple[bool, Dict, str]:
"""
Validate incoming message against schemas.
Returns: (is_valid, parsed_data, error_message)
"""
# Parse if string
if isinstance(message, str):
try:
data = json.loads(message)
except json.JSONDecodeError as e:
return False, {}, f"Invalid JSON: {e}"
else:
data = message
# Check message type
msg_type = data.get('type')
if not msg_type:
return False, {}, "Missing 'type' field"
# Get schema
schema = MessageValidator.SCHEMAS.get(msg_type)
if not schema:
return False, {}, f"Unknown message type: {msg_type}"
# Validate required fields
for field in schema.get('required', []):
if field not in data:
return False, {}, f"Missing required field: {field}"
# Validate field types and values
for field, spec in schema.get('properties', {}).items():
if field in data:
value = data[field]
# Type check
expected_type = spec.get('type')
if expected_type == 'string' and not isinstance(value, str):
return False, {}, f"Field '{field}' must be string"
# String validations
if isinstance(value, str):
max_len = spec.get('maxLength')
if max_len and len(value) > max_len:
return False, {}, f"Field '{field}' exceeds max length {max_len}"
pattern = spec.get('pattern')
if pattern and not re.match(pattern, value):
return False, {}, f"Field '{field}' has invalid format"
# Enum check
enum = spec.get('enum')
if enum and value not in enum:
return False, {}, f"Field '{field}' has invalid value"
return True, data, ""
@staticmethod
def sanitize_html(content: str) -> str:
"""Remove HTML tags from content."""
# Simple tag removal - consider using proper HTML parser
return re.sub(r'<[^>]+>', '', content)
@staticmethod
def sanitize_sql(value: str) -> str:
"""Basic SQL injection prevention."""
# Note: Use parameterized queries in actual implementation
dangerous = [';', '--', '/*', '*/', 'xp_', 'sp_', 'EXEC', 'DROP']
for pattern in dangerous:
value = value.replace(pattern, '')
return value
# Usage in WebSocket handler
async def handle_message(websocket, message):
# Validate message
is_valid, data, error = MessageValidator.validate_message(message)
if not is_valid:
await websocket.send(json.dumps({
"type": "error",
"message": error
}))
return
# Process validated message
if data['type'] == 'chat_message':
# Sanitize content
content = MessageValidator.sanitize_html(data['content'])
# Now safe to process...
Rate Limiting and DoS Protection
# Rate limiting for WebSocket
import time
from collections import defaultdict
from dataclasses import dataclass, field
from typing import Dict
import asyncio
@dataclass
class RateLimitConfig:
max_connections: int = 100 # Per IP
max_messages_per_minute: int = 60
max_messages_per_hour: int = 1000
max_message_size: int = 1024 * 1024 # 1MB
connection_timeout: int = 30
class RateLimiter:
"""Rate limiting for WebSocket connections."""
def __init__(self, config: RateLimitConfig):
self.config = config
self.connections: Dict[str, float] = {} # IP -> connect time
self.message_counts: Dict[str, list] = defaultdict(list) # IP -> [timestamps]
def get_client_ip(self, request) -> str:
"""Get client IP from request."""
# Check X-Forwarded-For if behind proxy
forwarded = request.headers.get('X-Forwarded-For')
if forwarded:
return forwarded.split(',')[0].strip()
return request.remote
async def check_connection(self, request) -> tuple[bool, str]:
"""Check if new connection is allowed."""
ip = self.get_client_ip(request)
now = time.time()
# Check connection limit
if len(self.connections) >= self.config.max_connections:
return False, "Server at capacity"
# Check per-IP limit
active_from_ip = sum(1 for t in self.connections.values()
if now - t < 3600) # Within last hour
if active_from_ip >= self.config.max_connections:
return False, "Too many connections from your IP"
# Track connection
self.connections[ip] = now
return True, ""
async def check_message(self, websocket, message: str) -> tuple[bool, str]:
"""Check if message is allowed."""
ip = self.get_client_ip(websocket.request)
now = time.time()
# Check message size
if len(message) > self.config.max_message_size:
return False, "Message too large"
# Clean old timestamps
self.message_counts[ip] = [
t for t in self.message_counts[ip]
if now - t < 60 # Last minute
]
# Check per-minute limit
if len(self.message_counts[ip]) >= self.config.max_messages_per_minute:
return False, "Rate limit exceeded (per minute)"
# Check per-hour limit
hourly = [
t for t in self.message_counts[ip]
if now - t < 3600
]
if len(hourly) >= self.config.max_messages_per_hour:
return False, "Rate limit exceeded (per hour)"
# Record message
self.message_counts[ip].append(now)
return True, ""
def cleanup_connections(self):
"""Remove stale connections."""
now = time.time()
stale = [
ip for ip, connect_time in self.connections.items()
if now - connect_time > self.config.connection_timeout
]
for ip in stale:
del self.connections[ip]
# Clean old message counts
for ip in list(self.message_counts.keys()):
self.message_counts[ip] = [
t for t in self.message_counts[ip]
if now - t < 3600
]
if not self.message_counts[ip]:
del self.message_counts[ip]
TLS and Encryption
Configuring WSS (WebSocket Secure)
Always use TLS-encrypted WebSocket connections:
# Python: Configure secure WebSocket with TLS
import ssl
import asyncio
import websockets
# Create SSL context
ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
# Load certificate and key
ssl_context.load_cert_chain(
certfile="/path/to/certificate.crt",
keyfile="/path/to/private.key"
)
# Configure security options
ssl_context.minimum_version = ssl.TLSVersion.TLSv1_3 # Require TLS 1.3
ssl_context.set_ciphers('ECDHE+AESGCM:ECDHE+CHACHA20:DHE+AESGCM:DHE+CHACHA20:!aNULL:!MD5:!DSS')
# Enable OCSP stapling (requires certificate configured)
ssl_context.sni_callback = None # Can configure for SNI
async def handler(websocket):
"""Handle secure WebSocket connection."""
# Check connection is secure
if websocket.extra_headers:
# Verify using wss:// protocol
pass
async for message in websocket:
await websocket.send(f"Echo: {message}")
# Start server with TLS
async def main():
async with websockets.serve(
handler,
host="0.0.0.0",
port=8765,
ssl=ssl_context
):
await asyncio.Future() # Run forever
asyncio.run(main())
Client-Side Certificate Authentication
For mutual TLS (mTLS):
# Server: Require client certificates
import ssl
ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
ssl_context.load_cert_chain("server.crt", "server.key")
# Require client certificates
ssl_context.verify_mode = ssl.CERT_REQUIRED
ssl_context.load_verify_locations("ca.crt") # CA that signed client certs
# Check client certificate in handler
async def handler(websocket):
# Get client certificate
cert = websocket.transport.get_extra_info('peercert')
if not cert:
await websocket.close(code=4001, message=b"Certificate required")
return
# Verify certificate
subject = dict(x[0] for x in cert['subject'])
common_name = subject.get('commonName')
# Check against allowed clients
allowed_clients = ["client-app-1", "client-app-2"]
if common_name not in allowed_clients:
await websocket.close(code=4003, message=b"Unauthorized client")
return
# Process authenticated connection...
pass
Best Practices Summary
Security Checklist
Use this checklist when implementing WebSocket security:
- Use WSS (WebSocket Secure) with TLS 1.3
- Validate Origin header on server
- Implement authentication at connection time
- Use secure token validation (JWT with expiration)
- Implement authorization per channel/topic
- Validate and sanitize all incoming messages
- Implement rate limiting (connection and message)
- Set maximum message size limits
- Implement connection timeouts
- Log security events for monitoring
- Use mTLS for high-security requirements
- Implement heartbeat/ping-pong for connection health
- Handle connection close gracefully
- Use secure random for tokens/secrets
Implementation Example
Here’s a complete secure WebSocket server:
# Complete secure WebSocket server example
import asyncio
import websockets
import json
import jwt
import secrets
import time
from typing import Dict, Set
import logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
SECRET_KEY = secrets.token_urlsafe(32)
ALLOWED_ORIGINS = {"https://example.com", "https://app.example.com"}
MAX_MESSAGE_SIZE = 1024 * 100 # 100KB
RATE_LIMIT_MESSAGES = 60 # per minute
class SecureWebSocketServer:
def __init__(self):
: Set = {}
self.user_s self.clientsessions: Dict = {}
self.rate_limits: Dict = {}
async def validate_connection(self, request) -> tuple[bool, str]:
"""Validate initial connection request."""
# Check origin
origin = request.headers.get('Origin')
if origin and origin not in ALLOWED_ORIGINS:
logger.warning(f"Rejected connection from origin: {origin}")
return False, "Origin not allowed"
# Check max connections
if len(self.clients) >= 10000:
return False, "Server at capacity"
return True, ""
async def authenticate(self, websocket, token: str) -> tuple[bool, str]:
"""Authenticate WebSocket connection."""
try:
payload = jwt.decode(token, SECRET_KEY, algorithms=["HS256"])
user_id = payload.get('sub')
if not user_id:
return False, "Invalid token"
if payload.get('exp', 0) < time.time():
return False, "Token expired"
# Store session
self.user_sessions[websocket] = user_id
self.rate_limits[user_id] = []
return True, user_id
except jwt.InvalidTokenError as e:
logger.warning(f"Authentication failed: {e}")
return False, "Authentication failed"
async def rate_limit_check(self, user_id: str) -> bool:
"""Check if user is within rate limits."""
now = time.time()
# Clean old entries
self.rate_limits[user_id] = [
t for t in self.rate_limits[user_id]
if now - t < 60
]
if len(self.rate_limits[user_id]) >= RATE_LIMIT_MESSAGES:
return False
self.rate_limits[user_id].append(now)
return True
async def handle(self, websocket):
"""Main WebSocket handler."""
# Validate connection
valid, error = await self.validate_connection(websocket.request)
if not valid:
await websocket.close(1008, error)
return
# Authenticate
try:
# Wait for auth message
auth_msg = await asyncio.wait_for(websocket.recv(), timeout=10)
auth_data = json.loads(auth_msg)
if auth_data.get('type') != 'auth':
await websocket.close(1008, "First message must be auth")
return
token = auth_data.get('token')
valid, result = await self.authenticate(websocket, token)
if not valid:
await websocket.close(1008, result)
return
user_id = result
logger.info(f"User {user_id} authenticated")
except asyncio.TimeoutError:
await websocket.close(1008, "Authentication timeout")
return
except json.JSONDecodeError:
await websocket.close(1008, "Invalid message format")
return
# Handle messages
try:
async for message in websocket:
# Rate limit check
if not await self.rate_limit_check(user_id):
await websocket.send(json.dumps({
'type': 'error',
'message': 'Rate limit exceeded'
}))
continue
# Process message
await self.process_message(user_id, message, websocket)
except websockets.exceptions.ConnectionClosed:
logger.info(f"Connection closed for user {user_id}")
finally:
# Cleanup
self.clients.pop(websocket, None)
self.user_sessions.pop(websocket, None)
async def process_message(self, user_id: str, message: str, websocket):
"""Process authenticated message."""
try:
data = json.loads(message)
# Handle different message types
msg_type = data.get('type')
if msg_type == 'ping':
await websocket.send(json.dumps({'type': 'pong'}))
elif msg_type == 'message':
# Process chat/feature message
content = data.get('content', '')[:5000]
await websocket.send(json.dumps({
'type': 'ack',
'message': 'Message received'
}))
else:
await websocket.send(json.dumps({
'type': 'error',
'message': f'Unknown message type: {msg_type}'
}))
except json.JSONDecodeError:
await websocket.send(json.dumps({
'type': 'error',
'message': 'Invalid JSON'
}))
async def start(self, host="0.0.0.0", port=8765):
"""Start the server."""
async with websockets.serve(self.handle, host, port):
logger.info(f"Secure WebSocket server started on {host}:{port}")
await asyncio.Future() # Run forever
# Run server
server = SecureWebSocketServer()
asyncio.run(server.start())
Conclusion
WebSocket security requires a defense-in-depth approach. Authentication, authorization, input validation, rate limiting, and encryption work together to protect your real-time applications.
Key takeaways:
- Always use WSS with TLS 1.3 for encrypted connections
- Authenticate at connection time with validated tokens
- Implement granular authorization per channel
- Validate and sanitize all incoming messages
- Rate limit to prevent DoS attacks
- Log security events for monitoring and incident response
By following the patterns in this guide, you can build secure WebSocket applications that protect your users and infrastructure.
Resources
- WebSocket Protocol RFC 6455
- OWASP WebSocket Security Cheat Sheet
- RFC 8446 - TLS 1.3
- WebSocket API MDN Documentation
Comments