I use numpy.argsort all the time for 1D data, but it seems to behaving differently in 2D.
For example, let’s say I want to argsort this array along axis 1 so the items in each row are in ascending order
>>> import numpy as np >>> arr = np.eye(4) >>> arr array([[1., 0., 0., 0.], [0., 1., 0., 0.], [0., 0., 1., 0.], [0., 0., 0., 1.]]) >>> idx = np.argsort(arr, axis=1) >>> idx array([[1, 2, 3, 0], [0, 2, 3, 1], [0, 1, 3, 2], [0, 1, 2, 3]])
All fine so far.
Each row in the above gives the order to how the columns should be rearranged in the second array.
Let’s say we want to sort the array below with the above idx
.
>>> arr2 = np.arange(16).reshape((4, 4)) >>> arr2 array([[ 0, 1, 2, 3], [ 4, 5, 6, 7], [ 8, 9, 10, 11], [12, 13, 14, 15]]) >>> sorted = arr2[idx] >>> sorted array([[[ 8, 9, 10, 11], [12, 13, 14, 15], [ 0, 1, 2, 3], [ 4, 5, 6, 7]], .... [[ 8, 9, 10, 11], [12, 13, 14, 15], [ 0, 1, 2, 3], [ 4, 5, 6, 7]]]) >>> sorted.shape (10, 4, 4)
The shape now has an added dimensions.
I was expecting to get.
array([[ 1, 2, 3, 0], [ 4, 6, 7, 5], [ 8, 9, 11, 10], [12, 13, 14, 15]])
I can do this iterating over the rows, which is bad!
>>> rows = [] >>> for i, row in enumerate(arr2): ... rows.append(row[idx[i]]) >>> np.arrays(rows) array([[ 1, 2, 3, 0], [ 4, 6, 7, 5], [ 8, 9, 11, 10], [12, 13, 14, 15]])
Advertisement
Answer
np.take_along_axis
has an example using argsort
:
>>> a = np.array([[10, 30, 20], [60, 40, 50]]) We can sort either by using sort directly, or argsort and this function >>> np.sort(a, axis=1) array([[10, 20, 30], [40, 50, 60]]) >>> ai = np.argsort(a, axis=1); ai array([[0, 2, 1], [1, 2, 0]]) >>> np.take_along_axis(a, ai, axis=1) array([[10, 20, 30], [40, 50, 60]])
This streamlines a process of applying ai
to the array itself. We can do that directly, but it requires a bit more thought about what the index actually represents.
In this example, ai
are index values along axis 1 (values like 0,1,or 2). This (2,3) has to broadcast
with a (2,1) array for axis 0:
In [247]: a[np.arange(2)[:,None], ai] Out[247]: array([[10, 20, 30], [40, 50, 60]])