As your application grows, a single database instance becomes a bottleneck. Database sharding horizontally partitions data across multiple database instances, enabling massive scale. This guide covers strategies, trade-offs, and implementation patterns.
Sharding Is a Last Resort
Database sharding distributes data across multiple database instances using horizontal partitioning. It is the most complex scaling strategy and should be your last resort. Before sharding, exhaust all other options: read replicas, connection pooling optimization, query optimization, caching (Redis/CDN), and vertical scaling (bigger instance).
The sharding key is the single most important decision — it determines query performance, data distribution, and rebalancing difficulty. A poor sharding key creates hot spots where one shard handles 80% of traffic while others sit idle. Common sharding keys include user ID (most natural for SaaS), tenant ID (multi-tenant), and geographic region (latency optimization).
Resharding is the hardest problem: when you need to add shards, existing data must be rebalanced — a massive migration. Two rebalancing approaches are downtime migration (offline) and consistent hashing (minimizes moved data). Alternatives to manual sharding include Citus (PostgreSQL sharding), Vitess (MySQL sharding), CockroachDB (auto-sharding), and Spanner (auto-sharding). Sharding works at scale for companies like Instagram and Uber, but it adds enormous operational complexity.
Understanding Database Sharding
The Scaling Problem
┌─────────────────────────────────────────────────────────────────┐
│ Single Database Scaling Limits │
│ │
│ Requests │
│ │ │
│ ▼ │
│ ┌─────────┐ │
│ │ App │ │
│ └────┬────┘ │
│ │ │
│ ▼ │
│ ┌─────────┐ │
│ │ Database│ ◄── Single point of failure │
│ └────┬────┘ - Vertical scaling has limits │
│ │ - Connection pool exhaustion │
│ ▼ - I/O bandwidth saturation │
│ ┌─────────┐ │
│ │ SSD │ │
│ └─────────┘ │
│ │
│ Typical limits: │
│ - 10,000-50,000 connections max │
│ - ~100TB on single instance │
│ - Vertical scaling: $100K+ for top specs │
└─────────────────────────────────────────────────────────────────┘
Sharding Solution
┌─────────────────────────────────────────────────────────────────┐
│ Horizontal Sharding Architecture │
│ │
│ Requests │
│ │ │
│ ▼ │
│ ┌─────────┐ │
│ │ App │ │
│ └────┬────┘ │
│ │ │
│ ▼ │
│ ┌─────────────────────────────────────────────┐ │
│ │ Shard Router │ │
│ │ (or Application Logic) │ │
│ └─────────────────────────────────────────────┘ │
│ │ │ │ │ │
│ ▼ ▼ ▼ ▼ │
│ ┌─────────┐ ┌─────────┐ ┌─────────┐ ┌─────────┐ │
│ │ Shard 1 │ │ Shard 2 │ │ Shard 3 │ │ Shard N │ │
│ │ (Users │ │ (Users │ │ (Orders)│ │ (Logs) │ │
│ │ A-M) │ │ N-Z) │ │ │ │ │ │
│ └─────────┘ └─────────┘ └─────────┘ └─────────┘ │
│ │
│ Benefits: │
│ ✓ Linear scaling with more shards │
│ ✓ Reduced load per database │
│ ✓ Geographic distribution possible │
│ ✓ Isolation for different data types │
└─────────────────────────────────────────────────────────────────┘
Shard Key Selection
Types of Shard Keys
SHARD_KEY_TYPES = {
"range_based": {
"description": "Partition by ranges of the key",
"example": "users_00: IDs 1-1M, users_01: IDs 1M-2M",
"pros": ["Simple to implement", "Easy range queries"],
"cons": ["Hot spots if access is not uniform"]
},
"hash_based": {
"description": "Hash function determines shard",
"example": "shard = hash(user_id) % num_shards",
"pros": ["Even distribution", "Reduced hot spots"],
"cons": ["Range queries scatter across shards"]
},
"directory_based": {
"description": "Lookup service maps keys to shards",
"example": "shard = lookup_service.get_shard(key)",
"pros": ["Flexible", "Can rebalance dynamically"],
"cons": ["Lookup service is single point"]
},
"geographic": {
"description": "Partition by geographic region",
"example": "us-east, eu-west, ap-south",
"pros": ["Low latency for regional users"],
"cons": ["Cross-region queries expensive"]
}
}
Shard Key Analysis
def analyze_shard_key_candidates(schema: dict, query_log: list) -> dict:
"""Analyze potential shard keys for even distribution."""
candidates = {}
for table in schema.values():
for column in table.columns:
distribution = analyze_distribution(
query_log,
column,
table.name
)
query_patterns = analyze_query_patterns(
query_log,
column,
table.name
)
candidates[f"{table.name}.{column.name}"] = {
"distribution_score": distribution["score"],
"query_coverage": query_patterns["coverage"],
"hotspot_risk": distribution["hotspot_risk"],
"recommendation": calculate_recommendation(
distribution, query_patterns
)
}
return candidates
def analyze_distribution(query_log, column, table) -> dict:
access_counts = {}
for query in query_log:
if query.table == table and column in query.where_cols:
key_value = query.where_values[column]
shard = hash(key_value) % NUM_SHARDS
access_counts[shard] = access_counts.get(shard, 0) + 1
total = sum(access_counts.values())
distribution = [count / total for count in access_counts.values()]
return {
"score": 1 - (max(distribution) - min(distribution)),
"hotspot_risk": max(distribution) / min(distribution) if min(distribution) > 0 else float('inf'),
"distribution": distribution
}
Choosing the Right Shard Key
GOOD_SHARD_KEYS = {
"user_id": {
"when": "Most queries are user-scoped",
"example": "SELECT * FROM orders WHERE user_id = ?",
"distribution": "hash(user_id) % N"
},
"order_date": {
"when": "Time-range queries are common",
"example": "SELECT * FROM logs WHERE date > '2026-01-01'",
"distribution": "Range-based partitioning"
},
"tenant_id": {
"when": "Multi-tenant SaaS application",
"example": "SELECT * FROM documents WHERE tenant_id = ?",
"distribution": "tenant_id % N"
}
}
BAD_SHARD_KEYS = {
"auto_increment_id": """
❌ Problem: Creates hot shard
- Newest data gets all writes
- Single shard becomes bottleneck
✅ Solution: Use composite key with random component
""",
"status_field": """
❌ Problem: Highly skewed distribution
- Most records have status='active'
- 90% of data on one shard
✅ Solution: Use user_id or date instead
""",
"foreign_key_only": """
❌ Problem: Can't route queries without the key
- Need to join across all shards
- Full scatter-gather operations
✅ Solution: Denormalize or use alternate key
"""
}
Implementation Strategies
Application-Level Sharding
class ShardRouter:
def __init__(self, shard_config: dict):
self.shards = shard_config["shards"]
self.shard_function = shard_config["shard_function"]
def get_shard(self, shard_key) -> str:
if self.shard_function == "hash":
shard_index = hash(shard_key) % len(self.shards)
elif self.shard_function == "range":
shard_index = self._get_range_shard(shard_key)
else:
raise ValueError(f"Unknown function: {self.shard_function}")
return self.shards[shard_index]
def _get_range_shard(self, key) -> int:
for i, range_def in enumerate(self.shard_config["ranges"]):
if range_def["min"] <= key < range_def["max"]:
return i
return len(self.shards) - 1
class ShardedConnection:
def __init__(self, router: ShardRouter):
self.router = router
self.connections: dict[str, Connection] = {}
def get_connection(self, shard_key) -> Connection:
shard = self.router.get_shard(shard_key)
if shard not in self.connections:
self.connections[shard] = self._create_connection(shard)
return self.connections[shard]
def execute_on_shard(self, shard_key, query, params):
conn = self.get_connection(shard_key)
return conn.execute(query, params)
Sharding at the Database Level
# Vitess-style horizontal sharding (MySQL)
SHARD_CONFIG = {
"shard_ranges": [
"-4000000000000000000", # Shard 0: < -4T
"-4000000000000000000-0", # Shard 1: -4T to 0
"0-4000000000000000000", # Shard 2: 0 to 4T
"4000000000000000000-" # Shard 3: > 4T
],
"vindex": {
"user_vindex": {
"type": "hash",
"column": "user_id",
"shard_count": 4
}
}
}
# Routing query with VIndex
def route_query(vtgate_conn, query):
# VSchema tells Vitess which vindex to use
bound_query = {
"sql": "SELECT * FROM orders WHERE user_id = :user_id",
"bind_vars": {"user_id": 12345}
}
# Vitess automatically routes to correct shard
return vtgate_conn.execute(bound_query)
Range-Based Sharding
class RangeBasedPartitioner:
"""Partition data by key ranges."""
def __init__(self, ranges: list[tuple]):
self.ranges = sorted(ranges, key=lambda x: x[0])
def get_partition(self, key) -> int:
for i, (min_val, max_val) in enumerate(self.ranges):
if min_val <= key < max_val:
return i
return len(self.ranges) - 1
# Example: Date-based partitioning for logs
LOG_PARTITIONER = RangeBasedPartitioner([
("2025-01-01", "2025-04-01"),
("2025-04-01", "2025-07-01"),
("2025-07-01", "2025-10-01"),
("2025-10-01", "2026-01-01"),
])
def get_log_shard(timestamp) -> str:
partition = LOG_PARTITIONER.get_partition(timestamp)
return f"logs_shard_{partition}"
Consistent Hashing
import hashlib
class ConsistentHashRing:
def __init__(self, nodes: list[str], virtual_nodes: int = 150):
self.ring = {}
self.sorted_keys = []
self.virtual_nodes = virtual_nodes
for node in nodes:
self._add_node(node)
def _add_node(self, node: str):
for i in range(self.virtual_nodes):
key = self._hash(f"{node}:{i}")
self.ring[key] = node
self.sorted_keys.append(key)
self.sorted_keys.sort()
def _hash(self, key: str) -> int:
return int(hashlib.md5(key.encode()).hexdigest(), 16)
def get_node(self, key: str) -> str:
hash_key = self._hash(key)
for node_hash in self.sorted_keys:
if node_hash >= hash_key:
return self.ring[node_hash]
return self.ring[self.sorted_keys[0]]
def add_node(self, node: str):
self._add_node(node)
def remove_node(self, node: str):
for i in range(self.virtual_nodes):
key = self._hash(f"{node}:{i}")
del self.ring[key]
self.sorted_keys.remove(key)
Cross-Shard Queries
Scatter-Gather Pattern
class ScatterGatherQuery:
def __init__(self, shard_connections: dict):
self.connections = shard_connections
async def query_all_shards(self, query: str, params: dict) -> list:
"""Execute query on all shards and combine results."""
async def fetch_from_shard(shard_name, conn):
try:
result = await conn.execute(query, params)
return {"shard": shard_name, "data": result}
except Exception as e:
return {"shard": shard_name, "error": str(e)}
tasks = [
fetch_from_shard(shard, conn)
for shard, conn in self.connections.items()
]
results = await asyncio.gather(*tasks)
return [r for r in results if "error" not in r]
def aggregate_results(self, shard_results: list, aggregator: callable):
"""Combine results from all shards."""
all_data = []
for result in shard_results:
all_data.extend(result.get("data", []))
return aggregator(all_data)
Handling Joins Across Shards
class ShardedJoinExecutor:
def __init__(self, router: ShardRouter):
self.router = router
async def join_across_shards(
self,
left_query,
right_query,
join_key: str,
join_type: str = "inner"
):
# Step 1: Fetch data from first table
left_shards = await self._fetch_left_table(left_query)
# Step 2: Group by join key
join_keys = set()
for row in left_shards:
join_keys.add(row[join_key])
# Step 3: Fetch matching rows from second table
right_shards = await self._fetch_right_table(
right_query,
join_keys
)
# Step 4: Perform in-memory join
right_index = {
row[join_key]: row
for row in right_shards
}
results = []
for left_row in left_shards:
key = left_row[join_key]
if key in right_index or join_type == "left":
results.append({
**left_row,
**right_index.get(key, {})
})
return results
async def _fetch_left_table(self, query):
# Execute on all relevant shards
pass
async def _fetch_right_table(self, query, keys):
# Fetch only needed keys from all shards
pass
Distributed Transactions
import asyncio
class TwoPhaseCommit:
def __init__(self, shard_connections: dict):
self.connections = shard_connections
async def execute_transaction(
self,
operations: list[dict]
) -> bool:
# Phase 1: Prepare
prepared = await self._prepare_phase(operations)
if not all(prepared.values()):
# Rollback on any failure
await self._rollback_phase(operations)
return False
# Phase 2: Commit
await self._commit_phase(operations)
return True
async def _prepare_phase(self, operations: list[dict]) -> dict:
results = {}
async def prepare_op(op):
shard = self.router.get_shard(op["shard_key"])
conn = self.connections[shard]
try:
await conn.execute("PREPARE TRANSACTION", op["transaction_id"])
return True
except Exception:
return False
results = await asyncio.gather(*[
prepare_op(op) for op in operations
])
return {op["transaction_id"]: r
for op, r in zip(operations, results)}
Rebalancing Shards
Online Re-sharding
class ShardRebalancer:
def __init__(
self,
source_shard: Connection,
target_shard: Connection
):
self.source = source_shard
self.target = target_shard
async def move_data(
self,
batch_size: int = 1000,
checkpoint_interval: int = 10000
):
last_id = 0
total_moved = 0
while True:
# Read batch from source
batch = await self.source.execute("""
SELECT * FROM users
WHERE id > %s
ORDER BY id
LIMIT %s
""", (last_id, batch_size))
if not batch:
break
# Write to target
for row in batch:
await self.target.insert("users", row)
last_id = batch[-1]["id"]
total_moved += len(batch)
# Checkpoint progress
if total_moved % checkpoint_interval == 0:
await self._save_checkpoint(last_id, total_moved)
return total_moved
async def verify_integrity(self):
source_count = await self.source.count("users")
target_count = await self.target.count("users")
return source_count == target_count
Dual-Write Pattern During Migration
class DualWriteRouter:
def __init__(self, old_router: ShardRouter, new_router: ShardRouter):
self.old_router = old_router
self.new_router = new_router
self.migration_state = "old" # old, migrating, new
async def write(self, key, data):
# Write to both during migration
old_shard = self.old_router.get_shard(key)
new_shard = self.new_router.get_shard(key)
if self.migration_state in ["old", "migrating"]:
await self.write_to_shard(old_shard, data)
if self.migration_state in ["migrating", "new"]:
await self.write_to_shard(new_shard, data)
async def read(self, key):
# Read from both and compare during migration
old_shard = self.old_router.get_shard(key)
new_shard = self.new_router.get_shard(key)
if self.migration_state == "migrating":
old_data = await self.read_from_shard(old_shard, key)
new_data = await self.read_from_shard(new_shard, key)
if old_data != new_data:
logger.error(f"Data mismatch: {old_data} vs {new_data}")
return new_data
elif self.migration_state == "new":
return await self.read_from_shard(new_shard, key)
else:
return await self.read_from_shard(old_shard, key)
Managing Reference Data
Distributed Reference Tables
class ReferenceDataManager:
def __init__(self, cache: Redis):
self.cache = cache
async def get_country_codes(self) -> dict:
# Try cache first
cached = await self.cache.get("country_codes")
if cached:
return json.loads(cached)
# Fetch from any shard (replicated)
data = await self.any_shard.execute(
"SELECT code, name FROM countries"
)
result = {row["code"]: row["name"] for row in data}
# Cache with TTL
await self.cache.setex(
"country_codes",
3600,
json.dumps(result)
)
return result
# Replicate reference tables to all shards
REFERENCE_TABLES = [
"countries",
"currencies",
"timezones",
"subscription_plans"
]
def replicate_reference_tables():
for table in REFERENCE_TABLES:
source_data = master_db.fetch_table(table)
for shard in all_shards:
shard.execute(f"TRUNCATE TABLE {table}")
shard.bulk_insert(table, source_data)
Connection Pooling Per Shard
class ShardedConnectionPool:
def __init__(self, config: dict):
self.pools: dict[str, ConnectionPool] = {}
self.config = config
self._initialize_pools()
def _initialize_pools(self):
for shard_name, shard_config in self.config["shards"].items():
self.pools[shard_name] = ConnectionPool(
host=shard_config["host"],
port=shard_config["port"],
min_connections=5,
max_connections=50,
max_idle_time=300
)
async def get_connection(self, shard_key: str) -> Connection:
shard = self.router.get_shard(shard_key)
# Check connection limit per shard
if self.pools[shard].active_connections >= \
self.pools[shard].max_connections:
# Wait or throw exception
raise ConnectionPoolExhaustedError(shard)
return await self.pools[shard].acquire()
async def release_connection(self, shard_key: str, conn: Connection):
shard = self.router.get_shard(shard_key)
await self.pools[shard].release(conn)
Monitoring Sharded Systems
Key Metrics
SHARD_METRICS = {
"per_shard_metrics": [
"queries_per_second",
"rows_returned",
"rows_modified",
"avg_query_duration",
"p99_query_duration",
"active_connections",
"cpu_usage",
"disk_usage",
"memory_usage"
],
"cross_shard_metrics": [
"scatter_gather_queries",
"distributed_transaction_count",
"failed_transactions",
"rebalance_progress",
"replication_lag"
],
"alert_thresholds": {
"shard_imbalance": 0.3, # 30% skew triggers alert
"high_latency_p99": "500ms",
"connection_pool_usage": 0.9, # 90% triggers alert
"disk_usage": 0.85
}
}
Health Checks
class ShardHealthChecker:
def __init__(self, pool_manager: ShardedConnectionPool):
self.pools = pool_manager
async def check_shard_health(self, shard: str) -> dict:
pool = self.pools.pools[shard]
try:
# Test query
start = time.time()
async with pool.acquire() as conn:
result = await conn.execute("SELECT 1")
latency = time.time() - start
# Get pool stats
stats = pool.get_stats()
return {
"shard": shard,
"healthy": True,
"latency_ms": latency * 1000,
"active_connections": stats["active"],
"idle_connections": stats["idle"],
"wait_queue": stats["waiting"]
}
except Exception as e:
return {
"shard": shard,
"healthy": False,
"error": str(e)
}
async def check_all_shards(self) -> list[dict]:
tasks = [
self.check_shard_health(shard)
for shard in self.pools.pools.keys()
]
return await asyncio.gather(*tasks)
Related Articles
Summary
Database sharding enables horizontal scaling beyond single database limits:
- Shard Key Selection is critical - choose keys that distribute load evenly and support your most common query patterns
- Range-based sharding works well for time-series data, while hash-based provides even distribution
- Cross-shard queries require scatter-gather patterns; minimize them for performance
- Rebalancing can be done online with dual-write patterns to avoid downtime
- Connection pooling must be managed per-shard to prevent resource exhaustion
- Monitoring both per-shard and cross-shard metrics is essential
Sharding is a significant architectural decision - consider using managed solutions like Vitess, CockroachDB, or Spanner if possible before implementing custom sharding.
Comments