Skip to main content
โšก Calmops

API Security Beyond JWT: OAuth 2.0, Rate Limiting, and Protection Strategies

API Security Beyond JWT: OAuth 2.0, Rate Limiting, and Protection Strategies

TL;DR: This guide covers advanced API security patterns for Rust web services beyond basic JWT authentication. You’ll learn OAuth 2.0 implementation, rate limiting algorithms, CORS configuration, security headers, and protection against common vulnerabilities.


Introduction

While JWT authentication is foundational, production APIs require defense in depth. This article explores comprehensive security measures that protect against:

  • Unauthorized access and privilege escalation
  • API abuse and DoS attacks
  • Cross-site attacks (CSRF, XSS)
  • Injection attacks
  • Data exposure

OAuth 2.0 Implementation in Rust

OAuth 2.0 provides delegated authorization for third-party applications.

Installation

[dependencies]
oauth2 = "5.0"
reqwest = { version = "0.11", features = ["json"] }
serde = { version = "1.0", features = ["derive"] }
tokio = { version = "1", features = ["full"] }

OAuth 2.0 Client Implementation

use oauth2::{
    basic::BasicClient, AuthUrl, ClientId, ClientSecret, RedirectUrl, TokenResponse,
    TokenUrl, AuthorizationCode, Scope,
};
use oauth2::reqwest::http_client;
use serde::{Deserialize, Serialize};
use std::sync::Arc;
use tokio::sync::RwLock;
use std::collections::HashMap;

pub struct OAuthService {
    client: BasicClient,
    state: Arc<RwLock<HashMap<String, UserSession>>>,
}

#[derive(Clone, Serialize, Deserialize)]
pub struct UserSession {
    pub user_id: String,
    pub email: String,
    pub scopes: Vec<String>,
}

impl OAuthService {
    pub fn new(
        client_id: String,
        client_secret: String,
        auth_url: String,
        token_url: String,
        redirect_url: String,
    ) -> Self {
        let client = BasicClient::new(
            ClientId::new(client_id),
            Some(ClientSecret::new(client_secret)),
            AuthUrl::new(auth_url).unwrap(),
            Some(TokenUrl::new(token_url).unwrap()),
        )
        .set_redirect_uri(RedirectUrl::new(redirect_url).unwrap());

        Self {
            client,
            state: Arc::new(RwLock::new(HashMap::new())),
        }
    }

    pub fn generate_authorization_url(&self, state: &str, scopes: Vec<&str>) -> String {
        let mut request = self.client.authorize_url(
            Scope::new("read".to_string()),
            Scope::new("write".to_string()),
        );

        request = request.add_extra_param("state", state);

        for scope in scopes {
            request = request.add_scope(Scope::new(scope.to_string()));
        }

        request.url().to_string()
    }

    pub async fn exchange_code_for_token(
        &self,
        code: String,
    ) -> Result<TokenData, OAuthError> {
        let token_result = self.client
            .exchange_code(AuthorizationCode::new(code))
            .http_client(http_client)
            .send()
            .await
            .map_err(|e| OAuthError::TokenExchange(e.to_string()))?;

        let access_token = token_result.access_token().secret().to_string();
        let refresh_token = token_result.refresh_token()
            .map(|t| t.secret().to_string());
        let expires_in = token_result.expires_in();

        Ok(TokenData {
            access_token,
            refresh_token,
            expires_in,
            token_type: "Bearer".to_string(),
        })
    }

    pub async fn get_user_info(&self, access_token: &str) -> Result<UserSession, OAuthError> {
        let client = reqwest::Client::new();
        let response = client
            .get("https://oauth-provider.com/userinfo")
            .bearer_auth(access_token)
            .send()
            .await
            .map_err(|e| OAuthError::UserInfo(e.to_string()))?;

        let user_info: UserInfoResponse = response.json().await
            .map_err(|e| OAuthError::UserInfo(e.to_string()))?;

        Ok(UserSession {
            user_id: user_info.sub,
            email: user_info.email,
            scopes: vec!["read".to_string()],
        })
    }
}

#[derive(Debug)]
pub enum OAuthError {
    TokenExchange(String),
    UserInfo(String),
}

