Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

cross products with einsums

I'm trying to compute the cross-products of many 3x1 vector pairs as fast as possible. This

n = 10000
a = np.random.rand(n, 3)
b = np.random.rand(n, 3)
numpy.cross(a, b)

gives the correct answer, but motivated by this answer to a similar question, I thought that einsum would get me somewhere. I found that both

eijk = np.zeros((3, 3, 3))
eijk[0, 1, 2] = eijk[1, 2, 0] = eijk[2, 0, 1] = 1
eijk[0, 2, 1] = eijk[2, 1, 0] = eijk[1, 0, 2] = -1

np.einsum('ijk,aj,ak->ai', eijk, a, b)
np.einsum('iak,ak->ai', np.einsum('ijk,aj->iak', eijk, a), b)

compute the cross product, but their performance is disappointing: Both methods perform much worse than np.cross:

%timeit np.cross(a, b)
1000 loops, best of 3: 628 µs per loop
%timeit np.einsum('ijk,aj,ak->ai', eijk, a, b)
100 loops, best of 3: 9.02 ms per loop
%timeit np.einsum('iak,ak->ai', np.einsum('ijk,aj->iak', eijk, a), b)
100 loops, best of 3: 10.6 ms per loop

Any ideas of how to improve the einsums?

like image 765
Nico Schlömer Avatar asked Oct 16 '25 16:10

Nico Schlömer


1 Answers

The count of multiply operation of einsum() is more then cross(), and in the newest NumPy version, cross() doesn't create many temporary arrays. So einsum() can't be faster than cross().

Here is the old code of cross:

x = a[1]*b[2] - a[2]*b[1]
y = a[2]*b[0] - a[0]*b[2]
z = a[0]*b[1] - a[1]*b[0]

Here is the new code of cross:

multiply(a1, b2, out=cp0)
tmp = array(a2 * b1)
cp0 -= tmp
multiply(a2, b0, out=cp1)
multiply(a0, b2, out=tmp)
cp1 -= tmp
multiply(a0, b1, out=cp2)
multiply(a1, b0, out=tmp)
cp2 -= tmp

To speedup it, you need cython or numba.

like image 173
HYRY Avatar answered Oct 18 '25 05:10

HYRY



Donate For Us

If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!