Skip to main content
โšก Calmops

API Security Beyond JWT: OAuth2, Rate Limiting, and Prevention

API Security Beyond JWT: OAuth2, Rate Limiting, and Prevention

JWT is only one piece of API security. This guide covers comprehensive API security including OAuth2, rate limiting, CORS, CSRF, and attack prevention.


OAuth2 Implementation

Authorization Code Flow

from fastapi import FastAPI, HTTPException, Depends
from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm
from authlib.integrations.starlette_client import OAuth
from starlette.responses import RedirectResponse
import httpx
import json

app = FastAPI()
oauth = OAuth()

# Configure OAuth2 provider
oauth.register(
    name='google',
    client_id='YOUR_CLIENT_ID',
    client_secret='YOUR_CLIENT_SECRET',
    server_metadata_url='https://accounts.google.com/.well-known/openid-configuration',
    client_kwargs={'scope': 'openid email profile'}
)

oauth.register(
    name='github',
    client_id='YOUR_GITHUB_CLIENT_ID',
    client_secret='YOUR_GITHUB_CLIENT_SECRET',
    api_base_url='https://api.github.com/',
    access_token_url='https://github.com/login/oauth/access_token',
    access_token_params=None,
    authorize_url='https://github.com/login/oauth/authorize',
    authorize_params=None,
    client_kwargs={'scope': 'user:email'},
)

@app.get('/login/{provider}')
async def login(request, provider: str):
    """Initiate OAuth2 flow"""
    redirect_uri = request.url_for('auth_callback', provider=provider)
    return await oauth.create_client(provider).authorize_redirect(str(redirect_uri))

@app.get('/auth/{provider}/callback')
async def auth_callback(request, provider: str):
    """Handle OAuth2 callback"""
    try:
        # Exchange authorization code for token
        token = await oauth.create_client(provider).authorize_access_token(request)
        
        # Get user info
        user = await oauth.create_client(provider).get('user', token=token)
        
        # Create session/JWT
        user_data = user.json()
        session_token = create_session_token(user_data)
        
        # Redirect to frontend with token
        return RedirectResponse(
            url=f"https://example.com/dashboard?token={session_token}"
        )
    except Exception as e:
        raise HTTPException(status_code=400, detail="Authentication failed")

def create_session_token(user_data: dict) -> str:
    """Create secure session token after OAuth2 verification"""
    import jwt
    import secrets
    
    payload = {
        'sub': user_data['id'],
        'email': user_data['email'],
        'name': user_data.get('name'),
        'iat': datetime.utcnow(),
        'exp': datetime.utcnow() + timedelta(hours=24),
        'jti': secrets.token_urlsafe(32)  # JWT ID for revocation
    }
    
    return jwt.encode(payload, SECRET_KEY, algorithm="HS256")

Refresh Token Rotation

from datetime import datetime, timedelta
import jwt
import secrets
from typing import Optional

