Skip to main content
โšก Calmops

Numerical Stability in Python

Overview

Numerical stability refers to how algorithms handle small numerical errors in calculations. In Python, especially when working with probabilities, you often encounter overflow and underflow issues that can break your computations.

The Problem

When computing probabilities, especially with exponential functions, values can become extremely small (underflow) or extremely large (overflow):

import numpy as np

# Underflow: very small numbers become 0
probabilities = np.array([1e-300, 1e-301, 1e-302])
print(np.exp(probabilities))
# Output: [1e-300  0  0]

# Overflow: large numbers become inf
large_values = np.array([1000, 2000, 3000])
print(np.exp(large_values))
# Output: [inf inf inf]

The Solution: Log-Sum-Exp Trick

Instead of working with raw probabilities, work with their logarithms. This is called the log-sum-exp trick.

Basic Implementation

def normal_scalar(X, mu, var, d):
    """Normal 2D PDF - numerically stable version
    
    Args:
        X: Data point
        mu: Mean
        var: Variance
        d: Dimension
    """
    # Work in log space to avoid overflow
    log_prob = -(np.linalg.norm(X - mu))**2 / (2*var) - np.log(2*np.pi*var) * d / 2
    return np.exp(log_prob)

# This is numerically stable because we take log in the middle
# and only exp at the end

Log-Sum-Exp for Summing Probabilities

When you need to sum probabilities:

def log_sum_exp(x):
    """Numerically stable way to compute log(sum(exp(x)))
    
    Args:
        x: Array of log probabilities
    Returns:
        Log of the sum of probabilities
    """
    max_x = np.max(x)  # Prevent overflow
    return max_x + np.log(np.sum(np.exp(x - max_x)))

# Usage
log_probs = np.array([-1000, -1001, -1002])
result = log_sum_exp(log_probs)
print(result)  # ~-999.09 (stable)
print(np.log(np.sum(np.exp(log_probs))))  # -inf (unstable!)

Using scipy

from scipy.special import logsumexp

log_probs = np.array([-1000, -1001, -1002])

# Simple log-sum-exp
result = logsumexp(log_probs)
print(result)  # ~-999.09

# With axis parameter
log_probs_2d = np.array([[-1000, -1001], [-1002, -1003]])
row_sums = logsumexp(log_probs_2d, axis=1)
print(row_sums)

Common Stability Patterns

1. Softmax Function

def stable_softmax(x):
    """Numerically stable softmax
    
    Args:
        x: Input array
    Returns:
        Softmax probabilities
    """
    # Subtract max to prevent overflow
    exp_x = np.exp(x - np.max(x))
    return exp_x / np.sum(exp_x)

# Example
x = np.array([1000, 1001, 1002])
print(stable_softmax(x))  # [0.09003057 0.24472847 0.66524096]
print(np.exp(x) / np.sum(np.exp(x)))  # [nan nan nan] - overflow!

2. Gaussian (Normal) Distribution

def log_gaussian(x, mu, sigma):
    """Log of Gaussian probability density function
    
    Args:
        x: Data point
        mu: Mean
        sigma: Standard deviation
    Returns:
        Log probability density
    """
    return -0.5 * np.log(2 * np.pi * sigma**2) - (x - mu)**2 / (2 * sigma**2)

def gaussian_pdf(x, mu, sigma):
    """Numerically stable Gaussian PDF
    """
    return np.exp(log_gaussian(x, mu, sigma))

3. Cross-Entropy Loss

def stable_cross_entropy(y_true, y_pred):
    """Numerically stable cross-entropy
    
    Args:
        y_true: True labels (one-hot encoded)
        y_pred: Predicted probabilities
    Returns:
        Cross-entropy loss
    """
    # Clip predictions to prevent log(0)
    epsilon = 1e-15
    y_pred = np.clip(y_pred, epsilon, 1 - epsilon)
    return -np.sum(y_true * np.log(y_pred))

Practical Examples

1. Hidden Markov Model Forward Algorithm

def forward_algorithm(log_emission_probs, log_transition_probs, log_initial):
    """Forward algorithm with numerical stability
    
    Args:
        log_emission_probs: Log emission probabilities
        log_transition_probs: Log transition probabilities
        log_initial: Log initial probabilities
    Returns:
        Log probability of the sequence
    """
    log_alpha = log_initial + log_emission_probs[0]
    
    for t in range(1, len(log_emission_probs)):
        # Log-sum-exp across previous states
        log_alpha = logsumexp(
            log_alpha[:, np.newaxis] + log_transition_probs, 
            axis=0
        ) + log_emission_probs[t]
    
    return logsumexp(log_alpha)

2. Gaussian Mixture Model

def log_gaussian_mix(X, weights, means, covs):
    """Log probability of Gaussian Mixture Model
    
    Args:
        X: Data points (n_samples, n_features)
        weights: Mixture weights
        means: Means of components
        covs: Covariance matrices
    Returns:
        Log probabilities (n_samples,)
    """
    n_components = len(weights)
    log_probs = np.zeros((X.shape[0], n_components))
    
    for k in range(n_components):
        log_probs[:, k] = log_gaussian(X, means[k], covs[k])
    
    # Add log weights and sum
    log_weights = np.log(weights)
    return logsumexp(log_probs + log_weights, axis=1)

Common Pitfalls

Pitfall Solution
log(0) Use np.log(x + epsilon)
exp(large) Subtract max before exp
1 / small Use reciprocal with checks
sqrt(negative) Use np.maximum(x, 0)
inf - inf Use np.nan_to_num()

Best Practices

  1. Always work in log-space for probability computations
  2. Subtract the maximum before taking exp
  3. Use scipy’s logsumexp for complex summations
  4. Clip values before logarithm to prevent log(0)
  5. Test edge cases: Try with extreme values

Conclusion

Numerical stability is crucial for any mathematical computation in Python, especially in machine learning and statistics. The log-sum-exp trick and careful value handling can prevent common overflow and underflow issues that would otherwise break your calculations.

Comments