Overview
Numerical stability refers to how algorithms handle small numerical errors in calculations. In Python, especially when working with probabilities, you often encounter overflow and underflow issues that can break your computations.
The Problem
When computing probabilities, especially with exponential functions, values can become extremely small (underflow) or extremely large (overflow):
import numpy as np
# Underflow: very small numbers become 0
probabilities = np.array([1e-300, 1e-301, 1e-302])
print(np.exp(probabilities))
# Output: [1e-300 0 0]
# Overflow: large numbers become inf
large_values = np.array([1000, 2000, 3000])
print(np.exp(large_values))
# Output: [inf inf inf]
The Solution: Log-Sum-Exp Trick
Instead of working with raw probabilities, work with their logarithms. This is called the log-sum-exp trick.
Basic Implementation
def normal_scalar(X, mu, var, d):
"""Normal 2D PDF - numerically stable version
Args:
X: Data point
mu: Mean
var: Variance
d: Dimension
"""
# Work in log space to avoid overflow
log_prob = -(np.linalg.norm(X - mu))**2 / (2*var) - np.log(2*np.pi*var) * d / 2
return np.exp(log_prob)
# This is numerically stable because we take log in the middle
# and only exp at the end
Log-Sum-Exp for Summing Probabilities
When you need to sum probabilities:
def log_sum_exp(x):
"""Numerically stable way to compute log(sum(exp(x)))
Args:
x: Array of log probabilities
Returns:
Log of the sum of probabilities
"""
max_x = np.max(x) # Prevent overflow
return max_x + np.log(np.sum(np.exp(x - max_x)))
# Usage
log_probs = np.array([-1000, -1001, -1002])
result = log_sum_exp(log_probs)
print(result) # ~-999.09 (stable)
print(np.log(np.sum(np.exp(log_probs)))) # -inf (unstable!)
Using scipy
from scipy.special import logsumexp
log_probs = np.array([-1000, -1001, -1002])
# Simple log-sum-exp
result = logsumexp(log_probs)
print(result) # ~-999.09
# With axis parameter
log_probs_2d = np.array([[-1000, -1001], [-1002, -1003]])
row_sums = logsumexp(log_probs_2d, axis=1)
print(row_sums)
Common Stability Patterns
1. Softmax Function
def stable_softmax(x):
"""Numerically stable softmax
Args:
x: Input array
Returns:
Softmax probabilities
"""
# Subtract max to prevent overflow
exp_x = np.exp(x - np.max(x))
return exp_x / np.sum(exp_x)
# Example
x = np.array([1000, 1001, 1002])
print(stable_softmax(x)) # [0.09003057 0.24472847 0.66524096]
print(np.exp(x) / np.sum(np.exp(x))) # [nan nan nan] - overflow!
2. Gaussian (Normal) Distribution
def log_gaussian(x, mu, sigma):
"""Log of Gaussian probability density function
Args:
x: Data point
mu: Mean
sigma: Standard deviation
Returns:
Log probability density
"""
return -0.5 * np.log(2 * np.pi * sigma**2) - (x - mu)**2 / (2 * sigma**2)
def gaussian_pdf(x, mu, sigma):
"""Numerically stable Gaussian PDF
"""
return np.exp(log_gaussian(x, mu, sigma))
3. Cross-Entropy Loss
def stable_cross_entropy(y_true, y_pred):
"""Numerically stable cross-entropy
Args:
y_true: True labels (one-hot encoded)
y_pred: Predicted probabilities
Returns:
Cross-entropy loss
"""
# Clip predictions to prevent log(0)
epsilon = 1e-15
y_pred = np.clip(y_pred, epsilon, 1 - epsilon)
return -np.sum(y_true * np.log(y_pred))
Practical Examples
1. Hidden Markov Model Forward Algorithm
def forward_algorithm(log_emission_probs, log_transition_probs, log_initial):
"""Forward algorithm with numerical stability
Args:
log_emission_probs: Log emission probabilities
log_transition_probs: Log transition probabilities
log_initial: Log initial probabilities
Returns:
Log probability of the sequence
"""
log_alpha = log_initial + log_emission_probs[0]
for t in range(1, len(log_emission_probs)):
# Log-sum-exp across previous states
log_alpha = logsumexp(
log_alpha[:, np.newaxis] + log_transition_probs,
axis=0
) + log_emission_probs[t]
return logsumexp(log_alpha)
2. Gaussian Mixture Model
def log_gaussian_mix(X, weights, means, covs):
"""Log probability of Gaussian Mixture Model
Args:
X: Data points (n_samples, n_features)
weights: Mixture weights
means: Means of components
covs: Covariance matrices
Returns:
Log probabilities (n_samples,)
"""
n_components = len(weights)
log_probs = np.zeros((X.shape[0], n_components))
for k in range(n_components):
log_probs[:, k] = log_gaussian(X, means[k], covs[k])
# Add log weights and sum
log_weights = np.log(weights)
return logsumexp(log_probs + log_weights, axis=1)
Common Pitfalls
| Pitfall | Solution |
|---|---|
log(0) |
Use np.log(x + epsilon) |
exp(large) |
Subtract max before exp |
1 / small |
Use reciprocal with checks |
sqrt(negative) |
Use np.maximum(x, 0) |
inf - inf |
Use np.nan_to_num() |
Best Practices
- Always work in log-space for probability computations
- Subtract the maximum before taking exp
- Use scipy’s logsumexp for complex summations
- Clip values before logarithm to prevent log(0)
- Test edge cases: Try with extreme values
Conclusion
Numerical stability is crucial for any mathematical computation in Python, especially in machine learning and statistics. The log-sum-exp trick and careful value handling can prevent common overflow and underflow issues that would otherwise break your calculations.
Comments