Introduction
Finding the maximum value in a NumPy array is straightforward with np.max(). But finding the position (index) of that maximum โ especially in multi-dimensional arrays โ requires argmax() and unravel_index(). This guide covers both functions with practical examples.
argmax() in 1D Arrays
For a 1D array, argmax() returns the index of the largest element:
import numpy as np
a = np.array([3, 7, 1, 9, 2, 5])
print(np.argmax(a)) # => 3 (index of value 9)
print(a[np.argmax(a)]) # => 9
argmax() in 2D Arrays
For multi-dimensional arrays, argmax() by default flattens the array and returns the index in the flattened version:
state = np.array([[8, 5],
[3, 2],
[5, 1],
[9, 6]])
print(state.argmax()) # => 6
# Flattened: [8, 5, 3, 2, 5, 1, 9, 6]
# indices: 0 1 2 3 4 5 6 7
# Value 9 is at flat index 6
That flat index 6 isn’t very useful on its own. To convert it back to a 2D index, use np.unravel_index().
unravel_index: Convert Flat Index to Multi-Dimensional
np.unravel_index(flat_index, shape) converts a flat index into a tuple of indices for each dimension:
state = np.array([[8, 5],
[3, 2],
[5, 1],
[9, 6]])
max_flat_idx = state.argmax() # => 6
max_2d_idx = np.unravel_index(max_flat_idx, state.shape)
print(max_2d_idx) # => (3, 0)
print(state[3, 0]) # => 9 (confirmed)
print(state[max_2d_idx]) # => 9
The result (3, 0) means row 3, column 0 โ which is indeed where 9 lives.
One-Liner Pattern
# Get the 2D position of the maximum value in one line
row, col = np.unravel_index(state.argmax(), state.shape)
print(f"Max value {state[row, col]} at row={row}, col={col}")
# => Max value 9 at row=3, col=0
argmax() Along an Axis
You can find the argmax along a specific axis without flattening:
scores = np.array([[85, 92, 78],
[90, 88, 95],
[70, 85, 80]])
# Index of max value in each column (axis=0)
print(np.argmax(scores, axis=0)) # => [1, 0, 1]
# Column 0: max is 90 at row 1
# Column 1: max is 92 at row 0
# Column 2: max is 95 at row 1
# Index of max value in each row (axis=1)
print(np.argmax(scores, axis=1)) # => [1, 2, 1]
# Row 0: max is 92 at col 1
# Row 1: max is 95 at col 2
# Row 2: max is 85 at col 1
Finding Top-N Positions
To find the positions of the top N values:
a = np.array([[3, 8, 1],
[7, 2, 9],
[4, 6, 5]])
# Flatten, get indices of top 3 values
flat_top3 = np.argsort(a.flatten())[-3:][::-1]
top3_positions = [np.unravel_index(i, a.shape) for i in flat_top3]
print(top3_positions)
# => [(1, 2), (0, 1), (1, 0)] โ positions of 9, 8, 7
Or more concisely with np.argpartition:
# Faster for large arrays: argpartition doesn't fully sort
flat_top3 = np.argpartition(a.flatten(), -3)[-3:]
top3_positions = [np.unravel_index(i, a.shape) for i in flat_top3]
Practical Example: Reinforcement Learning Q-Table
A common use case is finding the best action in a Q-table (used in reinforcement learning):
# Q-table: rows = states, columns = actions
q_table = np.array([
[0.1, 0.5, 0.3, 0.8], # state 0: best action is 3 (value 0.8)
[0.9, 0.2, 0.7, 0.4], # state 1: best action is 0 (value 0.9)
[0.3, 0.6, 0.1, 0.5], # state 2: best action is 1 (value 0.6)
])
# Best action for each state
best_actions = np.argmax(q_table, axis=1)
print(best_actions) # => [3, 0, 1]
# Best action for a specific state
current_state = 1
best_action = np.argmax(q_table[current_state])
print(f"State {current_state}: best action = {best_action}")
# => State 1: best action = 0
Handling Ties
When multiple elements share the maximum value, argmax() returns the index of the first occurrence:
a = np.array([5, 9, 3, 9, 1])
print(np.argmax(a)) # => 1 (first occurrence of 9)
# Find ALL positions of the maximum
max_val = a.max()
all_max_positions = np.where(a == max_val)[0]
print(all_max_positions) # => [1, 3]
Summary
| Function | Purpose |
|---|---|
np.argmax(a) |
Flat index of max in entire array |
np.argmax(a, axis=0) |
Index of max along columns |
np.argmax(a, axis=1) |
Index of max along rows |
np.unravel_index(idx, shape) |
Convert flat index to N-D index |
np.where(a == a.max()) |
All positions of the maximum value |
Resources
- NumPy argmax docs
- NumPy unravel_index docs
- Stack Overflow: Get position of biggest item in a multi-dimensional array
Comments