class RefreshTokenManager:
    """Manage refresh token rotation for security"""
    
    def __init__(self, db):
        self.db = db
    
    def issue_token_pair(self, user_id: str) -> dict:
        """Issue access token + refresh token"""
        
        # Access token: short-lived (15 minutes)
        access_token = self._create_access_token(user_id)
        
        # Refresh token: longer-lived (7 days)
        refresh_token = self._create_refresh_token(user_id)
        
        # Store refresh token in DB (for revocation)
        self.db.save_refresh_token(
            user_id=user_id,
            token_hash=hash_token(refresh_token),
            expires_at=datetime.utcnow() + timedelta(days=7),
            family_id=secrets.token_urlsafe(16)  # Token family for rotation
        )
        
        return {
            'access_token': access_token,
            'refresh_token': refresh_token,
            'token_type': 'Bearer',
            'expires_in': 900  # 15 minutes
        }
    
    def refresh_token_pair(self, refresh_token: str) -> dict:
        """Rotate refresh token"""
        
        # Verify token
        payload = self._verify_refresh_token(refresh_token)
        user_id = payload['sub']
        
        # Check if token is in DB and valid
        stored_token = self.db.get_refresh_token(user_id, hash_token(refresh_token))
        if not stored_token:
            raise HTTPException(status_code=401, detail="Invalid refresh token")
        
        if stored_token['expires_at'] < datetime.utcnow():
            raise HTTPException(status_code=401, detail="Refresh token expired")
        
        # Detect refresh token reuse (attack indicator)
        if self._detect_token_reuse(user_id, payload['jti']):
            # Invalidate all tokens for user (breach suspected)
            self.db.revoke_all_tokens(user_id)
            raise HTTPException(status_code=401, detail="Token reuse detected")
        
        # Issue new token pair (rotate)
        new_tokens = self.issue_token_pair(user_id)
        
        # Invalidate old refresh token
        self.db.revoke_refresh_token(user_id, hash_token(refresh_token))
        
        return new_tokens
    
    def _create_access_token(self, user_id: str) -> str:
        """Create short-lived access token"""
        payload = {
            'sub': user_id,
            'type': 'access',
            'iat': datetime.utcnow(),
            'exp': datetime.utcnow() + timedelta(minutes=15),
            'jti': secrets.token_urlsafe(16)
        }
        return jwt.encode(payload, SECRET_KEY, algorithm="HS256")
    
    def _create_refresh_token(self, user_id: str) -> str:
        """Create longer-lived refresh token"""
        payload = {
            'sub': user_id,
            'type': 'refresh',
            'iat': datetime.utcnow(),
            'exp': datetime.utcnow() + timedelta(days=7),
            'jti': secrets.token_urlsafe(16)
        }
        return jwt.encode(payload, SECRET_KEY, algorithm="HS256")
    
    def _detect_token_reuse(self, user_id: str, jti: str) -> bool:
        """Detect if token was already used (refresh attack)"""
        return self.db.check_token_jti_used(user_id, jti)

def hash_token(token: str) -> str:
    """Hash token for secure storage"""
    import hashlib
    return hashlib.sha256(token.encode()).hexdigest()

Rate Limiting and Throttling

Token Bucket Algorithm

import time
from collections import defaultdict
from typing import Optional

class TokenBucketRateLimiter:
    """Token bucket algorithm for rate limiting"""
    
    def __init__(self, capacity: int = 100, refill_rate: float = 10):
        """
        capacity: max tokens in bucket
        refill_rate: tokens added per second
        """
        self.capacity = capacity
        self.refill_rate = refill_rate
        self.buckets: dict = defaultdict(lambda: {
            'tokens': capacity,
            'last_refill': time.time()
        })
    
    def is_allowed(self, user_id: str, tokens_required: int = 1) -> tuple[bool, dict]:
        """Check if request is allowed"""
        
        bucket = self.buckets[user_id]
        now = time.time()
        
        # Refill bucket based on time elapsed
        time_passed = now - bucket['last_refill']
        tokens_to_add = time_passed * self.refill_rate
        bucket['tokens'] = min(self.capacity, bucket['tokens'] + tokens_to_add)
        bucket['last_refill'] = now
        
        # Check if we have enough tokens
        if bucket['tokens'] >= tokens_required:
            bucket['tokens'] -= tokens_required
            return True, {
                'remaining': int(bucket['tokens']),
                'reset_at': now + (self.capacity - bucket['tokens']) / self.refill_rate
            }
        else:
            # Calculate when request would be allowed
            time_to_retry = (tokens_required - bucket['tokens']) / self.refill_rate
            return False, {
                'retry_after': int(time_to_retry) + 1,
                'remaining': 0
            }

