Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

how to force numba to return a numpy type?

Tags:

python

numba

I find this behavior quite counter-intuitive although I suppose there is a reason for it - numba automatically converts my numpy integer types directly into a python int:

import numba as nb
import numpy as np 

print(f"Numba version: {nb.__version__}")  # 0.59.0
print(f"NumPy version: {np.__version__}")  # 1.23.5

# Explicitly define the signature
sig = nb.uint32(nb.uint32, nb.uint32)

@nb.njit(sig, cache=False)
def test_fn(a, b):
    return a * b

res = test_fn(2, 10)
print(f"Result value: {res}")  # returns 20
print(f"Result type: {type(res)}")  # returns <class 'int'>

This is an issue as I'm using the return as an input into another njit function so I get a casting warning (and I also do unnecessary casts in-between the njit functions)

Is there any way to force numba to give me np.uint32 as a result instead?

--- EDIT ---

This is the best I've managed to do myself, however I refuse to believe this is the best implementation out there:

# we manually define a return record and pass it as a parameter
res_type = np.dtype([('res', np.uint32)])
sig = nb.void(nb.uint32, nb.uint32, nb.from_dtype(res_type))

@nb.njit(sig, cache=False)
def test_fn(a:np.uint32, b:np.uint32, res: res_type):
    res['res'] = a * b

# Call with Python ints (Numba should coerce based on signature)
res = np.recarray(1, dtype=res_type)[0]
res_py_in = test_fn(2, 10, res)
print(f"\nCalled with Python ints:")
print(f"Result value: {res['res']}")  # 20
print(f"Result type: {type(res['res'])}")  # <class 'numpy.uint32'>

--- EDIT 2 --- as @Nin17 correctly pointed out actually returning an int object is still about 3 times quicker when called from python context, so its better to just return a simple int and cast as needed.

like image 457
Raven Avatar asked Dec 06 '25 15:12

Raven


2 Answers

Why don't you just return np.uint32(a*b):

@nb.njit(nb.uint32(nb.uint32, nb.uint32))
def func(a, b):
    return np.uint32(a * b)

It is faster and more readable than the other solutions:

import numba as nb
import numpy as np

@nb.njit(nb.types.Array(nb.uint32, 0, "C")(nb.uint32, nb.uint32))
def test_fn(a, b):
    res = np.empty((), dtype=np.uint32)
    res[...] = a * b
    return res

res_type = np.dtype([('res', np.uint32)])
sig = nb.void(nb.uint32, nb.uint32, nb.from_dtype(res_type))

@nb.njit(sig)
def test_fn2(a, b, out):
    out['res'] = a * b

res = np.recarray(1, dtype=res_type)[0]
test_fn2(np.uint32(2), np.uint32(10), res)

a = np.uint32(2)
b = np.uint32(10)


%timeit test_fn(a, b)
%timeit test_fn2(a, b, res)
%timeit func(a, b)

Output:

339 ns ± 4.67 ns per loop (mean ± std. dev. of 7 runs, 1,000,000 loops each)
426 ns ± 1.01 ns per loop (mean ± std. dev. of 7 runs, 1,000,000 loops each)
126 ns ± 0.111 ns per loop (mean ± std. dev. of 7 runs, 10,000,000 loops each)
N = int(1e7)
@nb.njit
def _test_fn(a, b):
    out = np.empty((N,), dtype=np.uint32)
    for i in range(N):
        out[i] = test_fn(a, b).item()
    return out

@nb.njit
def _test_fn2(a, b, res):
    out = np.empty((N,), dtype=np.uint32)
    for i in range(N):
        test_fn2(a, b, res)
        out[i] = res['res']
    return out

@nb.njit
def _func(a, b):
    out = np.empty((N,), dtype=np.uint32)
    for i in range(N):
        out[i] = func(a, b)
    return out

_test_fn(a, b)
_test_fn2(a, b, res)
_func(a, b)

%timeit _test_fn(a, b)
%timeit _test_fn2(a, b, res)
%timeit _func(a, b)

Output:

254 ms ± 508 μs per loop (mean ± std. dev. of 7 runs, 1 loop each)
3.44 ms ± 40.4 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)
3.37 ms ± 19.9 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)
like image 186
Nin17 Avatar answered Dec 10 '25 10:12

Nin17


You might like this one as it is clean, Less verbose, interoperable with other Numba functions and return value behaves like a scalar but has correct dtype. Best for general purpose.

import numba as nb
import numpy as np

@nb.njit
def test_fn(a, b):
    res = np.empty((), dtype=np.uint32)
    res[...] = a * b
    return res

# Test
a = np.uint32(2)
b = np.uint32(10)
result = test_fn(a, b)
print("Result value:", result)
print("Type of result:", type(result))
print("Dtype of result:", result.dtype)
print("As Python int (optional):", result.item())

Another one (it avoids memory allocation, more efficient in tight loops and best for performance-critical scenarios):

res_type = np.dtype([('res', np.uint32)])
sig = nb.void(nb.uint32, nb.uint32, nb.from_dtype(res_type))

@nb.njit(sig)
def test_fn(a, b, out):
    out['res'] = a * b

res = np.recarray(1, dtype=res_type)[0]
test_fn(np.uint32(2), np.uint32(10), res)
print("Result value:", res['res'])
print("Type of result:", type(res['res']))
print("As Python int (optional):", int(res['res']))

Ouput:

D:\python>python test.py
Result value: 20
Type of result: <class 'numpy.uint32'>
As Python int (optional): 20
like image 43
Subir Chowdhury Avatar answered Dec 10 '25 11:12

Subir Chowdhury



Donate For Us

If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!