Introduction
logsumexp computes log(sum(exp(x))) in a numerically stable way. This operation appears constantly in machine learning โ softmax, log-likelihood, variational inference, and more. Without the stability trick, naive computation overflows or underflows for large or small values.
The Problem: Numerical Instability
Computing log(sum(exp(x))) directly fails for large or very negative values:
import numpy as np
# Large values: exp() overflows to inf
x = np.array([1000, 1001, 1002])
print(np.exp(x)) # => [inf inf inf]
print(np.log(np.sum(np.exp(x)))) # => nan
# Very negative values: exp() underflows to 0
x = np.array([-1000, -1001, -1002])
print(np.exp(x)) # => [0. 0. 0.]
print(np.log(np.sum(np.exp(x)))) # => -inf (wrong!)
The Log-Sum-Exp Trick
The trick: subtract the maximum value before exponentiating, then add it back:
log(sum(exp(x_i))) = max(x) + log(sum(exp(x_i - max(x))))
Since x_i - max(x) <= 0, the exponentials are all in [0, 1] โ no overflow. And the maximum term is added back exactly.
def logsumexp_manual(x):
c = np.max(x)
return c + np.log(np.sum(np.exp(x - c)))
x = np.array([1000, 1001, 1002])
print(logsumexp_manual(x)) # => 1002.4076059644443 (correct!)
scipy.special.logsumexp
SciPy provides an optimized, production-ready implementation:
from scipy.special import logsumexp
import numpy as np
# 1D array
a = np.array([1, 2, 3])
result = logsumexp(a)
print(result) # => 3.4076059644443804
# Verification: log(exp(1) + exp(2) + exp(3)) = 3.4076...
# 2D array โ along axis=1 (row-wise)
a = np.array([[1, 2, 3],
[4, 5, 6]])
print(logsumexp(a, axis=1))
# => [3.40760596 6.40760596]
# Row 0: log(exp(1) + exp(2) + exp(3)) = 3.4076...
# Row 1: log(exp(4) + exp(5) + exp(6)) = 6.4076...
# Along axis=0 (column-wise)
print(logsumexp(a, axis=0))
# => [4.31326169 5.31326169 6.31326169]
Weighted logsumexp
logsumexp supports a b parameter for weighted sums: log(sum(b * exp(a))):
from scipy.special import logsumexp
a = np.array([1.0, 2.0, 3.0])
b = np.array([0.5, 1.0, 2.0]) # weights
result = logsumexp(a, b=b)
# = log(0.5*exp(1) + 1.0*exp(2) + 2.0*exp(3))
print(result) # => 3.7194...
Returning the Sign
For cases where the result might be negative (e.g., log(sum(b * exp(a))) with negative b):
result, sign = logsumexp(a, b=[-1, 1, -1], return_sign=True)
# result = log(|sum(b * exp(a))|)
# sign = sign of the sum (+1 or -1)
print(result, sign)
Practical Applications
Softmax in Log Space
Softmax is exp(x_i) / sum(exp(x)). In log space:
def log_softmax(x):
"""Numerically stable log-softmax."""
return x - logsumexp(x)
def softmax(x):
"""Numerically stable softmax."""
return np.exp(log_softmax(x))
logits = np.array([2.0, 1.0, 0.1])
print(softmax(logits))
# => [0.65900114 0.24243297 0.09856589]
print(softmax(logits).sum()) # => 1.0
Log-Likelihood Aggregation
When computing the total log-likelihood of independent observations:
# Log-probabilities of each observation
log_probs = np.array([-0.5, -1.2, -0.8, -2.1, -0.3])
# Total log-likelihood (sum in log space)
total_log_likelihood = np.sum(log_probs)
print(total_log_likelihood) # => -4.9
# When combining probabilities from different models (mixture):
# log(p1 * w1 + p2 * w2) = logsumexp([log_p1 + log_w1, log_p2 + log_w2])
log_p1, log_p2 = -0.5, -1.0
log_w1, log_w2 = np.log(0.7), np.log(0.3)
log_mixture = logsumexp([log_p1 + log_w1, log_p2 + log_w2])
print(log_mixture)
Viterbi / Forward Algorithm (HMMs)
In Hidden Markov Models, the forward algorithm sums over all paths:
def forward_step(log_alpha, log_transition, log_emission):
"""One step of the forward algorithm in log space."""
# log_alpha: shape (n_states,)
# log_transition: shape (n_states, n_states)
# log_emission: shape (n_states,)
# For each next state j: log(sum_i alpha_i * T_ij) + log(E_j)
log_alpha_next = logsumexp(
log_alpha[:, np.newaxis] + log_transition,
axis=0
) + log_emission
return log_alpha_next
Normalizing Log-Probabilities
# Unnormalized log-probabilities
log_unnorm = np.array([-1.0, -2.0, -0.5, -3.0])
# Normalize: subtract log(sum(exp(log_unnorm)))
log_norm = log_unnorm - logsumexp(log_unnorm)
# Verify they sum to 1 in probability space
print(np.exp(log_norm).sum()) # => 1.0
print(log_norm)
# => [-1.56... -2.56... -1.06... -3.56...]
NumPy Alternative (without SciPy)
If you can’t use SciPy, implement it with NumPy:
def logsumexp_numpy(a, axis=None):
"""NumPy-only logsumexp."""
a_max = np.max(a, axis=axis, keepdims=True)
out = np.log(np.sum(np.exp(a - a_max), axis=axis))
out += np.squeeze(a_max, axis=axis)
return out
# Test
a = np.array([[1, 2, 3], [4, 5, 6]])
print(logsumexp_numpy(a, axis=1))
# => [3.40760596 6.40760596]
Summary
| Function | Use Case |
|---|---|
logsumexp(a) |
Stable log(sum(exp(a))) |
logsumexp(a, axis=k) |
Along a specific axis |
logsumexp(a, b=weights) |
Weighted: log(sum(b * exp(a))) |
log_softmax(x) |
Stable log-softmax |
x - logsumexp(x) |
Normalize log-probabilities |
Comments