API Security Beyond JWT: OAuth 2.0, Rate Limiting, and Protection Strategies
TL;DR: This guide covers advanced API security patterns for Rust web services beyond basic JWT authentication. You’ll learn OAuth 2.0 implementation, rate limiting algorithms, CORS configuration, security headers, and protection against common vulnerabilities.
Introduction
While JWT authentication is foundational, production APIs require defense in depth. This article explores comprehensive security measures that protect against:
- Unauthorized access and privilege escalation
- API abuse and DoS attacks
- Cross-site attacks (CSRF, XSS)
- Injection attacks
- Data exposure
OAuth 2.0 Implementation in Rust
OAuth 2.0 provides delegated authorization for third-party applications.
Installation
[dependencies]
oauth2 = "5.0"
reqwest = { version = "0.11", features = ["json"] }
serde = { version = "1.0", features = ["derive"] }
tokio = { version = "1", features = ["full"] }
OAuth 2.0 Client Implementation
use oauth2::{
basic::BasicClient, AuthUrl, ClientId, ClientSecret, RedirectUrl, TokenResponse,
TokenUrl, AuthorizationCode, Scope,
};
use oauth2::reqwest::http_client;
use serde::{Deserialize, Serialize};
use std::sync::Arc;
use tokio::sync::RwLock;
use std::collections::HashMap;
pub struct OAuthService {
client: BasicClient,
state: Arc<RwLock<HashMap<String, UserSession>>>,
}
#[derive(Clone, Serialize, Deserialize)]
pub struct UserSession {
pub user_id: String,
pub email: String,
pub scopes: Vec<String>,
}
impl OAuthService {
pub fn new(
client_id: String,
client_secret: String,
auth_url: String,
token_url: String,
redirect_url: String,
) -> Self {
let client = BasicClient::new(
ClientId::new(client_id),
Some(ClientSecret::new(client_secret)),
AuthUrl::new(auth_url).unwrap(),
Some(TokenUrl::new(token_url).unwrap()),
)
.set_redirect_uri(RedirectUrl::new(redirect_url).unwrap());
Self {
client,
state: Arc::new(RwLock::new(HashMap::new())),
}
}
pub fn generate_authorization_url(&self, state: &str, scopes: Vec<&str>) -> String {
let mut request = self.client.authorize_url(
Scope::new("read".to_string()),
Scope::new("write".to_string()),
);
request = request.add_extra_param("state", state);
for scope in scopes {
request = request.add_scope(Scope::new(scope.to_string()));
}
request.url().to_string()
}
pub async fn exchange_code_for_token(
&self,
code: String,
) -> Result<TokenData, OAuthError> {
let token_result = self.client
.exchange_code(AuthorizationCode::new(code))
.http_client(http_client)
.send()
.await
.map_err(|e| OAuthError::TokenExchange(e.to_string()))?;
let access_token = token_result.access_token().secret().to_string();
let refresh_token = token_result.refresh_token()
.map(|t| t.secret().to_string());
let expires_in = token_result.expires_in();
Ok(TokenData {
access_token,
refresh_token,
expires_in,
token_type: "Bearer".to_string(),
})
}
pub async fn get_user_info(&self, access_token: &str) -> Result<UserSession, OAuthError> {
let client = reqwest::Client::new();
let response = client
.get("https://oauth-provider.com/userinfo")
.bearer_auth(access_token)
.send()
.await
.map_err(|e| OAuthError::UserInfo(e.to_string()))?;
let user_info: UserInfoResponse = response.json().await
.map_err(|e| OAuthError::UserInfo(e.to_string()))?;
Ok(UserSession {
user_id: user_info.sub,
email: user_info.email,
scopes: vec!["read".to_string()],
})
}
}
#[derive(Debug)]
pub enum OAuthError {
TokenExchange(String),
UserInfo(String),
}
#[derive(Serialize, Deserialize)]
pub struct TokenData {
pub access_token: String,
pub refresh_token: Option<String>,
pub expires_in: Option<u64>,
pub token_type: String,
}
#[derive(Deserialize)]
struct UserInfoResponse {
sub: String,
email: String,
}
OAuth 2.0 Resource Server
use axum::{
extract::HeaderValue,
http::{Request, StatusCode},
middleware::Next,
response::Response,
};
use jsonwebtoken::{decode, Algorithm, DecodingKey, Validation};
pub struct JwtValidator {
validation: Validation,
decoding_key: DecodingKey,
}
impl JwtValidator {
pub fn new(secret: &[u8]) -> Self {
let mut validation = Validation::new(Algorithm::HS256);
validation.set_issuer(&["oauth-provider"]);
Self {
validation,
decoding_key: DecodingKey::from_secret(secret),
}
}
pub fn validate_token(&self, token: &str) -> Result<Claims, String> {
let token_data = decode::<Claims>(token, &self.decoding_key, &self.validation)
.map_err(|e| e.to_string())?;
Ok(token_data.claims)
}
}
#[derive(Serialize, Deserialize)]
pub struct Claims {
pub sub: String,
pub email: String,
pub iat: u64,
pub exp: u64,
pub scopes: Vec<String>,
}
Rate Limiting Implementation
Protect your API from abuse with rate limiting.
Token Bucket Algorithm
use std::sync::atomic::{AtomicU64, Ordering};
use std::time::{Duration, Instant};
use tokio::sync::RwLock;
use std::collections::HashMap;
pub struct TokenBucketRateLimiter {
buckets: Arc<RwLock<HashMap<String, Bucket>>>,
capacity: u64,
refill_rate: f64, // tokens per second
}
struct Bucket {
tokens: f64,
last_refill: Instant,
}
impl TokenBucketRateLimiter {
pub fn new(capacity: u64, refill_rate: f64) -> Self {
Self {
buckets: Arc::new(RwLock::new(HashMap::new())),
capacity,
refill_rate,
}
}
pub async fn check_rate_limit(
&self,
key: &str,
cost: u64,
) -> RateLimitResult {
let mut buckets = self.buckets.write().await;
let bucket = buckets.entry(key.to_string()).or_insert_with(|| Bucket {
tokens: self.capacity as f64,
last_refill: Instant::now(),
});
let now = Instant::now();
let elapsed = now.duration_since(bucket.last_refill).as_secs_f64();
bucket.tokens = (bucket.tokens + elapsed * self.refill_rate).min(self.capacity as f64);
bucket.last_refill = now;
if bucket.tokens >= cost as f64 {
bucket.tokens -= cost as f64;
RateLimitResult::Allowed {
remaining: bucket.tokens as u64,
reset_in: Duration::from_secs_f64(cost as f64 / self.refill_rate),
}
} else {
RateLimitResult::Denied {
retry_after: Duration::from_secs_f64(
(cost as f64 - bucket.tokens) / self.refill_rate
),
}
}
}
}
pub enum RateLimitResult {
Allowed { remaining: u64, reset_in: Duration },
Denied { retry_after: Duration },
}
Axum Rate Limiting Middleware
use axum::{
extract::Request,
http::header::{HeaderName, HeaderValue},
middleware::Next,
response::Response,
};
use std::sync::Arc;
pub async fn rate_limit_middleware(
rate_limiter: Arc<TokenBucketRateLimiter>,
request: Request,
next: Next,
) -> Response {
let client_ip = extract_client_ip(&request);
let cost = determine_request_cost(&request);
match rate_limiter.check_rate_limit(&client_ip, cost).await {
RateLimitResult::Allowed { remaining, reset_in } => {
let mut response = next.run(request).await;
response.headers_mut().insert(
HeaderName::from_static("x-ratelimit-limit"),
HeaderValue::from_str(&remaining.to_string()).unwrap(),
);
response.headers_mut().insert(
HeaderName::from_static("x-ratelimit-remaining"),
HeaderValue::from_str(&remaining.to_string()).unwrap(),
);
response.headers_mut().insert(
HeaderName::from_static("x-ratelimit-reset"),
HeaderValue::from_str(&reset_in.as_secs().to_string()).unwrap(),
);
response
}
RateLimitResult::Denied { retry_after } => {
let mut response = Response::new("Rate limit exceeded".into());
*response.status_mut() = StatusCode::TOO_MANY_REQUESTS;
response.headers_mut().insert(
HeaderName::from_static("retry-after"),
HeaderValue::from_str(&retry_after.as_secs().to_string()).unwrap(),
);
response
}
}
}
fn extract_client_ip(request: &Request) -> String {
request
.headers()
.get("x-forwarded-for")
.and_then(|v| v.to_str().ok())
.map(|s| s.split(',').next().unwrap_or(s).to_string())
.unwrap_or_else(|| "unknown".to_string())
}
fn determine_request_cost(request: &Request) -> u64 {
match request.uri().path() {
p if p.starts_with("/api/write") => 10,
p if p.starts_with("/api/read") => 1,
p if p.starts_with("/api/search") => 5,
_ => 1,
}
}
Fixed Window Rate Limiter (Simpler)
use std::time::{Duration, Instant};
use tokio::sync::RwLock;
use std::collections::HashMap;
pub struct FixedWindowRateLimiter {
windows: Arc<RwLock<HashMap<String, Window>>>,
max_requests: u64,
window_size: Duration,
}
struct Window {
count: u64,
window_start: Instant,
}
impl FixedWindowRateLimiter {
pub fn new(max_requests: u64, window_size: Duration) -> Self {
Self {
windows: Arc::new(RwLock::new(HashMap::new())),
max_requests,
window_size,
}
}
pub async fn is_allowed(&self, key: &str) -> bool {
let mut windows = self.windows.write().await;
let now = Instant::now();
let window = windows.entry(key.to_string()).or_insert_with(|| Window {
count: 0,
window_start: now,
});
if now.duration_since(window.window_start) >= self.window_size {
window.count = 0;
window.window_start = now;
}
if window.count < self.max_requests {
window.count += 1;
true
} else {
false
}
}
}
CORS Configuration
Proper CORS Headers
use axum::{
extract::State,
http::{header, Method, Origin},
response::Response,
routing::get,
Router,
};
use std::sync::Arc;
#[derive(Clone)]
struct CorsState {
allowed_origins: Vec<String>,
allowed_methods: Vec<Method>,
allowed_headers: Vec<String>,
max_age: u64,
}
async fn cors_handler(
State(state): State<Arc<CorsState>>,
Origin(origin): Origin,
request: axum::extract::Request,
next: Next,
) -> Response {
let allowed = state.allowed_origins.iter()
.any(|o| o == "*" || o == origin.as_str());
let mut response = next.run(request).await;
if allowed {
let allowed_origin = if state.allowed_origins.contains(&"*".to_string()) {
header::Value::from_static("*")
} else {
origin.to_str().unwrap_or("*").parse().unwrap()
};
response.headers_mut().insert(
header::ACCESS_CONTROL_ALLOW_ORIGIN,
allowed_origin,
);
response.headers_mut().insert(
header::ACCESS_CONTROL_ALLOW_METHODS,
state.allowed_methods.iter()
.map(|m| m.as_str())
.collect::<Vec<_>>()
.join(", ")
.parse()
.unwrap(),
);
response.headers_mut().insert(
header::ACCESS_CONTROL_ALLOW_HEADERS,
state.allowed_headers.join(", ").parse().unwrap(),
);
response.headers_mut().insert(
header::ACCESS_CONTROL_MAX_AGE,
state.max_age.to_string().parse().unwrap(),
);
response.headers_mut().insert(
header::ACCESS_CONTROL_ALLOW_CREDENTIALS,
header::Value::from_static("true"),
);
}
response
}
Security Headers
Setting CSP, HSTS, and Other Headers
use axum::{
http::{header, HeaderName, HeaderValue, Method},
middleware::Next,
response::Response,
extract::Request,
};
pub async fn security_headers_middleware(
request: Request,
next: Next,
) -> Response {
let mut response = next.run(request).await;
// Content Security Policy
response.headers_mut().insert(
HeaderName::from_static("content-security-policy"),
HeaderValue::from_static(
"default-src 'self'; \
script-src 'self' 'unsafe-inline'; \
style-src 'self' 'unsafe-inline'; \
img-src 'self' data: https:; \
font-src 'self'; \
connect-src 'self' https://api.example.com; \
frame-ancestors 'none';"
),
);
// HTTP Strict Transport Security
response.headers_mut().insert(
HeaderName::from_static("strict-transport-security"),
HeaderValue::from_static("max-age=31536000; includeSubDomains"),
);
// X-Content-Type-Options
response.headers_mut().insert(
HeaderName::from_static("x-content-type-options"),
HeaderValue::from_static("nosniff"),
);
// X-Frame-Options
response.headers_mut().insert(
HeaderName::from_static("x-frame-options"),
HeaderValue::from_static("DENY"),
);
// X-XSS-Protection (legacy but still useful)
response.headers_mut().insert(
HeaderName::from_static("x-xss-protection"),
HeaderValue::from_static("1; mode=block"),
);
// Referrer Policy
response.headers_mut().insert(
HeaderName::from_static("referrer-policy"),
HeaderValue::from_static("strict-origin-when-cross-origin"),
);
// Permissions Policy
response.headers_mut().insert(
HeaderName::from_static("permissions-policy"),
HeaderValue::from_static(
"geolocation=(), microphone=(), camera=()"
),
);
response
}
SQL Injection Prevention
Using Parameterized Queries
use sqlx::{PgPool, Row};
pub async fn unsafe_query_example(
pool: &PgPool,
user_input: String,
) -> Result<Vec<User>, sqlx::Error> {
// โ NEVER DO THIS - SQL Injection vulnerability!
let query = format!(
"SELECT * FROM users WHERE username = '{}'",
user_input
);
sqlx::query(&query).fetch_all(pool).await
}
pub async fn safe_query_example(
pool: &PgPool,
username: &str,
) -> Result<Vec<User>, sqlx::Error> {
// โ
Use parameterized queries
let rows = sqlx::query(
"SELECT id, username, email, created_at FROM users WHERE username = $1"
)
.bind(username)
.fetch_all(pool)
.await?;
let users = rows.iter().map(|row| User {
id: row.get("id"),
username: row.get("username"),
email: row.get("email"),
created_at: row.get("created_at"),
}).collect();
Ok(users)
}
Using SQLx Query Builder
use sqlx::query_as;
pub async fn query_with_builder(
pool: &PgPool,
filters: UserFilters,
) -> Result<Vec<User>, sqlx::Error> {
let mut query = "SELECT * FROM users WHERE 1=1".to_string();
let mut param_index = 1;
let mut params: Vec<Box<dyn sqlx::Encode<'_, sqlx::Postgres>> + Send> = Vec::new();
if let Some(ref username) = filters.username {
query.push_str(&format!(" AND username = ${}", param_index));
params.push(Box::new(username.clone()));
param_index += 1;
}
if let Some(ref email) = filters.email {
query.push_str(&format!(" AND email = ${}", param_index));
params.push(Box::new(email.clone()));
param_index += 1;
}
if let Some(limit) = filters.limit {
query.push_str(&format!(" LIMIT ${}", param_index));
params.push(Box::new(limit));
}
// Execute safely
let mut builder = sqlx::query(&query);
for param in params {
builder = builder.bind(param);
}
builder.fetch_all(pool).await
}
#[derive(Default)]
struct UserFilters {
username: Option<String>,
email: Option<String>,
limit: Option<i64>,
}
CSRF Protection
use axum::{
extract::FromRef,
http::{header, Method, StatusCode},
response::{Html, IntoResponse, Response},
routing::get,
Form, Router,
};
use serde::Deserialize;
use std::net::SocketAddr;
use std::sync::Arc;
use tokio::sync::RwLock;
use std::collections::HashMap;
#[derive(Clone)]
struct CsrfState {
tokens: Arc<RwLock<HashMap<String, CsrfToken>>>,
}
#[derive(Clone)]
struct CsrfToken {
token: String,
created_at: std::time::Instant,
}
impl CsrfState {
pub fn new() -> Self {
Self {
tokens: Arc::new(RwLock::new(HashMap::new())),
}
}
pub async fn generate_token(&self, session_id: &str) -> String {
let token = uuid::Uuid::new_v4().to_string();
let mut tokens = self.tokens.write().await;
tokens.insert(session_id.to_string(), CsrfToken {
token: token.clone(),
created_at: std::time::Instant::now(),
});
token
}
pub async fn validate_token(&self, session_id: &str, token: &str) -> bool {
let tokens = self.tokens.read().await;
if let Some(stored) = tokens.get(session_id) {
let age = stored.created_at.elapsed();
if age > std::time::Duration::from_secs(3600) {
return false; // Token expired
}
stored.token == token
} else {
false
}
}
}
async fn protected_handler(
csrf_state: CsrfState<CsrfState>,
Form(form): Form<CsrfForm>,
) -> impl IntoResponse {
let session_id = "session-123"; // Get from session
if !csrf_state.validate_token(session_id, &form.csrf_token).await {
return (StatusCode::FORBIDDEN, "CSRF token invalid");
}
// Process form
"Success"
}
#[derive(Deserialize)]
struct CsrfForm {
#[serde(rename = "csrf_token")]
csrf_token: String,
data: String,
}
Complete Security Middleware Stack
use axum::{
middleware::SelfAndThen,
routing::{get, post},
Router,
};
fn create_secure_app() -> Router {
let cors_state = Arc::new(CorsState {
allowed_origins: vec!["https://example.com".to_string()],
allowed_methods: vec![Method::GET, Method::POST, Method::PUT, Method::DELETE],
allowed_headers: vec!["Content-Type".to_string(), "Authorization".to_string()],
max_age: 86400,
});
let rate_limiter = Arc::new(TokenBucketRateLimiter::new(100, 10.0));
Router::new()
.route("/api/data", get(get_data).post(create_data))
.layer(axum::middleware::from_fn_with_state(
cors_state,
cors_handler,
))
.layer(axum::middleware::from_fn_with_state(
rate_limiter.clone(),
rate_limit_middleware,
))
.layer(axum::middleware::from_fn(security_headers_middleware))
}
Best Practices Summary
| Security Measure | Implementation | Priority |
|---|---|---|
| OAuth 2.0 | Third-party auth, delegated access | High |
| Rate Limiting | Token bucket, fixed window | High |
| CORS | Origin allowlist, credentials | High |
| Security Headers | CSP, HSTS, X-Frame-Options | High |
| SQL Injection | Parameterized queries | Critical |
| CSRF | Token-based validation | Medium |
| Input Validation | Type-safe validation | High |
| HTTPS | TLS termination | Critical |
Conclusion
API security requires layered defense. This article covered:
- OAuth 2.0 - Delegated authorization for third-party apps
- Rate Limiting - Token bucket and fixed window algorithms
- CORS - Proper cross-origin request handling
- Security Headers - CSP, HSTS, X-Frame-Options
- SQL Injection Prevention - Always use parameterized queries
- CSRF Protection - Token-based validation
Implement these patterns together for comprehensive API security.
External Resources
Related Articles
- JWT Authentication in Rust Web Services
- Authentication and Authorization in Rust
- Building REST APIs with Axum and Actix-web
- Error Handling Patterns in Rust Web Services
Comments