I am trying to use Numba to parallelize a Python function which takes two numpy ndarrays, alpha and beta, as arguments. They respectively have shape of the form (a,m,n) and (b,m,n) and are thus broadcastable over the latter dimensions. The function computes the matrix dot product (Frobenius product) of 2D slices of the arguments and finds the slice of beta which maximizes this product for each slice of alpha. In code:
@njit(parallel=True)
def parallel_value(alpha,beta):
values=np.empty(alpha.shape[0])
indices=np.empty(alpha.shape[0])
for i in prange(alpha.shape[0]):
dot=np.sum(alpha[i,:,:]*beta,axis=(1,2))
index=np.argmax(dot)
values[i]=dot[index]
indices[i]=index
return values,indices
This runs fine without the njit decorator, but the Numba compiler complains:
No implementation of function Function(<built-in function setitem>) found for signature:
>>>setitem(array(float64, 1d, C), int64, array(float64, 1d, C))
The offending line is apparently values[i]=dot[index]. I have no idea why this is problematic. What is the cause of this issue, and how do I fix it?
Also, would there be any advantage to adding nogil=True to the arguments of @njit?
I managed to reproduce your problem. When running the code:
import numpy as np
from numba import njit, prange
@njit(parallel=True)
def parallel_value(alpha,beta):
values=np.empty(alpha.shape[0])
indices=np.empty(alpha.shape[0])
for i in prange(alpha.shape[0]):
dot=np.sum(alpha[i,:,:]*beta,axis=(1,2))
index=np.argmax(dot)
values[i]=dot[index]
indices[i]=index
return values,indices
a, b, m, n = 6, 5, 4, 3
parallel_value(np.random.rand(a, m, n), np.random.rand(b, m, n))
I get the error message:
TypingError: Failed in nopython mode pipeline (step: nopython frontend)
No implementation of function Function(<built-in function setitem>) found for signature:
>>> setitem(array(float64, 1d, C), int64, array(float64, 1d, C))
There are 16 candidate implementations:
- Of which 16 did not match due to:
Overload of function 'setitem': File: <numerous>: Line N/A.
With argument(s): '(array(float64, 1d, C), int64, array(float64, 1d, C))':
No match.
During: typing of setitem at <ipython-input-41-44518cf5219f> (11)
File "<ipython-input-41-44518cf5219f>", line 11:
def parallel_value(alpha,beta):
<source elided>
index=np.argmax(dot)
values[i]=dot[index]
^
according to this issue in the GitHub page, there might be a problem with dot operations in numba.
When i rewrote the code using explicit loops it seems to work:
import numpy as np
from numba import njit, prange
@njit(parallel=True)
def parallel_value_numba(alpha,beta):
values = np.empty(alpha.shape[0])
indices = np.empty(alpha.shape[0])
for i in prange(alpha.shape[0]):
dot = np.zeros(beta.shape[0])
for j in prange(beta.shape[0]):
for k in prange(beta.shape[1]):
for l in prange(beta.shape[2]):
dot[j] += alpha[i,k,l]*beta[j, k, l]
index=np.argmax(dot)
values[i]=dot[index]
indices[i]=index
return values,indices
def parallel_value_nonumba(alpha,beta):
values=np.empty(alpha.shape[0])
indices=np.empty(alpha.shape[0])
for i in prange(alpha.shape[0]):
dot=np.sum(alpha[i,:,:]*beta,axis=(1,2))
index=np.argmax(dot)
values[i]=dot[index]
indices[i]=index
return values,indices
a, b, m, n = 6, 5, 4, 3
np.random.seed(42)
A = np.random.rand(a, m, n)
B = np.random.rand(b, m, n)
res_num = parallel_value_numba(A, B)
res_nonum = parallel_value_nonumba(A, B)
print(f'res_num = {res_num}')
print(f'res_nonum = {res_nonum}')
output:
res_num = (array([3.52775653, 2.49947515, 3.33824146, 2.9669794 , 3.78968905,
3.43988156]), array([1., 3., 1., 1., 1., 1.]))
res_nonum = (array([3.52775653, 2.49947515, 3.33824146, 2.9669794 , 3.78968905,
3.43988156]), array([1., 3., 1., 1., 1., 1.]))
As far as I can see, the explicit loops don't seem to hinder performance. Although I can't compare it to running the same code without them, because this is numba, my guess is that it won't matter:
%timeit res_num = parallel_value_numba(A, B)
%timeit res_nonum = parallel_value_nonumba(A, B)
output:
The slowest run took 1472.03 times longer than the fastest. This could mean that an intermediate result is being cached.
100000 loops, best of 5: 4.92 µs per loop
10000 loops, best of 5: 76.9 µs per loop
Lastly, you can do it more efficiently with just numpy by vectorizing the code you're using. Iit's almost as fast as numba with explicit loops and you won't have that initial compilation dalay. Here's how you might do it:
def parallel_value_np(alpha,beta):
alpha = alpha.reshape(alpha.shape[0], 1, alpha.shape[1], alpha.shape[2])
beta = beta.reshape(1, beta.shape[0], beta.shape[1], beta.shape[2])
dot = np.sum(alpha*beta, axis=(2,3))
indices = np.argmax(dot, axis = 1)
values = dot[np.arange(len(indices)), indices]
return values,indices
res_np = parallel_value_np(A, B)
print(f'res_num = {res_np}')
%timeit res_num = parallel_value_numba(A, B)
Output:
res_num = (array([3.52775653, 2.49947515, 3.33824146, 2.9669794 , 3.78968905,
3.43988156]), array([1, 3, 1, 1, 1, 1]))
The slowest run took 5.46 times longer than the fastest. This could mean that an intermediate result is being cached.
100000 loops, best of 5: 16.1 µs per loop
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With