Skip to content
Advertisement

How to get NumPy array of n last/first Trues in each row

I have a broadcasted array, which is sorted rowly, and a masked array. I want to get the last n elements (or the first n ones) of each row which are True i.e.:

a = np.array([[0.00298, 0.00455, 0.00767, 0.00939, 0.01104, 0.02351, 0.03370],
              [0.00298, 0.00455, 0.00767, 0.00939, 0.01104, 0.02351, 0.03370],
              [0.00298, 0.00455, 0.00767, 0.00939, 0.01104, 0.02351, 0.03370]])

mask = np.array([[1, 0, 0, 1, 1, 0, 1], [0, 1, 0, 0, 0, 1, 1], [0, 0, 1, 1, 0, 0, 0]], dtype=bool)

# a[mask] --> [0.00298  0.00939  0.01104  0.0337  0.00455  0.02351  0.0337  0.00767  0.00939]

# needed last  two --> [[0.01104  0.0337 ]  [0.02351  0.0337 ]  [0.00767  0.00939]]
# needed first two --> [[0.00298  0.00939]  [0.00455  0.02351]  [0.00767  0.00939]]

Do we have to split the array (using np.cumsum(np.sum(mask, axis=1))), pad and …?
What will be the best way to do this just with NumPy?

Advertisement

Answer

Using numpy to get the first n True:

n=2
a[(np.cumsum(mask, axis=1)<=n)&mask].reshape(-1,n)

Output:

array([[0.00298, 0.00939],
       [0.00455, 0.02351],
       [0.00767, 0.00939]])

Last n:

n=2
a[(np.cumsum(mask[:,::-1], axis=1)<=n)[:,::-1]&mask].reshape(-1,n)

Output:

array([[0.01104, 0.0337 ],
       [0.02351, 0.0337 ],
       [0.00767, 0.00939]])

NB. There must be at least n True per row to have the correct final shape

User contributions licensed under: CC BY-SA
8 People found this is helpful
Advertisement