Skip to main content
โšก Calmops

NumPy argmax and unravel_index: Finding the Position of the Maximum Value

Introduction

Finding the maximum value in a NumPy array is straightforward with np.max(). But finding the position (index) of that maximum โ€” especially in multi-dimensional arrays โ€” requires argmax() and unravel_index(). This guide covers both functions with practical examples.

argmax() in 1D Arrays

For a 1D array, argmax() returns the index of the largest element:

import numpy as np

a = np.array([3, 7, 1, 9, 2, 5])

print(np.argmax(a))   # => 3  (index of value 9)
print(a[np.argmax(a)])  # => 9

argmax() in 2D Arrays

For multi-dimensional arrays, argmax() by default flattens the array and returns the index in the flattened version:

state = np.array([[8, 5],
                  [3, 2],
                  [5, 1],
                  [9, 6]])

print(state.argmax())   # => 6
# Flattened: [8, 5, 3, 2, 5, 1, 9, 6]
#  indices:   0  1  2  3  4  5  6  7
# Value 9 is at flat index 6

That flat index 6 isn’t very useful on its own. To convert it back to a 2D index, use np.unravel_index().

unravel_index: Convert Flat Index to Multi-Dimensional

np.unravel_index(flat_index, shape) converts a flat index into a tuple of indices for each dimension:

state = np.array([[8, 5],
                  [3, 2],
                  [5, 1],
                  [9, 6]])

max_flat_idx = state.argmax()          # => 6
max_2d_idx   = np.unravel_index(max_flat_idx, state.shape)

print(max_2d_idx)        # => (3, 0)
print(state[3, 0])       # => 9  (confirmed)
print(state[max_2d_idx]) # => 9

The result (3, 0) means row 3, column 0 โ€” which is indeed where 9 lives.

One-Liner Pattern

# Get the 2D position of the maximum value in one line
row, col = np.unravel_index(state.argmax(), state.shape)
print(f"Max value {state[row, col]} at row={row}, col={col}")
# => Max value 9 at row=3, col=0

argmax() Along an Axis

You can find the argmax along a specific axis without flattening:

scores = np.array([[85, 92, 78],
                   [90, 88, 95],
                   [70, 85, 80]])

# Index of max value in each column (axis=0)
print(np.argmax(scores, axis=0))  # => [1, 0, 1]
# Column 0: max is 90 at row 1
# Column 1: max is 92 at row 0
# Column 2: max is 95 at row 1

# Index of max value in each row (axis=1)
print(np.argmax(scores, axis=1))  # => [1, 2, 1]
# Row 0: max is 92 at col 1
# Row 1: max is 95 at col 2
# Row 2: max is 85 at col 1

Finding Top-N Positions

To find the positions of the top N values:

a = np.array([[3, 8, 1],
              [7, 2, 9],
              [4, 6, 5]])

# Flatten, get indices of top 3 values
flat_top3 = np.argsort(a.flatten())[-3:][::-1]
top3_positions = [np.unravel_index(i, a.shape) for i in flat_top3]

print(top3_positions)
# => [(1, 2), (0, 1), (1, 0)]  โ€” positions of 9, 8, 7

Or more concisely with np.argpartition:

# Faster for large arrays: argpartition doesn't fully sort
flat_top3 = np.argpartition(a.flatten(), -3)[-3:]
top3_positions = [np.unravel_index(i, a.shape) for i in flat_top3]

Practical Example: Reinforcement Learning Q-Table

A common use case is finding the best action in a Q-table (used in reinforcement learning):

# Q-table: rows = states, columns = actions
q_table = np.array([
    [0.1, 0.5, 0.3, 0.8],   # state 0: best action is 3 (value 0.8)
    [0.9, 0.2, 0.7, 0.4],   # state 1: best action is 0 (value 0.9)
    [0.3, 0.6, 0.1, 0.5],   # state 2: best action is 1 (value 0.6)
])

# Best action for each state
best_actions = np.argmax(q_table, axis=1)
print(best_actions)  # => [3, 0, 1]

# Best action for a specific state
current_state = 1
best_action = np.argmax(q_table[current_state])
print(f"State {current_state}: best action = {best_action}")
# => State 1: best action = 0

Handling Ties

When multiple elements share the maximum value, argmax() returns the index of the first occurrence:

a = np.array([5, 9, 3, 9, 1])
print(np.argmax(a))  # => 1  (first occurrence of 9)

# Find ALL positions of the maximum
max_val = a.max()
all_max_positions = np.where(a == max_val)[0]
print(all_max_positions)  # => [1, 3]

Summary

Function Purpose
np.argmax(a) Flat index of max in entire array
np.argmax(a, axis=0) Index of max along columns
np.argmax(a, axis=1) Index of max along rows
np.unravel_index(idx, shape) Convert flat index to N-D index
np.where(a == a.max()) All positions of the maximum value

Resources

Comments