Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Sort numpy matrix based on index-matrix

In numpy, we can sort arrays like this:

>>> import numpy as np

>>> a = np.array([0, 100, 200])
>>> order = np.array([1, 2, 0])
>>> print(a[order])
[100 200   0]

However, this does not work when the "order" is a matrix:

>>> A = np.array([    [0, 1, 2],
                      [3, 4, 5],
                      [6, 7, 8]])

>>> Ord = np.array([  [1, 0, 2],
                      [0, 2, 1],
                      [2, 1, 0]])

>>> print(A[Ord].shape)
(3, 3, 3)

I would like to have "A" sorted like this:

array([[1, 0, 2],
       [3, 5, 4],
       [8, 7, 6]])
like image 513
Felip Avatar asked Sep 05 '25 03:09

Felip


1 Answers

You could use np.take_along_axis for this.

np.take_along_axis(A, Ord, axis=1)

Output

array([[1, 0, 2],
       [3, 5, 4],
       [8, 7, 6]])

As stated in the documentation it is often used together with functions that produce indices, like argsort. But I am not sure if this would generalize for more than 2 dimensions.

like image 161
Kevin Avatar answered Sep 07 '25 20:09

Kevin