API Security Beyond JWT: OAuth2, Rate Limiting, and Prevention
JWT is only one piece of API security. This guide covers comprehensive API security including OAuth2, rate limiting, CORS, CSRF, and attack prevention.
OAuth2 Implementation
Authorization Code Flow
from fastapi import FastAPI, HTTPException, Depends
from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm
from authlib.integrations.starlette_client import OAuth
from starlette.responses import RedirectResponse
import httpx
import json
app = FastAPI()
oauth = OAuth()
# Configure OAuth2 provider
oauth.register(
name='google',
client_id='YOUR_CLIENT_ID',
client_secret='YOUR_CLIENT_SECRET',
server_metadata_url='https://accounts.google.com/.well-known/openid-configuration',
client_kwargs={'scope': 'openid email profile'}
)
oauth.register(
name='github',
client_id='YOUR_GITHUB_CLIENT_ID',
client_secret='YOUR_GITHUB_CLIENT_SECRET',
api_base_url='https://api.github.com/',
access_token_url='https://github.com/login/oauth/access_token',
access_token_params=None,
authorize_url='https://github.com/login/oauth/authorize',
authorize_params=None,
client_kwargs={'scope': 'user:email'},
)
@app.get('/login/{provider}')
async def login(request, provider: str):
"""Initiate OAuth2 flow"""
redirect_uri = request.url_for('auth_callback', provider=provider)
return await oauth.create_client(provider).authorize_redirect(str(redirect_uri))
@app.get('/auth/{provider}/callback')
async def auth_callback(request, provider: str):
"""Handle OAuth2 callback"""
try:
# Exchange authorization code for token
token = await oauth.create_client(provider).authorize_access_token(request)
# Get user info
user = await oauth.create_client(provider).get('user', token=token)
# Create session/JWT
user_data = user.json()
session_token = create_session_token(user_data)
# Redirect to frontend with token
return RedirectResponse(
url=f"https://example.com/dashboard?token={session_token}"
)
except Exception as e:
raise HTTPException(status_code=400, detail="Authentication failed")
def create_session_token(user_data: dict) -> str:
"""Create secure session token after OAuth2 verification"""
import jwt
import secrets
payload = {
'sub': user_data['id'],
'email': user_data['email'],
'name': user_data.get('name'),
'iat': datetime.utcnow(),
'exp': datetime.utcnow() + timedelta(hours=24),
'jti': secrets.token_urlsafe(32) # JWT ID for revocation
}
return jwt.encode(payload, SECRET_KEY, algorithm="HS256")
Refresh Token Rotation
from datetime import datetime, timedelta
import jwt
import secrets
from typing import Optional
class RefreshTokenManager:
"""Manage refresh token rotation for security"""
def __init__(self, db):
self.db = db
def issue_token_pair(self, user_id: str) -> dict:
"""Issue access token + refresh token"""
# Access token: short-lived (15 minutes)
access_token = self._create_access_token(user_id)
# Refresh token: longer-lived (7 days)
refresh_token = self._create_refresh_token(user_id)
# Store refresh token in DB (for revocation)
self.db.save_refresh_token(
user_id=user_id,
token_hash=hash_token(refresh_token),
expires_at=datetime.utcnow() + timedelta(days=7),
family_id=secrets.token_urlsafe(16) # Token family for rotation
)
return {
'access_token': access_token,
'refresh_token': refresh_token,
'token_type': 'Bearer',
'expires_in': 900 # 15 minutes
}
def refresh_token_pair(self, refresh_token: str) -> dict:
"""Rotate refresh token"""
# Verify token
payload = self._verify_refresh_token(refresh_token)
user_id = payload['sub']
# Check if token is in DB and valid
stored_token = self.db.get_refresh_token(user_id, hash_token(refresh_token))
if not stored_token:
raise HTTPException(status_code=401, detail="Invalid refresh token")
if stored_token['expires_at'] < datetime.utcnow():
raise HTTPException(status_code=401, detail="Refresh token expired")
# Detect refresh token reuse (attack indicator)
if self._detect_token_reuse(user_id, payload['jti']):
# Invalidate all tokens for user (breach suspected)
self.db.revoke_all_tokens(user_id)
raise HTTPException(status_code=401, detail="Token reuse detected")
# Issue new token pair (rotate)
new_tokens = self.issue_token_pair(user_id)
# Invalidate old refresh token
self.db.revoke_refresh_token(user_id, hash_token(refresh_token))
return new_tokens
def _create_access_token(self, user_id: str) -> str:
"""Create short-lived access token"""
payload = {
'sub': user_id,
'type': 'access',
'iat': datetime.utcnow(),
'exp': datetime.utcnow() + timedelta(minutes=15),
'jti': secrets.token_urlsafe(16)
}
return jwt.encode(payload, SECRET_KEY, algorithm="HS256")
def _create_refresh_token(self, user_id: str) -> str:
"""Create longer-lived refresh token"""
payload = {
'sub': user_id,
'type': 'refresh',
'iat': datetime.utcnow(),
'exp': datetime.utcnow() + timedelta(days=7),
'jti': secrets.token_urlsafe(16)
}
return jwt.encode(payload, SECRET_KEY, algorithm="HS256")
def _detect_token_reuse(self, user_id: str, jti: str) -> bool:
"""Detect if token was already used (refresh attack)"""
return self.db.check_token_jti_used(user_id, jti)
def hash_token(token: str) -> str:
"""Hash token for secure storage"""
import hashlib
return hashlib.sha256(token.encode()).hexdigest()
Rate Limiting and Throttling
Token Bucket Algorithm
import time
from collections import defaultdict
from typing import Optional
class TokenBucketRateLimiter:
"""Token bucket algorithm for rate limiting"""
def __init__(self, capacity: int = 100, refill_rate: float = 10):
"""
capacity: max tokens in bucket
refill_rate: tokens added per second
"""
self.capacity = capacity
self.refill_rate = refill_rate
self.buckets: dict = defaultdict(lambda: {
'tokens': capacity,
'last_refill': time.time()
})
def is_allowed(self, user_id: str, tokens_required: int = 1) -> tuple[bool, dict]:
"""Check if request is allowed"""
bucket = self.buckets[user_id]
now = time.time()
# Refill bucket based on time elapsed
time_passed = now - bucket['last_refill']
tokens_to_add = time_passed * self.refill_rate
bucket['tokens'] = min(self.capacity, bucket['tokens'] + tokens_to_add)
bucket['last_refill'] = now
# Check if we have enough tokens
if bucket['tokens'] >= tokens_required:
bucket['tokens'] -= tokens_required
return True, {
'remaining': int(bucket['tokens']),
'reset_at': now + (self.capacity - bucket['tokens']) / self.refill_rate
}
else:
# Calculate when request would be allowed
time_to_retry = (tokens_required - bucket['tokens']) / self.refill_rate
return False, {
'retry_after': int(time_to_retry) + 1,
'remaining': 0
}
class AdaptiveRateLimiter:
"""Adaptive rate limiting based on user tier"""
def __init__(self):
self.limiters = {
'free': TokenBucketRateLimiter(capacity=100, refill_rate=10),
'pro': TokenBucketRateLimiter(capacity=1000, refill_rate=100),
'enterprise': TokenBucketRateLimiter(capacity=10000, refill_rate=1000),
}
def check_rate_limit(self, user_id: str, user_tier: str) -> tuple[bool, dict]:
"""Check rate limit based on user tier"""
limiter = self.limiters.get(user_tier, self.limiters['free'])
return limiter.is_allowed(user_id)
# FastAPI integration
from fastapi import Header, HTTPException
async def rate_limit_check(
x_user_id: str = Header(...),
x_user_tier: str = Header(default="free")
):
"""Rate limit middleware"""
limiter = AdaptiveRateLimiter()
allowed, info = limiter.check_rate_limit(x_user_id, x_user_tier)
if not allowed:
raise HTTPException(
status_code=429,
detail="Rate limit exceeded",
headers={"Retry-After": str(info['retry_after'])}
)
return info
@app.get("/api/resource", dependencies=[Depends(rate_limit_check)])
async def get_resource(rate_info: dict = Depends(rate_limit_check)):
"""Protected endpoint with rate limiting"""
return {
'data': 'sensitive data',
'remaining_requests': rate_info['remaining']
}
Distributed Rate Limiting with Redis
import redis
import time
import json
class DistributedRateLimiter:
"""Rate limiting across multiple servers using Redis"""
def __init__(self, redis_host='localhost', redis_port=6379):
self.redis = redis.Redis(host=redis_host, port=redis_port, decode_responses=True)
def check_rate_limit(self, user_id: str, limit: int, window: int) -> tuple[bool, dict]:
"""
Check rate limit using sliding window with Redis
limit: max requests
window: time window in seconds
"""
key = f"rate_limit:{user_id}"
now = time.time()
window_start = now - window
# Remove old requests outside window
self.redis.zremrangebyscore(key, 0, window_start)
# Count requests in window
current_requests = self.redis.zcard(key)
if current_requests < limit:
# Add current request
self.redis.zadd(key, {str(now): now})
# Set expiry
self.redis.expire(key, window + 1)
return True, {
'remaining': limit - current_requests - 1,
'reset_at': window_start + window
}
else:
# Get oldest request time
oldest_request = float(self.redis.zrange(key, 0, 0, withscores=True)[0][1])
retry_after = int(oldest_request + window - now) + 1
return False, {
'retry_after': retry_after,
'reset_at': oldest_request + window
}
# Distributed limiter
limiter = DistributedRateLimiter()
@app.get("/api/endpoint")
async def protected_endpoint(user_id: str = Header(...)):
allowed, info = limiter.check_rate_limit(
user_id=user_id,
limit=100, # 100 requests
window=3600 # per hour
)
if not allowed:
raise HTTPException(
status_code=429,
detail="Rate limit exceeded",
headers={"Retry-After": str(info['retry_after'])}
)
return {"status": "ok", "remaining": info['remaining']}
CORS and CSRF Protection
Secure CORS Configuration
from fastapi.middleware.cors import CORSMiddleware
app.add_middleware(
CORSMiddleware,
# Only allow specific origins (not *)
allow_origins=[
"https://example.com",
"https://www.example.com",
"https://app.example.com",
],
allow_credentials=True,
allow_methods=["GET", "POST", "PUT", "DELETE", "OPTIONS"],
allow_headers=[
"Content-Type",
"Authorization",
"X-CSRF-Token",
"X-Requested-With",
],
expose_headers=["X-Total-Count"],
max_age=3600, # 1 hour
)
# Stricter configuration for sensitive endpoints
@app.get("/api/admin/settings")
async def get_admin_settings(origin: str = Header(...)):
"""Admin endpoint with strict CORS"""
allowed_origins = ["https://admin.example.com"]
if origin not in allowed_origins:
raise HTTPException(status_code=403, detail="CORS policy violation")
return {"settings": "admin only"}
CSRF Token Implementation
import secrets
import hashlib
from datetime import datetime, timedelta
class CSRFProtection:
"""CSRF token generation and validation"""
def __init__(self, db):
self.db = db
def generate_csrf_token(self, session_id: str) -> str:
"""Generate CSRF token for session"""
# Generate random token
token = secrets.token_urlsafe(32)
token_hash = hashlib.sha256(token.encode()).hexdigest()
# Store in DB associated with session
self.db.store_csrf_token(
session_id=session_id,
token_hash=token_hash,
created_at=datetime.utcnow(),
expires_at=datetime.utcnow() + timedelta(hours=2)
)
return token
def validate_csrf_token(self, session_id: str, token: str) -> bool:
"""Validate CSRF token"""
token_hash = hashlib.sha256(token.encode()).hexdigest()
# Look up token in DB
stored_token = self.db.get_csrf_token(session_id)
if not stored_token:
return False
# Check expiry
if stored_token['expires_at'] < datetime.utcnow():
return False
# Use constant-time comparison to prevent timing attacks
return secrets.compare_digest(stored_token['token_hash'], token_hash)
# FastAPI integration
csrf = CSRFProtection(db)
@app.post("/api/form/submit")
async def submit_form(
csrf_token: str = Header(None),
session_id: str = Cookie(None)
):
"""Form submission with CSRF protection"""
if not csrf.validate_csrf_token(session_id, csrf_token):
raise HTTPException(status_code=403, detail="CSRF validation failed")
# Process form...
return {"status": "success"}
@app.get("/api/form")
async def get_form(session_id: str = Cookie(None)):
"""Get form with CSRF token"""
csrf_token = csrf.generate_csrf_token(session_id)
return {
"form_fields": [...],
"csrf_token": csrf_token
}
Injection Attack Prevention
SQL Injection Prevention
from sqlalchemy import text
from sqlalchemy.orm import Session
class SafeQueryBuilder:
"""Build SQL queries safely using parameterized queries"""
@staticmethod
def search_users(db: Session, search_term: str):
"""Safe search using parameterized queries"""
# WRONG: Vulnerable to SQL injection
# query = f"SELECT * FROM users WHERE email LIKE '%{search_term}%'"
# RIGHT: Use parameterized queries
query = db.query(User).filter(
User.email.ilike(f"%{search_term}%")
)
return query.all()
@staticmethod
def complex_search(db: Session, filters: dict):
"""Safe complex query builder"""
query = db.query(User)
# Build query safely with parameterized filters
if 'email' in filters:
query = query.filter(User.email == filters['email'])
if 'name' in filters:
query = query.filter(User.name.ilike(f"%{filters['name']}%"))
if 'created_after' in filters:
query = query.filter(User.created_at > filters['created_after'])
return query.all()
@staticmethod
def raw_sql_with_params(db: Session, user_id: int):
"""Raw SQL with proper parameterization"""
# Using text() and params prevents SQL injection
result = db.execute(
text("SELECT * FROM users WHERE id = :user_id"),
{"user_id": user_id}
)
return result.fetchall()
# Input validation
from pydantic import BaseModel, validator
class SearchQuery(BaseModel):
term: str
@validator('term')
def validate_search_term(cls, v):
# Limit length
if len(v) > 100:
raise ValueError('Search term too long')
# Prevent null bytes
if '\x00' in v:
raise ValueError('Invalid characters')
return v
Command Injection Prevention
import subprocess
import shlex
from pathlib import Path
class SafeProcessRunner:
"""Run system commands safely"""
@staticmethod
def run_backup(database_name: str):
"""Safely run backup command"""
# WRONG: Vulnerable to command injection
# os.system(f"mysqldump {database_name} > backup.sql")
# RIGHT: Use subprocess with list of args (no shell expansion)
try:
result = subprocess.run(
[
'mysqldump',
'--user=backup_user',
'--password=secure_password',
database_name
],
capture_output=True,
check=True,
timeout=300 # 5 minutes timeout
)
return result.stdout
except subprocess.TimeoutExpired:
print("Backup timeout")
except subprocess.CalledProcessError as e:
print(f"Backup failed: {e.stderr}")
@staticmethod
def process_file(filename: str):
"""Process file safely"""
# Validate filename
if not filename.endswith('.txt'):
raise ValueError("Only .txt files allowed")
# Use Path for safe file handling
filepath = Path('uploads') / filename
# Ensure file is within allowed directory
try:
filepath.resolve().relative_to(Path('uploads').resolve())
except ValueError:
raise ValueError("Path traversal attempt detected")
if not filepath.exists():
raise FileNotFoundError(f"File not found: {filename}")
# Process file safely
with open(filepath, 'r') as f:
return f.read()
# Use shlex for safe command parsing (if needed)
def safe_grep(pattern: str, filename: str):
"""Safe grep command"""
# Validate inputs
pattern = pattern[:100] # Limit pattern length
try:
result = subprocess.run(
['grep', '--', pattern, filename],
capture_output=True,
check=False, # Don't raise on grep not finding match
timeout=10
)
return result.stdout.decode()
except subprocess.TimeoutExpired:
raise ValueError("Pattern matching timeout")
API Gateway Patterns
from typing import Callable, Any
class APIGateway:
"""Centralized API gateway with security features"""
def __init__(self):
self.rate_limiters = {}
self.request_validators = []
async def handle_request(self,
request,
handler: Callable,
user_id: str = None) -> Any:
"""Handle request with security checks"""
# 1. Rate limiting
if user_id:
allowed, info = self.check_rate_limit(user_id)
if not allowed:
return {
'status': 429,
'body': 'Rate limit exceeded',
'headers': {'Retry-After': str(info['retry_after'])}
}
# 2. Request validation
for validator in self.request_validators:
if not validator(request):
return {'status': 400, 'body': 'Invalid request'}
# 3. Security headers
response = await handler(request)
response_headers = {
'X-Content-Type-Options': 'nosniff',
'X-Frame-Options': 'DENY',
'X-XSS-Protection': '1; mode=block',
'Strict-Transport-Security': 'max-age=31536000; includeSubDomains',
'Content-Security-Policy': "default-src 'self'",
'Referrer-Policy': 'strict-origin-when-cross-origin'
}
response['headers'] = {**response.get('headers', {}), **response_headers}
return response
Glossary
- OAuth2: Authorization protocol
- CSRF: Cross-Site Request Forgery
- CORS: Cross-Origin Resource Sharing
- JTI: JWT ID (unique token identifier)
- Token Rotation: Regularly issuing new tokens
- Rate Limiting: Controlling request frequency
- Parameterized Query: SQL query with separated parameters
Comments