#[derive(Serialize, Deserialize)]
pub struct TokenData {
    pub access_token: String,
    pub refresh_token: Option<String>,
    pub expires_in: Option<u64>,
    pub token_type: String,
}

#[derive(Deserialize)]
struct UserInfoResponse {
    sub: String,
    email: String,
}

OAuth 2.0 Resource Server

use axum::{
    extract::HeaderValue,
    http::{Request, StatusCode},
    middleware::Next,
    response::Response,
};
use jsonwebtoken::{decode, Algorithm, DecodingKey, Validation};

pub struct JwtValidator {
    validation: Validation,
    decoding_key: DecodingKey,
}

impl JwtValidator {
    pub fn new(secret: &[u8]) -> Self {
        let mut validation = Validation::new(Algorithm::HS256);
        validation.set_issuer(&["oauth-provider"]);
        
        Self {
            validation,
            decoding_key: DecodingKey::from_secret(secret),
        }
    }

    pub fn validate_token(&self, token: &str) -> Result<Claims, String> {
        let token_data = decode::<Claims>(token, &self.decoding_key, &self.validation)
            .map_err(|e| e.to_string())?;
        
        Ok(token_data.claims)
    }
}

#[derive(Serialize, Deserialize)]
pub struct Claims {
    pub sub: String,
    pub email: String,
    pub iat: u64,
    pub exp: u64,
    pub scopes: Vec<String>,
}

Rate Limiting Implementation

Protect your API from abuse with rate limiting.

Token Bucket Algorithm

use std::sync::atomic::{AtomicU64, Ordering};
use std::time::{Duration, Instant};
use tokio::sync::RwLock;
use std::collections::HashMap;

pub struct TokenBucketRateLimiter {
    buckets: Arc<RwLock<HashMap<String, Bucket>>>,
    capacity: u64,
    refill_rate: f64, // tokens per second
}

struct Bucket {
    tokens: f64,
    last_refill: Instant,
}

impl TokenBucketRateLimiter {
    pub fn new(capacity: u64, refill_rate: f64) -> Self {
        Self {
            buckets: Arc::new(RwLock::new(HashMap::new())),
            capacity,
            refill_rate,
        }
    }

    pub async fn check_rate_limit(
        &self,
        key: &str,
        cost: u64,
    ) -> RateLimitResult {
        let mut buckets = self.buckets.write().await;
        
        let bucket = buckets.entry(key.to_string()).or_insert_with(|| Bucket {
            tokens: self.capacity as f64,
            last_refill: Instant::now(),
        });

        let now = Instant::now();
        let elapsed = now.duration_since(bucket.last_refill).as_secs_f64();
        
        bucket.tokens = (bucket.tokens + elapsed * self.refill_rate).min(self.capacity as f64);
        bucket.last_refill = now;

        if bucket.tokens >= cost as f64 {
            bucket.tokens -= cost as f64;
            RateLimitResult::Allowed {
                remaining: bucket.tokens as u64,
                reset_in: Duration::from_secs_f64(cost as f64 / self.refill_rate),
            }
        } else {
            RateLimitResult::Denied {
                retry_after: Duration::from_secs_f64(
                    (cost as f64 - bucket.tokens) / self.refill_rate
                ),
            }
        }
    }
}

pub enum RateLimitResult {
    Allowed { remaining: u64, reset_in: Duration },
    Denied { retry_after: Duration },
}

Axum Rate Limiting Middleware

use axum::{
    extract::Request,
    http::header::{HeaderName, HeaderValue},
    middleware::Next,
    response::Response,
};
use std::sync::Arc;

