Skip to main content

logsumexp: Numerically Stable Log-Sum-Exp in Python

Created: June 12, 2019 4 min read

Introduction

logsumexp computes log(sum(exp(x))) in a numerically stable way. This operation appears constantly in machine learning — softmax, log-likelihood, variational inference, and more. Without the stability trick, naive computation overflows or underflows for large or small values. See Python Guide for more context. See Python Guide for more context. See Python Guide for more context.

The Problem: Numerical Instability

Computing log(sum(exp(x))) directly fails for large or very negative values:

import numpy as np

# Large values: exp() overflows to inf
x = np.array([1000, 1001, 1002])
print(np.exp(x))          # => [inf inf inf]
print(np.log(np.sum(np.exp(x))))  # => nan

# Very negative values: exp() underflows to 0
x = np.array([-1000, -1001, -1002])
print(np.exp(x))          # => [0. 0. 0.]
print(np.log(np.sum(np.exp(x))))  # => -inf (wrong!)

The Log-Sum-Exp Trick

The trick: subtract the maximum value before exponentiating, then add it back:

log(sum(exp(x_i))) = max(x) + log(sum(exp(x_i - max(x))))

Since x_i - max(x) <= 0, the exponentials are all in [0, 1] — no overflow. And the maximum term is added back exactly.

def logsumexp_manual(x):
    c = np.max(x)
    return c + np.log(np.sum(np.exp(x - c)))

x = np.array([1000, 1001, 1002])
print(logsumexp_manual(x))  # => 1002.4076059644443  (correct!)

scipy.special.logsumexp

SciPy provides an optimized, production-ready implementation:

from scipy.special import logsumexp
import numpy as np

# 1D array
a = np.array([1, 2, 3])
result = logsumexp(a)
print(result)  # => 3.4076059644443804
# Verification: log(exp(1) + exp(2) + exp(3)) = 3.4076...

# 2D array — along axis=1 (row-wise)
a = np.array([[1, 2, 3],
              [4, 5, 6]])

print(logsumexp(a, axis=1))
# => [3.40760596 6.40760596]
# Row 0: log(exp(1) + exp(2) + exp(3)) = 3.4076...
# Row 1: log(exp(4) + exp(5) + exp(6)) = 6.4076...

# Along axis=0 (column-wise)
print(logsumexp(a, axis=0))
# => [4.31326169 5.31326169 6.31326169]

Weighted logsumexp

logsumexp supports a b parameter for weighted sums: log(sum(b * exp(a))):

from scipy.special import logsumexp

a = np.array([1.0, 2.0, 3.0])
b = np.array([0.5, 1.0, 2.0])  # weights

result = logsumexp(a, b=b)
# = log(0.5*exp(1) + 1.0*exp(2) + 2.0*exp(3))
print(result)  # => 3.7194...

Returning the Sign

For cases where the result might be negative (e.g., log(sum(b * exp(a))) with negative b):

result, sign = logsumexp(a, b=[-1, 1, -1], return_sign=True)
# result = log(|sum(b * exp(a))|)
# sign   = sign of the sum (+1 or -1)
print(result, sign)

Practical Applications

Softmax in Log Space

Softmax is exp(x_i) / sum(exp(x)). In log space:

def log_softmax(x):
    """Numerically stable log-softmax."""
    return x - logsumexp(x)

def softmax(x):
    """Numerically stable softmax."""
    return np.exp(log_softmax(x))

logits = np.array([2.0, 1.0, 0.1])
print(softmax(logits))
# => [0.65900114 0.24243297 0.09856589]
print(softmax(logits).sum())  # => 1.0

Log-Likelihood Aggregation

When computing the total log-likelihood of independent observations:

# Log-probabilities of each observation
log_probs = np.array([-0.5, -1.2, -0.8, -2.1, -0.3])

# Total log-likelihood (sum in log space)
total_log_likelihood = np.sum(log_probs)
print(total_log_likelihood)  # => -4.9

# When combining probabilities from different models (mixture):
# log(p1 * w1 + p2 * w2) = logsumexp([log_p1 + log_w1, log_p2 + log_w2])
log_p1, log_p2 = -0.5, -1.0
log_w1, log_w2 = np.log(0.7), np.log(0.3)

log_mixture = logsumexp([log_p1 + log_w1, log_p2 + log_w2])
print(log_mixture)

Viterbi / Forward Algorithm (HMMs)

In Hidden Markov Models, the forward algorithm sums over all paths:

def forward_step(log_alpha, log_transition, log_emission):
    """One step of the forward algorithm in log space."""
    # log_alpha: shape (n_states,)
    # log_transition: shape (n_states, n_states)
    # log_emission: shape (n_states,)

    # For each next state j: log(sum_i alpha_i * T_ij) + log(E_j)
    log_alpha_next = logsumexp(
        log_alpha[:, np.newaxis] + log_transition,
        axis=0
    ) + log_emission

    return log_alpha_next

Normalizing Log-Probabilities

# Unnormalized log-probabilities
log_unnorm = np.array([-1.0, -2.0, -0.5, -3.0])

# Normalize: subtract log(sum(exp(log_unnorm)))
log_norm = log_unnorm - logsumexp(log_unnorm)

# Verify they sum to 1 in probability space
print(np.exp(log_norm).sum())  # => 1.0
print(log_norm)
# => [-1.56...  -2.56...  -1.06...  -3.56...]

NumPy Alternative (without SciPy)

If you can’t use SciPy, implement it with NumPy:

def logsumexp_numpy(a, axis=None):
    """NumPy-only logsumexp."""
    a_max = np.max(a, axis=axis, keepdims=True)
    out = np.log(np.sum(np.exp(a - a_max), axis=axis))
    out += np.squeeze(a_max, axis=axis)
    return out

# Test
a = np.array([[1, 2, 3], [4, 5, 6]])
print(logsumexp_numpy(a, axis=1))
# => [3.40760596 6.40760596]

Summary

Function Use Case
logsumexp(a) Stable log(sum(exp(a)))
logsumexp(a, axis=k) Along a specific axis
logsumexp(a, b=weights) Weighted: log(sum(b * exp(a)))
log_softmax(x) Stable log-softmax
x - logsumexp(x) Normalize log-probabilities

Resources

Comments

Share this article

Scan to read on mobile