I need to keep the max N (3) values per row in an Array.
a=np.array([[1,2,3,4],[8,7,6,5],[5,3,1,2]])
a
Out[135]:
array([[1, 2, 3, 4],
[8, 7, 6, 5],
[5, 3, 1, 2]])
The indexes of those can be identified with np.partition:
n=3
np.argpartition(a, -n, axis=1)[:,-n:]
Out[136]:
array([[1, 2, 3],
[2, 1, 0],
[3, 0, 1]], dtype=int64)
So, my question is: How should I keep values from those indices and set to zero others to get:
Out[136]:
array([[0, 2, 3, 4],
[8, 7, 6, 0],
[5, 3, 0, 2]])
a=np.array([[1,2,3,4],[8,7,6,5],[5,3,1,2]])
n=3
mask = np.argpartition(a, -n, axis=1) < a.shape[1] - n
a[mask] = 0
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With