Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Reducing an axis with numpy

Tags:

python

numpy

I have an NxMx3 numpy array with dtype=object. I also have a function f(a,b,c) which takes the three elements in the last axis of this array and returns a np.int32. My question is how do I apply f to my NxMx3 array to yield an NxM array with dtype=np.int32?

My current solution is to use

newarr = np.fromfunction(lambda i,j: f(arr[i,j,0], arr[i,j,1], arr[i,j,2]),
                          arr.shape[:2], dtype=np.int)

although this is a little more verbose than I had hoped.

like image 515
Freddie Witherden Avatar asked Jun 26 '26 03:06

Freddie Witherden


1 Answers

You could use vectorize:

np.vectorize(f, otypes=[np.int32])(arr[:, :, 0], arr[:, :, 1], arr[:, :, 2])

This can be simplified by axis rolling and iteration:

np.vectorize(f, otypes=[np.int32])(*np.rollaxis(arr, 2, 0))

Alternatively you can split the array explicitly with dsplit:

np.vectorize(f, otypes=[np.int32])(*np.dsplit(arr, 3))[..., 0]

or

np.vectorize(f, otypes=[np.int32])(*np.dsplit(arr, 3)).reshape(arr.shape[:-1])

or

np.vectorize(f, otypes=[np.int32])(*np.dsplit(arr, 3)).squeeze()

However, apply_along_axis is probably simpler:

np.apply_along_axis(lambda x: f(*x), 2, arr)
like image 111
ecatmur Avatar answered Jun 27 '26 16:06

ecatmur



Donate For Us

If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!