Introduction
Building an AI API is different from traditional APIs. You deal with variable latency, high costs, rate limits, and the challenge of providing consistent quality. This guide covers everything you need to know to build production-ready AI APIs.
API Design Principles
RESTful Design
# FastAPI example for AI endpoint
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
app = FastAPI()
class CompletionRequest(BaseModel):
model: str
prompt: str
max_tokens: int = 1000
temperature: float = 0.7
class CompletionResponse(BaseModel):
id: str
text: str
usage: dict
model: str
@app.post("/v1/completions", response_model=CompletionResponse)
async def create_completion(request: CompletionRequest):
# Validate model
if request.model not in AVAILABLE_MODELS:
raise HTTPException(400, "Model not available")
# Process request
result = await generate_completion(request)
return CompletionResponse(**result)
WebSocket for Streaming
# Streaming responses
from fastapi import WebSocket
@app.websocket("/v1/chat/stream")
async def chat_stream(websocket: WebSocket):
await websocket.accept()
try:
while True:
data = await websocket.receive_json()
async for chunk in generate_stream(data):
await websocket.send_json(chunk)
except Exception:
await websocket.close()
Rate Limiting
Token Bucket Implementation
from fastapi import Request, HTTPException
from datetime import datetime, timedelta
class RateLimiter:
def __init__(self, rate: int, per: int):
self.rate = rate
self.per = per
self.buckets = {}
async def check(self, key: str) -> bool:
now = datetime.now()
if key not in self.buckets:
self.buckets[key] = {
"tokens": self.rate,
"last_update": now
}
bucket = self.buckets[key]
# Refill tokens
elapsed = (now - bucket["last_update"]).total_seconds()
refill = elapsed * (self.rate / self.per)
bucket["tokens"] = min(self.rate, bucket["tokens"] + refill)
bucket["last_update"] = now
if bucket["tokens"] >= 1:
bucket["tokens"] -= 1
return True
return False
# Usage
rate_limiter = RateLimiter(rate=100, per=60)
@app.post("/v1/completions")
async def create_completion(request: Request):
client_id = request.client.host
if not await rate_limiter.check(client_id):
raise HTTPException(429, "Rate limit exceeded")
return await generate_completion(request)
Caching Strategies
Semantic Caching
import hashlib
from sklearn.metrics.pairwise import cosine_similarity
import numpy as np
class SemanticCache:
def __init__(self, threshold: float = 0.95):
self.threshold = threshold
self.cache = []
self.embeddings = []
def get_cache_key(self, prompt: str) -> str:
# Simple hash for exact match
return hashlib.md5(prompt.encode()).hexdigest()
async def get(self, prompt: str) -> str | None:
# Check exact match first
key = self.get_cache_key(prompt)
for cached in self.cache:
if cached["key"] == key:
return cached["response"]
# Check semantic similarity
embedding = await get_embedding(prompt)
for i, cached_emb in enumerate(self.embeddings):
similarity = cosine_similarity(
[embedding],
[cached_emb]
)[0][0]
if similarity >= self.threshold:
return self.cache[i]["response"]
return None
async def set(self, prompt: str, response: str):
embedding = await get_embedding(prompt)
self.cache.append({
"key": self.get_cache_key(prompt),
"prompt": prompt,
"response": response
})
self.embeddings.append(embedding)
# Limit cache size
if len(self.cache) > 1000:
self.cache.pop(0)
self.embeddings.pop(0)
Authentication
API Key Management
from fastapi import Security, HTTPException
from fastapi.security import APIKeyHeader
import secrets
import hashlib
api_key_header = APIKeyHeader(name="X-API-Key")
class AuthManager:
def __init__(self):
self.keys = {} # {key_hash: {user_id, permissions, rate_limit}}
async def verify_key(self, api_key: str = Security(api_key_header)) -> dict:
key_hash = hashlib.sha256(api_key.encode()).hexdigest()
if key_hash not in self.keys:
raise HTTPException(401, "Invalid API key")
return self.keys[key_hash]
auth = AuthManager()
@app.post("/v1/completions")
async def create_completion(
request: Request,
auth_data: dict = Security(auth.verify_key)
):
# Check rate limit based on user
user_rate_limit = auth_data["rate_limit"]
# Process request
JWT Authentication
from jose import JWTError, jwt
from datetime import datetime, timedelta
SECRET_KEY = "your-secret-key"
ALGORITHM = "HS256"
def create_access_token(data: dict, expires_delta: timedelta = None):
to_encode = data.copy()
if expires_delta:
expire = datetime.utcnow() + expires_delta
else:
expire = datetime.utcnow() + timedelta(hours=24)
to_encode.update({"exp": expire})
encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
return encoded_jwt
async def get_current_user(token: str = Security(oauth2_scheme)):
try:
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
user_id: str = payload.get("sub")
if user_id is None:
raise HTTPException(401, "Invalid token")
return user_id
except JWTError:
raise HTTPException(401, "Invalid token")
Error Handling
Graceful Error Responses
from fastapi import HTTPException
from enum import Enum
class ErrorCode(str, Enum):
RATE_LIMIT = "rate_limit_exceeded"
INVALID_MODEL = "invalid_model"
CONTENT_FILTER = "content_filter"
TIMEOUT = "timeout"
SERVER_ERROR = "server_error"
class AIException(Exception):
def __init__(self, code: ErrorCode, message: str):
self.code = code
self.message = message
@app.exception_handler(AIException)
async def ai_exception_handler(request: Request, exc: AIException):
status_codes = {
ErrorCode.RATE_LIMIT: 429,
ErrorCode.INVALID_MODEL: 400,
ErrorCode.CONTENT_FILTER: 400,
ErrorCode.TIMEOUT: 504,
ErrorCode.SERVER_ERROR: 500
}
return JSONResponse(
status_code=status_codes.get(exc.code, 500),
content={
"error": {
"code": exc.code,
"message": exc.message
}
}
)
Scaling Strategies
Load Balancing
# Multiple model instances
class ModelPool:
def __init__(self, instances: list):
self.instances = instances
self.current = 0
async def get_instance(self):
instance = self.instances[self.current]
self.current = (self.current + 1) % len(self.instances)
return instance
# Round-robin selection
model_pool = ModelPool([
"http://model-1:8000",
"http://model-2:8000",
"http://model-3:8000"
])
@app.post("/v1/completions")
async def create_completion(request: CompletionRequest):
instance = await model_pool.get_instance()
response = await call_model_instance(instance, request)
return response
Queue-Based Processing
import asyncio
from collections import deque
class RequestQueue:
def __init__(self, max_size: int = 100):
self.queue = deque(maxlen=max_size)
self.processing = False
async def add(self, request: dict):
if len(self.queue) >= self.queue.maxlen:
raise HTTPException(503, "Server busy")
self.queue.append(request)
if not self.processing:
asyncio.create_task(self.process_queue())
async def process_queue(self):
self.processing = True
while self.queue:
request = self.queue.popleft()
await process_request(request)
self.processing = False
Monitoring
Metrics Collection
from prometheus_client import Counter, Histogram, Gauge
# Define metrics
REQUESTS = Counter(
"ai_api_requests_total",
"Total API requests",
["model", "status"]
)
LATENCY = Histogram(
"ai_api_request_duration_seconds",
"Request latency",
["model"]
)
ACTIVE_REQUESTS = Gauge(
"ai_api_active_requests",
"Active requests"
)
TOKEN_USAGE = Counter(
"ai_api_tokens_total",
"Total tokens used",
["model", "type"]
)
# Use in endpoint
@app.post("/v1/completions")
async def create_completion(request: CompletionRequest):
ACTIVE_REQUESTS.inc()
start_time = time.time()
try:
result = await generate_completion(request)
REQUESTS.labels(model=request.model, status="success").inc()
TOKEN_USAGE.labels(
model=request.model,
type="prompt"
).inc(result.usage.prompt_tokens)
return result
finally:
LATENCY.labels(model=request.model).observe(
time.time() - start_time
)
ACTIVE_REQUESTS.dec()
Best Practices
1. Request Validation
# Always validate inputs
class CompletionRequest(BaseModel):
model: str
prompt: str
max_tokens: int = Field(ge=1, le=4096)
temperature: float = Field(ge=0.0, le=2.0)
top_p: float = Field(ge=0.0, le=1.0)
@validator("prompt")
def validate_prompt(cls, v):
if not v or len(v.strip()) == 0:
raise ValueError("Prompt cannot be empty")
if len(v) > 100000:
raise ValueError("Prompt too long")
return v
2. Timeout Handling
import asyncio
async def generate_with_timeout(prompt: str, timeout: float = 30.0):
try:
return await asyncio.wait_for(
generate_completion(prompt),
timeout=timeout
)
except asyncio.TimeoutError:
raise AIException(ErrorCode.TIMEOUT, "Request timed out")
3. Logging
import structlog
logger = structlog.get_logger()
@app.post("/v1/completions")
async def create_completion(request: CompletionRequest):
logger.info(
"completion_request",
model=request.model,
prompt_length=len(request.prompt),
user_id=get_current_user_id()
)
# Process
External Resources
Documentation
Tools
Conclusion
Building production AI APIs requires careful attention to rate limiting, caching, authentication, and monitoring. Use these patterns to build reliable services.
Key takeaways:
- Implement rate limiting - Protect your service
- Add semantic caching - Reduce costs
- Use proper auth - Secure endpoints
- Monitor everything - Know your system
Comments