class AdaptiveRateLimiter:
    """Adaptive rate limiting based on user tier"""
    
    def __init__(self):
        self.limiters = {
            'free': TokenBucketRateLimiter(capacity=100, refill_rate=10),
            'pro': TokenBucketRateLimiter(capacity=1000, refill_rate=100),
            'enterprise': TokenBucketRateLimiter(capacity=10000, refill_rate=1000),
        }
    
    def check_rate_limit(self, user_id: str, user_tier: str) -> tuple[bool, dict]:
        """Check rate limit based on user tier"""
        
        limiter = self.limiters.get(user_tier, self.limiters['free'])
        return limiter.is_allowed(user_id)

# FastAPI integration
from fastapi import Header, HTTPException

async def rate_limit_check(
    x_user_id: str = Header(...),
    x_user_tier: str = Header(default="free")
):
    """Rate limit middleware"""
    limiter = AdaptiveRateLimiter()
    allowed, info = limiter.check_rate_limit(x_user_id, x_user_tier)
    
    if not allowed:
        raise HTTPException(
            status_code=429,
            detail="Rate limit exceeded",
            headers={"Retry-After": str(info['retry_after'])}
        )
    
    return info

@app.get("/api/resource", dependencies=[Depends(rate_limit_check)])
async def get_resource(rate_info: dict = Depends(rate_limit_check)):
    """Protected endpoint with rate limiting"""
    return {
        'data': 'sensitive data',
        'remaining_requests': rate_info['remaining']
    }

Distributed Rate Limiting with Redis

import redis
import time
import json

class DistributedRateLimiter:
    """Rate limiting across multiple servers using Redis"""
    
    def __init__(self, redis_host='localhost', redis_port=6379):
        self.redis = redis.Redis(host=redis_host, port=redis_port, decode_responses=True)
    
    def check_rate_limit(self, user_id: str, limit: int, window: int) -> tuple[bool, dict]:
        """
        Check rate limit using sliding window with Redis
        limit: max requests
        window: time window in seconds
        """
        
        key = f"rate_limit:{user_id}"
        now = time.time()
        window_start = now - window
        
        # Remove old requests outside window
        self.redis.zremrangebyscore(key, 0, window_start)
        
        # Count requests in window
        current_requests = self.redis.zcard(key)
        
        if current_requests < limit:
            # Add current request
            self.redis.zadd(key, {str(now): now})
            
            # Set expiry
            self.redis.expire(key, window + 1)
            
            return True, {
                'remaining': limit - current_requests - 1,
                'reset_at': window_start + window
            }
        else:
            # Get oldest request time
            oldest_request = float(self.redis.zrange(key, 0, 0, withscores=True)[0][1])
            retry_after = int(oldest_request + window - now) + 1
            
            return False, {
                'retry_after': retry_after,
                'reset_at': oldest_request + window
            }

# Distributed limiter
limiter = DistributedRateLimiter()

@app.get("/api/endpoint")
async def protected_endpoint(user_id: str = Header(...)):
    allowed, info = limiter.check_rate_limit(
        user_id=user_id,
        limit=100,  # 100 requests
        window=3600  # per hour
    )
    
    if not allowed:
        raise HTTPException(
            status_code=429,
            detail="Rate limit exceeded",
            headers={"Retry-After": str(info['retry_after'])}
        )
    
    return {"status": "ok", "remaining": info['remaining']}

CORS and CSRF Protection

Secure CORS Configuration

from fastapi.middleware.cors import CORSMiddleware

app.add_middleware(
    CORSMiddleware,
    # Only allow specific origins (not *)
    allow_origins=[
        "https://example.com",
        "https://www.example.com",
        "https://app.example.com",
    ],
    allow_credentials=True,
    allow_methods=["GET", "POST", "PUT", "DELETE", "OPTIONS"],
    allow_headers=[
        "Content-Type",
        "Authorization",
        "X-CSRF-Token",
        "X-Requested-With",
    ],
    expose_headers=["X-Total-Count"],
    max_age=3600,  # 1 hour
)