pub async fn rate_limit_middleware(
    rate_limiter: Arc<TokenBucketRateLimiter>,
    request: Request,
    next: Next,
) -> Response {
    let client_ip = extract_client_ip(&request);
    let cost = determine_request_cost(&request);
    
    match rate_limiter.check_rate_limit(&client_ip, cost).await {
        RateLimitResult::Allowed { remaining, reset_in } => {
            let mut response = next.run(request).await;
            
            response.headers_mut().insert(
                HeaderName::from_static("x-ratelimit-limit"),
                HeaderValue::from_str(&remaining.to_string()).unwrap(),
            );
            response.headers_mut().insert(
                HeaderName::from_static("x-ratelimit-remaining"),
                HeaderValue::from_str(&remaining.to_string()).unwrap(),
            );
            response.headers_mut().insert(
                HeaderName::from_static("x-ratelimit-reset"),
                HeaderValue::from_str(&reset_in.as_secs().to_string()).unwrap(),
            );
            
            response
        }
        RateLimitResult::Denied { retry_after } => {
            let mut response = Response::new("Rate limit exceeded".into());
            *response.status_mut() = StatusCode::TOO_MANY_REQUESTS;
            
            response.headers_mut().insert(
                HeaderName::from_static("retry-after"),
                HeaderValue::from_str(&retry_after.as_secs().to_string()).unwrap(),
            );
            
            response
        }
    }
}

fn extract_client_ip(request: &Request) -> String {
    request
        .headers()
        .get("x-forwarded-for")
        .and_then(|v| v.to_str().ok())
        .map(|s| s.split(',').next().unwrap_or(s).to_string())
        .unwrap_or_else(|| "unknown".to_string())
}

fn determine_request_cost(request: &Request) -> u64 {
    match request.uri().path() {
        p if p.starts_with("/api/write") => 10,
        p if p.starts_with("/api/read") => 1,
        p if p.starts_with("/api/search") => 5,
        _ => 1,
    }
}

Fixed Window Rate Limiter (Simpler)

use std::time::{Duration, Instant};
use tokio::sync::RwLock;
use std::collections::HashMap;

pub struct FixedWindowRateLimiter {
    windows: Arc<RwLock<HashMap<String, Window>>>,
    max_requests: u64,
    window_size: Duration,
}

struct Window {
    count: u64,
    window_start: Instant,
}

impl FixedWindowRateLimiter {
    pub fn new(max_requests: u64, window_size: Duration) -> Self {
        Self {
            windows: Arc::new(RwLock::new(HashMap::new())),
            max_requests,
            window_size,
        }
    }

    pub async fn is_allowed(&self, key: &str) -> bool {
        let mut windows = self.windows.write().await;
        let now = Instant::now();
        
        let window = windows.entry(key.to_string()).or_insert_with(|| Window {
            count: 0,
            window_start: now,
        });
        
        if now.duration_since(window.window_start) >= self.window_size {
            window.count = 0;
            window.window_start = now;
        }
        
        if window.count < self.max_requests {
            window.count += 1;
            true
        } else {
            false
        }
    }
}

CORS Configuration

Proper CORS Headers

use axum::{
    extract::State,
    http::{header, Method, Origin},
    response::Response,
    routing::get,
    Router,
};
use std::sync::Arc;

#[derive(Clone)]
struct CorsState {
    allowed_origins: Vec<String>,
    allowed_methods: Vec<Method>,
    allowed_headers: Vec<String>,
    max_age: u64,
}

async fn cors_handler(
    State(state): State<Arc<CorsState>>,
    Origin(origin): Origin,
    request: axum::extract::Request,
    next: Next,
) -> Response {
    let allowed = state.allowed_origins.iter()
        .any(|o| o == "*" || o == origin.as_str());
    
    let mut response = next.run(request).await;
    
    if allowed {
        let allowed_origin = if state.allowed_origins.contains(&"*".to_string()) {
            header::Value::from_static("*")
        } else {
            origin.to_str().unwrap_or("*").parse().unwrap()
        };
        
        response.headers_mut().insert(
            header::ACCESS_CONTROL_ALLOW_ORIGIN,
            allowed_origin,
        );
        
        response.headers_mut().insert(
            header::ACCESS_CONTROL_ALLOW_METHODS,
            state.allowed_methods.iter()
                .map(|m| m.as_str())
                .collect::<Vec<_>>()
                .join(", ")
                .parse()
                .unwrap(),
        );
        
        response.headers_mut().insert(
            header::ACCESS_CONTROL_ALLOW_HEADERS,
            state.allowed_headers.join(", ").parse().unwrap(),
        );
        
        response.headers_mut().insert(
            header::ACCESS_CONTROL_MAX_AGE,
            state.max_age.to_string().parse().unwrap(),
        );
        
        response.headers_mut().insert(
            header::ACCESS_CONTROL_ALLOW_CREDENTIALS,
            header::Value::from_static("true"),
        );
    }
    
    response
}

