Introduction
logsumexp computes log(sum(exp(x))) in a numerically stable way. This operation appears constantly in machine learning — softmax, log-likelihood, variational inference, and more. Without the stability trick, naive computation overflows or underflows for large or small values. See Python Guide for more context. See Python Guide for more context. See Python Guide for more context.
The Problem: Numerical Instability
Computing log(sum(exp(x))) directly fails for large or very negative values:
import numpy as np
# Large values: exp() overflows to inf
x = np.array([1000, 1001, 1002])
print(np.exp(x)) # => [inf inf inf]
print(np.log(np.sum(np.exp(x)))) # => nan
# Very negative values: exp() underflows to 0
x = np.array([-1000, -1001, -1002])
print(np.exp(x)) # => [0. 0. 0.]
print(np.log(np.sum(np.exp(x)))) # => -inf (wrong!)
The Log-Sum-Exp Trick
The trick: subtract the maximum value before exponentiating, then add it back:
log(sum(exp(x_i))) = max(x) + log(sum(exp(x_i - max(x))))
Since x_i - max(x) <= 0, the exponentials are all in [0, 1] — no overflow. And the maximum term is added back exactly.
def logsumexp_manual(x):
c = np.max(x)
return c + np.log(np.sum(np.exp(x - c)))
x = np.array([1000, 1001, 1002])
print(logsumexp_manual(x)) # => 1002.4076059644443 (correct!)
scipy.special.logsumexp
SciPy provides an optimized, production-ready implementation:
from scipy.special import logsumexp
import numpy as np
# 1D array
a = np.array([1, 2, 3])
result = logsumexp(a)
print(result) # => 3.4076059644443804
# Verification: log(exp(1) + exp(2) + exp(3)) = 3.4076...
# 2D array — along axis=1 (row-wise)
a = np.array([[1, 2, 3],
[4, 5, 6]])
print(logsumexp(a, axis=1))
# => [3.40760596 6.40760596]
# Row 0: log(exp(1) + exp(2) + exp(3)) = 3.4076...
# Row 1: log(exp(4) + exp(5) + exp(6)) = 6.4076...
# Along axis=0 (column-wise)
print(logsumexp(a, axis=0))
# => [4.31326169 5.31326169 6.31326169]
Weighted logsumexp
logsumexp supports a b parameter for weighted sums: log(sum(b * exp(a))):
from scipy.special import logsumexp
a = np.array([1.0, 2.0, 3.0])
b = np.array([0.5, 1.0, 2.0]) # weights
result = logsumexp(a, b=b)
# = log(0.5*exp(1) + 1.0*exp(2) + 2.0*exp(3))
print(result) # => 3.7194...
Returning the Sign
For cases where the result might be negative (e.g., log(sum(b * exp(a))) with negative b):
result, sign = logsumexp(a, b=[-1, 1, -1], return_sign=True)
# result = log(|sum(b * exp(a))|)
# sign = sign of the sum (+1 or -1)
print(result, sign)
Practical Applications
Softmax in Log Space
Softmax is exp(x_i) / sum(exp(x)). In log space:
def log_softmax(x):
"""Numerically stable log-softmax."""
return x - logsumexp(x)
def softmax(x):
"""Numerically stable softmax."""
return np.exp(log_softmax(x))
logits = np.array([2.0, 1.0, 0.1])
print(softmax(logits))
# => [0.65900114 0.24243297 0.09856589]
print(softmax(logits).sum()) # => 1.0
Log-Likelihood Aggregation
When computing the total log-likelihood of independent observations:
# Log-probabilities of each observation
log_probs = np.array([-0.5, -1.2, -0.8, -2.1, -0.3])
# Total log-likelihood (sum in log space)
total_log_likelihood = np.sum(log_probs)
print(total_log_likelihood) # => -4.9
# When combining probabilities from different models (mixture):
# log(p1 * w1 + p2 * w2) = logsumexp([log_p1 + log_w1, log_p2 + log_w2])
log_p1, log_p2 = -0.5, -1.0
log_w1, log_w2 = np.log(0.7), np.log(0.3)
log_mixture = logsumexp([log_p1 + log_w1, log_p2 + log_w2])
print(log_mixture)
Viterbi / Forward Algorithm (HMMs)
In Hidden Markov Models, the forward algorithm sums over all paths:
def forward_step(log_alpha, log_transition, log_emission):
"""One step of the forward algorithm in log space."""
# log_alpha: shape (n_states,)
# log_transition: shape (n_states, n_states)
# log_emission: shape (n_states,)
# For each next state j: log(sum_i alpha_i * T_ij) + log(E_j)
log_alpha_next = logsumexp(
log_alpha[:, np.newaxis] + log_transition,
axis=0
) + log_emission
return log_alpha_next
Normalizing Log-Probabilities
# Unnormalized log-probabilities
log_unnorm = np.array([-1.0, -2.0, -0.5, -3.0])
# Normalize: subtract log(sum(exp(log_unnorm)))
log_norm = log_unnorm - logsumexp(log_unnorm)
# Verify they sum to 1 in probability space
print(np.exp(log_norm).sum()) # => 1.0
print(log_norm)
# => [-1.56... -2.56... -1.06... -3.56...]
NumPy Alternative (without SciPy)
If you can’t use SciPy, implement it with NumPy:
def logsumexp_numpy(a, axis=None):
"""NumPy-only logsumexp."""
a_max = np.max(a, axis=axis, keepdims=True)
out = np.log(np.sum(np.exp(a - a_max), axis=axis))
out += np.squeeze(a_max, axis=axis)
return out
# Test
a = np.array([[1, 2, 3], [4, 5, 6]])
print(logsumexp_numpy(a, axis=1))
# => [3.40760596 6.40760596]
Summary
| Function | Use Case |
|---|---|
logsumexp(a) |
Stable log(sum(exp(a))) |
logsumexp(a, axis=k) |
Along a specific axis |
logsumexp(a, b=weights) |
Weighted: log(sum(b * exp(a))) |
log_softmax(x) |
Stable log-softmax |
x - logsumexp(x) |
Normalize log-probabilities |
Comments