# Stricter configuration for sensitive endpoints
@app.get("/api/admin/settings")
async def get_admin_settings(origin: str = Header(...)):
    """Admin endpoint with strict CORS"""
    
    allowed_origins = ["https://admin.example.com"]
    
    if origin not in allowed_origins:
        raise HTTPException(status_code=403, detail="CORS policy violation")
    
    return {"settings": "admin only"}

CSRF Token Implementation

import secrets
import hashlib
from datetime import datetime, timedelta

class CSRFProtection:
    """CSRF token generation and validation"""
    
    def __init__(self, db):
        self.db = db
    
    def generate_csrf_token(self, session_id: str) -> str:
        """Generate CSRF token for session"""
        
        # Generate random token
        token = secrets.token_urlsafe(32)
        token_hash = hashlib.sha256(token.encode()).hexdigest()
        
        # Store in DB associated with session
        self.db.store_csrf_token(
            session_id=session_id,
            token_hash=token_hash,
            created_at=datetime.utcnow(),
            expires_at=datetime.utcnow() + timedelta(hours=2)
        )
        
        return token
    
    def validate_csrf_token(self, session_id: str, token: str) -> bool:
        """Validate CSRF token"""
        
        token_hash = hashlib.sha256(token.encode()).hexdigest()
        
        # Look up token in DB
        stored_token = self.db.get_csrf_token(session_id)
        
        if not stored_token:
            return False
        
        # Check expiry
        if stored_token['expires_at'] < datetime.utcnow():
            return False
        
        # Use constant-time comparison to prevent timing attacks
        return secrets.compare_digest(stored_token['token_hash'], token_hash)

# FastAPI integration
csrf = CSRFProtection(db)

@app.post("/api/form/submit")
async def submit_form(
    csrf_token: str = Header(None),
    session_id: str = Cookie(None)
):
    """Form submission with CSRF protection"""
    
    if not csrf.validate_csrf_token(session_id, csrf_token):
        raise HTTPException(status_code=403, detail="CSRF validation failed")
    
    # Process form...
    return {"status": "success"}

@app.get("/api/form")
async def get_form(session_id: str = Cookie(None)):
    """Get form with CSRF token"""
    
    csrf_token = csrf.generate_csrf_token(session_id)
    
    return {
        "form_fields": [...],
        "csrf_token": csrf_token
    }

Injection Attack Prevention

SQL Injection Prevention

from sqlalchemy import text
from sqlalchemy.orm import Session

class SafeQueryBuilder:
    """Build SQL queries safely using parameterized queries"""
    
    @staticmethod
    def search_users(db: Session, search_term: str):
        """Safe search using parameterized queries"""
        
        # WRONG: Vulnerable to SQL injection
        # query = f"SELECT * FROM users WHERE email LIKE '%{search_term}%'"
        
        # RIGHT: Use parameterized queries
        query = db.query(User).filter(
            User.email.ilike(f"%{search_term}%")
        )
        
        return query.all()
    
    @staticmethod
    def complex_search(db: Session, filters: dict):
        """Safe complex query builder"""
        
        query = db.query(User)
        
        # Build query safely with parameterized filters
        if 'email' in filters:
            query = query.filter(User.email == filters['email'])
        
        if 'name' in filters:
            query = query.filter(User.name.ilike(f"%{filters['name']}%"))
        
        if 'created_after' in filters:
            query = query.filter(User.created_at > filters['created_after'])
        
        return query.all()
    
    @staticmethod
    def raw_sql_with_params(db: Session, user_id: int):
        """Raw SQL with proper parameterization"""
        
        # Using text() and params prevents SQL injection
        result = db.execute(
            text("SELECT * FROM users WHERE id = :user_id"),
            {"user_id": user_id}
        )
        
        return result.fetchall()

# Input validation
from pydantic import BaseModel, validator

