Introduction
Finding the maximum value in a NumPy array is straightforward with np.max(). But finding the position (index) of that maximum — especially in multi-dimensional arrays — requires argmax() and unravel_index(). This guide covers both functions with practical examples. See Python Guide for more context. See Python Guide for more context. See Python Guide for more context.
argmax() in 1D Arrays
For a 1D array, argmax() returns the index of the largest element:
import numpy as np
a = np.array([3, 7, 1, 9, 2, 5])
print(np.argmax(a)) # => 3 (index of value 9)
print(a[np.argmax(a)]) # => 9
argmax() in 2D Arrays
For multi-dimensional arrays, argmax() by default flattens the array and returns the index in the flattened version:
state = np.array([[8, 5],
[3, 2],
[5, 1],
[9, 6]])
print(state.argmax()) # => 6
# Flattened: [8, 5, 3, 2, 5, 1, 9, 6]
# indices: 0 1 2 3 4 5 6 7
# Value 9 is at flat index 6
That flat index 6 isn’t very useful on its own. To convert it back to a 2D index, use np.unravel_index().
unravel_index: Convert Flat Index to Multi-Dimensional
np.unravel_index(flat_index, shape) converts a flat index into a tuple of indices for each dimension:
state = np.array([[8, 5],
[3, 2],
[5, 1],
[9, 6]])
max_flat_idx = state.argmax() # => 6
max_2d_idx = np.unravel_index(max_flat_idx, state.shape)
print(max_2d_idx) # => (3, 0)
print(state[3, 0]) # => 9 (confirmed)
print(state[max_2d_idx]) # => 9
The result (3, 0) means row 3, column 0 — which is indeed where 9 lives.
One-Liner Pattern
# Get the 2D position of the maximum value in one line
row, col = np.unravel_index(state.argmax(), state.shape)
print(f"Max value {state[row, col]} at row={row}, col={col}")
# => Max value 9 at row=3, col=0
argmax() Along an Axis
You can find the argmax along a specific axis without flattening:
scores = np.array([[85, 92, 78],
[90, 88, 95],
[70, 85, 80]])
# Index of max value in each column (axis=0)
print(np.argmax(scores, axis=0)) # => [1, 0, 1]
# Column 0: max is 90 at row 1
# Column 1: max is 92 at row 0
# Column 2: max is 95 at row 1
# Index of max value in each row (axis=1)
print(np.argmax(scores, axis=1)) # => [1, 2, 1]
# Row 0: max is 92 at col 1
# Row 1: max is 95 at col 2
# Row 2: max is 85 at col 1
Finding Top-N Positions
To find the positions of the top N values:
a = np.array([[3, 8, 1],
[7, 2, 9],
[4, 6, 5]])
# Flatten, get indices of top 3 values
flat_top3 = np.argsort(a.flatten())[-3:][::-1]
top3_positions = [np.unravel_index(i, a.shape) for i in flat_top3]
print(top3_positions)
# => [(1, 2), (0, 1), (1, 0)] — positions of 9, 8, 7
Or more concisely with np.argpartition:
# Faster for large arrays: argpartition doesn't fully sort
flat_top3 = np.argpartition(a.flatten(), -3)[-3:]
top3_positions = [np.unravel_index(i, a.shape) for i in flat_top3]
Practical Example: Reinforcement Learning Q-Table
A common use case is finding the best action in a Q-table (used in reinforcement learning):
# Q-table: rows = states, columns = actions
q_table = np.array([
[0.1, 0.5, 0.3, 0.8], # state 0: best action is 3 (value 0.8)
[0.9, 0.2, 0.7, 0.4], # state 1: best action is 0 (value 0.9)
[0.3, 0.6, 0.1, 0.5], # state 2: best action is 1 (value 0.6)
])
# Best action for each state
best_actions = np.argmax(q_table, axis=1)
print(best_actions) # => [3, 0, 1]
# Best action for a specific state
current_state = 1
best_action = np.argmax(q_table[current_state])
print(f"State {current_state}: best action = {best_action}")
# => State 1: best action = 0
Handling Ties
When multiple elements share the maximum value, argmax() returns the index of the first occurrence:
a = np.array([5, 9, 3, 9, 1])
print(np.argmax(a)) # => 1 (first occurrence of 9)
# Find ALL positions of the maximum
max_val = a.max()
all_max_positions = np.where(a == max_val)[0]
print(all_max_positions) # => [1, 3]
Summary
| Function | Purpose |
|---|---|
np.argmax(a) |
Flat index of max in entire array |
np.argmax(a, axis=0) |
Index of max along columns |
np.argmax(a, axis=1) |
Index of max along rows |
np.unravel_index(idx, shape) |
Convert flat index to N-D index |
np.where(a == a.max()) |
All positions of the maximum value |
Resources
- NumPy argmax docs
- NumPy unravel_index docs
- Stack Overflow: Get position of biggest item in a multi-dimensional array
Comments