Security Headers

Setting CSP, HSTS, and Other Headers

use axum::{
    http::{header, HeaderName, HeaderValue, Method},
    middleware::Next,
    response::Response,
    extract::Request,
};

pub async fn security_headers_middleware(
    request: Request,
    next: Next,
) -> Response {
    let mut response = next.run(request).await;
    
    // Content Security Policy
    response.headers_mut().insert(
        HeaderName::from_static("content-security-policy"),
        HeaderValue::from_static(
            "default-src 'self'; \
             script-src 'self' 'unsafe-inline'; \
             style-src 'self' 'unsafe-inline'; \
             img-src 'self' data: https:; \
             font-src 'self'; \
             connect-src 'self' https://api.example.com; \
             frame-ancestors 'none';"
        ),
    );
    
    // HTTP Strict Transport Security
    response.headers_mut().insert(
        HeaderName::from_static("strict-transport-security"),
        HeaderValue::from_static("max-age=31536000; includeSubDomains"),
    );
    
    // X-Content-Type-Options
    response.headers_mut().insert(
        HeaderName::from_static("x-content-type-options"),
        HeaderValue::from_static("nosniff"),
    );
    
    // X-Frame-Options
    response.headers_mut().insert(
        HeaderName::from_static("x-frame-options"),
        HeaderValue::from_static("DENY"),
    );
    
    // X-XSS-Protection (legacy but still useful)
    response.headers_mut().insert(
        HeaderName::from_static("x-xss-protection"),
        HeaderValue::from_static("1; mode=block"),
    );
    
    // Referrer Policy
    response.headers_mut().insert(
        HeaderName::from_static("referrer-policy"),
        HeaderValue::from_static("strict-origin-when-cross-origin"),
    );
    
    // Permissions Policy
    response.headers_mut().insert(
        HeaderName::from_static("permissions-policy"),
        HeaderValue::from_static(
            "geolocation=(), microphone=(), camera=()"
        ),
    );
    
    response
}

SQL Injection Prevention

Using Parameterized Queries

use sqlx::{PgPool, Row};

pub async fn unsafe_query_example(
    pool: &PgPool,
    user_input: String,
) -> Result<Vec<User>, sqlx::Error> {
    // โŒ NEVER DO THIS - SQL Injection vulnerability!
    let query = format!(
        "SELECT * FROM users WHERE username = '{}'",
        user_input
    );
    
    sqlx::query(&query).fetch_all(pool).await
}

pub async fn safe_query_example(
    pool: &PgPool,
    username: &str,
) -> Result<Vec<User>, sqlx::Error> {
    // โœ… Use parameterized queries
    let rows = sqlx::query(
        "SELECT id, username, email, created_at FROM users WHERE username = $1"
    )
    .bind(username)
    .fetch_all(pool)
    .await?;
    
    let users = rows.iter().map(|row| User {
        id: row.get("id"),
        username: row.get("username"),
        email: row.get("email"),
        created_at: row.get("created_at"),
    }).collect();
    
    Ok(users)
}

Using SQLx Query Builder

use sqlx::query_as;

pub async fn query_with_builder(
    pool: &PgPool,
    filters: UserFilters,
) -> Result<Vec<User>, sqlx::Error> {
    let mut query = "SELECT * FROM users WHERE 1=1".to_string();
    let mut param_index = 1;
    let mut params: Vec<Box<dyn sqlx::Encode<'_, sqlx::Postgres>> + Send> = Vec::new();
    
    if let Some(ref username) = filters.username {
        query.push_str(&format!(" AND username = ${}", param_index));
        params.push(Box::new(username.clone()));
        param_index += 1;
    }
    
    if let Some(ref email) = filters.email {
        query.push_str(&format!(" AND email = ${}", param_index));
        params.push(Box::new(email.clone()));
        param_index += 1;
    }
    
    if let Some(limit) = filters.limit {
        query.push_str(&format!(" LIMIT ${}", param_index));
        params.push(Box::new(limit));
    }
    
    // Execute safely
    let mut builder = sqlx::query(&query);
    for param in params {
        builder = builder.bind(param);
    }
    
    builder.fetch_all(pool).await
}