class SearchQuery(BaseModel):
    term: str
    
    @validator('term')
    def validate_search_term(cls, v):
        # Limit length
        if len(v) > 100:
            raise ValueError('Search term too long')
        
        # Prevent null bytes
        if '\x00' in v:
            raise ValueError('Invalid characters')
        
        return v

Command Injection Prevention

import subprocess
import shlex
from pathlib import Path

class SafeProcessRunner:
    """Run system commands safely"""
    
    @staticmethod
    def run_backup(database_name: str):
        """Safely run backup command"""
        
        # WRONG: Vulnerable to command injection
        # os.system(f"mysqldump {database_name} > backup.sql")
        
        # RIGHT: Use subprocess with list of args (no shell expansion)
        try:
            result = subprocess.run(
                [
                    'mysqldump',
                    '--user=backup_user',
                    '--password=secure_password',
                    database_name
                ],
                capture_output=True,
                check=True,
                timeout=300  # 5 minutes timeout
            )
            
            return result.stdout
        except subprocess.TimeoutExpired:
            print("Backup timeout")
        except subprocess.CalledProcessError as e:
            print(f"Backup failed: {e.stderr}")
    
    @staticmethod
    def process_file(filename: str):
        """Process file safely"""
        
        # Validate filename
        if not filename.endswith('.txt'):
            raise ValueError("Only .txt files allowed")
        
        # Use Path for safe file handling
        filepath = Path('uploads') / filename
        
        # Ensure file is within allowed directory
        try:
            filepath.resolve().relative_to(Path('uploads').resolve())
        except ValueError:
            raise ValueError("Path traversal attempt detected")
        
        if not filepath.exists():
            raise FileNotFoundError(f"File not found: {filename}")
        
        # Process file safely
        with open(filepath, 'r') as f:
            return f.read()

# Use shlex for safe command parsing (if needed)
def safe_grep(pattern: str, filename: str):
    """Safe grep command"""
    
    # Validate inputs
    pattern = pattern[:100]  # Limit pattern length
    
    try:
        result = subprocess.run(
            ['grep', '--', pattern, filename],
            capture_output=True,
            check=False,  # Don't raise on grep not finding match
            timeout=10
        )
        return result.stdout.decode()
    except subprocess.TimeoutExpired:
        raise ValueError("Pattern matching timeout")

API Gateway Patterns

from typing import Callable, Any

class APIGateway:
    """Centralized API gateway with security features"""
    
    def __init__(self):
        self.rate_limiters = {}
        self.request_validators = []
    
    async def handle_request(self, 
                           request,
                           handler: Callable,
                           user_id: str = None) -> Any:
        """Handle request with security checks"""
        
        # 1. Rate limiting
        if user_id:
            allowed, info = self.check_rate_limit(user_id)
            if not allowed:
                return {
                    'status': 429,
                    'body': 'Rate limit exceeded',
                    'headers': {'Retry-After': str(info['retry_after'])}
                }
        
        # 2. Request validation
        for validator in self.request_validators:
            if not validator(request):
                return {'status': 400, 'body': 'Invalid request'}
        
        # 3. Security headers
        response = await handler(request)
        
        response_headers = {
            'X-Content-Type-Options': 'nosniff',
            'X-Frame-Options': 'DENY',
            'X-XSS-Protection': '1; mode=block',
            'Strict-Transport-Security': 'max-age=31536000; includeSubDomains',
            'Content-Security-Policy': "default-src 'self'",
            'Referrer-Policy': 'strict-origin-when-cross-origin'
        }
        
        response['headers'] = {**response.get('headers', {}), **response_headers}
        
        return response

Glossary

  • OAuth2: Authorization protocol
  • CSRF: Cross-Site Request Forgery
  • CORS: Cross-Origin Resource Sharing
  • JTI: JWT ID (unique token identifier)
  • Token Rotation: Regularly issuing new tokens
  • Rate Limiting: Controlling request frequency
  • Parameterized Query: SQL query with separated parameters

Resources

Comments