Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Numba Signature for jitted function as argument

I've looked in the numba documentation, but couldn't find anything.

I have a function to jit that takes a jitted_function as argument. I want to make eager compilation by adding a signature, just like:

@jit(float64('jit_func.type', int32, int32...))

'jitted_func.type' should be the "function type"

when I do:

type(jitted_func)

I get a CPUDispatcher object

thanks for your help!

like image 297
Darkonia Avatar asked Oct 26 '25 04:10

Darkonia


1 Answers

I am also looking for a solution to this. Unfortunately, @Carbon's suggestion does not work, because the type returned by numba.typeof for a function bar is different to the type of a function baz, even if the signatures of bar and baz are the same.

Example:

import numba 

@numba.jit(
    numba.int32(numba.int32),
    nopython=True,
    nogil=True,
)
def bar(a):

    return 2 * a

@numba.jit(
    numba.int32(numba.int32),
    nopython=True,
    nogil=True,
)
def baz(a):

    return 3 * a

@numba.jit(
    numba.int32(numba.typeof(bar), numba.int32),
    nopython=True,
    nogil=True,
)
def foo(fn, a):

    return fn(a)

foo(bar, 2) returns 4

foo(baz, 2) returns the following exception:

Traceback (most recent call last):
  File "test_numba.py", line 33, in <module>
    print(foo(baz, 2))
  File "<snip>\Python38\lib\site-packages\numba\core\dispatcher.py", line 656, in _explain_matching_error
    raise TypeError(msg)
TypeError: No matching definition for argument type(s) type(CPUDispatcher(<function baz at 0x000001DFA8C2D1F0>)), int64

The only workaround I've found is to omit the the function signature for foo entirely, and let numba figure it out. I don't know what negative consequences that has (if any) it may get your code running.

Example:

import numba 

@numba.jit(
    numba.int32(numba.int32),
    nopython=True,
    nogil=True,
)
def bar(a):

    return 2 * a

@numba.jit(
    numba.int32(numba.int32),
    nopython=True,
    nogil=True,
)
def baz(a):

    return 3 * a

@numba.jit(
    nopython=True,
    nogil=True,
)
def foo(fn, a):

    return fn(a)

foo(bar, 2) returns 4

foo(baz, 2) returns 6

like image 145
joseph-fourier Avatar answered Oct 28 '25 18:10

joseph-fourier



Donate For Us

If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!