Skip to content
Advertisement

Slice multidimensional numpy array from max in a given axis

I have a 3-dimensional array a of shape (n, m, l). I extract one column j from it’s last axis and compute the maximum index along the first axis like follows:

sub = a[:, :, j]  # shape (n, m)
wheremax = np.argmax(sub, axis=0)  # this have a shape of m

Now I’d like to slice the original array a to get all the information based on the index where the column j is maximal. I.e. I’d like an numpythonic way to do the following using array broadcasting or numpy functions:

new_arr = np.zeros((m, l))
for i, idx in enumerate(wheremax):
    new_arr[i, :] = a[idx, i, :]
a = new_arr

Is there one?

Advertisement

Answer

As @hpaulj mentionned in the comments, using a[wheremax, np.arange(m)] did the trick.

User contributions licensed under: CC BY-SA
8 People found this is helpful
Advertisement