#[derive(Default)]
struct UserFilters {
    username: Option<String>,
    email: Option<String>,
    limit: Option<i64>,
}

CSRF Protection

use axum::{
    extract::FromRef,
    http::{header, Method, StatusCode},
    response::{Html, IntoResponse, Response},
    routing::get,
    Form, Router,
};
use serde::Deserialize;
use std::net::SocketAddr;
use std::sync::Arc;
use tokio::sync::RwLock;
use std::collections::HashMap;

#[derive(Clone)]
struct CsrfState {
    tokens: Arc<RwLock<HashMap<String, CsrfToken>>>,
}

#[derive(Clone)]
struct CsrfToken {
    token: String,
    created_at: std::time::Instant,
}

impl CsrfState {
    pub fn new() -> Self {
        Self {
            tokens: Arc::new(RwLock::new(HashMap::new())),
        }
    }
    
    pub async fn generate_token(&self, session_id: &str) -> String {
        let token = uuid::Uuid::new_v4().to_string();
        let mut tokens = self.tokens.write().await;
        tokens.insert(session_id.to_string(), CsrfToken {
            token: token.clone(),
            created_at: std::time::Instant::now(),
        });
        token
    }
    
    pub async fn validate_token(&self, session_id: &str, token: &str) -> bool {
        let tokens = self.tokens.read().await;
        if let Some(stored) = tokens.get(session_id) {
            let age = stored.created_at.elapsed();
            if age > std::time::Duration::from_secs(3600) {
                return false; // Token expired
            }
            stored.token == token
        } else {
            false
        }
    }
}

async fn protected_handler(
    csrf_state: CsrfState<CsrfState>,
    Form(form): Form<CsrfForm>,
) -> impl IntoResponse {
    let session_id = "session-123"; // Get from session
    
    if !csrf_state.validate_token(session_id, &form.csrf_token).await {
        return (StatusCode::FORBIDDEN, "CSRF token invalid");
    }
    
    // Process form
    "Success"
}

#[derive(Deserialize)]
struct CsrfForm {
    #[serde(rename = "csrf_token")]
    csrf_token: String,
    data: String,
}

Complete Security Middleware Stack

use axum::{
    middleware::SelfAndThen,
    routing::{get, post},
    Router,
};

fn create_secure_app() -> Router {
    let cors_state = Arc::new(CorsState {
        allowed_origins: vec!["https://example.com".to_string()],
        allowed_methods: vec![Method::GET, Method::POST, Method::PUT, Method::DELETE],
        allowed_headers: vec!["Content-Type".to_string(), "Authorization".to_string()],
        max_age: 86400,
    });
    
    let rate_limiter = Arc::new(TokenBucketRateLimiter::new(100, 10.0));
    
    Router::new()
        .route("/api/data", get(get_data).post(create_data))
        .layer(axum::middleware::from_fn_with_state(
            cors_state,
            cors_handler,
        ))
        .layer(axum::middleware::from_fn_with_state(
            rate_limiter.clone(),
            rate_limit_middleware,
        ))
        .layer(axum::middleware::from_fn(security_headers_middleware))
}

Best Practices Summary

Security Measure Implementation Priority
OAuth 2.0 Third-party auth, delegated access High
Rate Limiting Token bucket, fixed window High
CORS Origin allowlist, credentials High
Security Headers CSP, HSTS, X-Frame-Options High
SQL Injection Parameterized queries Critical
CSRF Token-based validation Medium
Input Validation Type-safe validation High
HTTPS TLS termination Critical

Conclusion

API security requires layered defense. This article covered:

  1. OAuth 2.0 - Delegated authorization for third-party apps
  2. Rate Limiting - Token bucket and fixed window algorithms
  3. CORS - Proper cross-origin request handling
  4. Security Headers - CSP, HSTS, X-Frame-Options
  5. SQL Injection Prevention - Always use parameterized queries
  6. CSRF Protection - Token-based validation

Implement these patterns together for comprehensive API security.


External Resources


Comments