Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Algorithm for tensordot implemented in numba is much slower than numpy's

I am trying to expand the numpy "tensordot" such that things like: K_ijklm = A_ki * B_jml can be written in a clear way like this: K = mytensordot(A,B,[2,0],[1,4,3])

To my understanding, numpy's tensordot (with optional argument 0) would be able to do something like this: K_kijml = A_ki * B_jml, i.e. keeping the order of the indexes. Therefore I would then have to do a number of np.swapaxes() to obtain the matrix `K_ijklm', which in a complicated case can be an easy source of errors (potentially very hard to debug).

The problem is that my implementation is slow (10x slower than tensordot [EDIT: It is actually MUCH slower than that]), even when using numba. I was wondering if anyone would have some insight on what could be done to improve the performance of my algorithm.

MWE

import numpy as np
import numba as nb
import itertools
import timeit

@nb.jit()
def myproduct(dimN):
    N=np.prod(dimN)
    L=len(dimN)
    Product=np.zeros((N,L),dtype=np.int32)
    rn=0
    for n in range(1,N):
        for l in range(L):
            if l==0:
                rn=1
            v=Product[n-1,L-1-l]+rn
            rn = 0
            if v == dimN[L-1-l]:
                v = 0
                rn = 1
            Product[n,L-1-l]=v
    return Product

@nb.jit()
def mytensordot(A,B,iA,iB):
    iA,iB = np.array(iA,dtype=np.int32),np.array(iB,dtype=np.int32)
    dimA,dimB = A.shape,B.shape
    NdimA,NdimB=len(dimA),len(dimB)

    if len(iA) != NdimA: raise ValueError("iA must be same size as dim A")
    if len(iB) != NdimB: raise ValueError("iB must be same size as dim B")

    NdimN = NdimA + NdimB
    dimN=np.zeros(NdimN,dtype=np.int32)
    dimN[iA]=dimA
    dimN[iB]=dimB
    Out=np.zeros(dimN)
    indexes = myproduct(dimN)

    for nidxs in indexes:
        idxA = tuple(nidxs[iA])
        idxB = tuple(nidxs[iB])
        v=A[(idxA)]*B[(idxB)]
        Out[tuple(nidxs)]=v
    return Out



A=np.random.random((4,5,3))
B=np.random.random((6,4))

def runmytdot():
    return mytensordot(A,B,[0,2,3],[1,4])
def runtensdot():
    return np.tensordot(A,B,0).swapaxes(1,3).swapaxes(2,3)


print(np.all(runmytdot()==runtensdot()))
print(timeit.timeit(runmytdot,number=100))
print(timeit.timeit(runtensdot,number=100))

Result:

True
1.4962144780438393
0.003484356915578246
like image 946
Miguel Avatar asked Oct 19 '25 21:10

Miguel


1 Answers

You have run into a known issue. numpy.zeros requires a tuple when creating a multidimensional array. If you pass something other than a tuple, it sometimes works, but that's only because numpy is smart about converting the object into a tuple first.

The trouble is that numba does not currently support conversion of arbitrary iterables into tuples. So this line fails when you try to compile it in nopython=True mode. (A couple of others fail too, but this is the first.)

Out=np.zeros(dimN)

In theory you could call np.prod(dimN), create a flat array of zeros, and reshape it, but then you run into the very same problem: the reshape method of numpy arrays requires a tuple!

This is quite a vexing problem with numba -- I had not encountered it before. I really doubt the solution I have found is the correct one, but it is a working solution that allows us to compile a version in nopython=True mode.

The core idea is to avoid using tuples for indexing by directly implementing an indexer that follows the strides of the array:

@nb.jit(nopython=True)
def index_arr(a, ix_arr):
    strides = np.array(a.strides) / a.itemsize
    ix = int((ix_arr * strides).sum())
    return a.ravel()[ix]

@nb.jit(nopython=True)
def index_set_arr(a, ix_arr, val):
    strides = np.array(a.strides) / a.itemsize
    ix = int((ix_arr * strides).sum())
    a.ravel()[ix] = val

This allows us to get and set values without needing a tuple.

We can also avoid using reshape by passing the output buffer into the jitted function, and wrapping that function in a helper:

@nb.jit()  # We can't use nopython mode here...
def mytensordot(A, B, iA, iB):
    iA, iB = np.array(iA, dtype=np.int32), np.array(iB, dtype=np.int32)
    dimA, dimB = A.shape, B.shape
    NdimA, NdimB = len(dimA), len(dimB)

    if len(iA) != NdimA:
        raise ValueError("iA must be same size as dim A")
    if len(iB) != NdimB:
        raise ValueError("iB must be same size as dim B")

    NdimN = NdimA + NdimB
    dimN = np.zeros(NdimN, dtype=np.int32)
    dimN[iA] = dimA
    dimN[iB] = dimB
    Out = np.zeros(dimN)
    return mytensordot_jit(A, B, iA, iB, dimN, Out)

Since the helper contains no loops, it adds some overhead, but the overhead is pretty trivial. Here's the final jitted function:

@nb.jit(nopython=True)
def mytensordot_jit(A, B, iA, iB, dimN, Out):
    for i in range(np.prod(dimN)):
        nidxs = int_to_idx(i, dimN)
        a = index_arr(A, nidxs[iA])
        b = index_arr(B, nidxs[iB])
        index_set_arr(Out, nidxs, a * b)
    return Out

Unfortunately, this does not wind up generating as much of a speedup as we might like. On smaller arrays it's about 5x slower than tensordot; on larger arrays it's still 50x slower. (But at least it's not 1000x slower!) This is not too surprising in retrospect, since dot and tensordot are both using BLAS under the hood, as @hpaulj reminds us.

After finishing this code, I saw that einsum has solved your real problem -- nice!

But the underlying issue that your original question points to -- that indexing with arbitrary-length tuples is not possible in jitted code -- is still a frustration. So hopefully this will be useful to someone else!

like image 74
senderle Avatar answered Oct 22 '25 11:10

senderle



Donate For Us

If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!