I'm working on a jitclass in which one of the methods can accept an input argument of int, float, or numpy.ndarray. I need to be able to determine if the argument is and array or any of the other two types. I've tried using isinstance as shown in the interp method below:
spec = [('x', float64[:]),
('y', float64[:])]
@jitclass(spec)
class Lookup:
def __init__(self, x, y):
self.x = x
self.y = y
def interp(self, x0):
if isinstance(x0, (float, int)):
result = self._interpolate(x0)
elif isinstance(x0, np.ndarray):
result = np.zeros(x0.size)
for i in range(x0.size):
result[i] = self._interpolate(x0[i])
else:
raise TypeError("`interp` method can only accept types of float, int, or ndarray.")
return result
def _interpolate(self, x0):
x = self.x
y = self.y
if x0 < x[0]:
return y[0]
elif x0 > x[-1]:
return y[-1]
else:
for i in range(len(x) - 1):
if x[i] <= x0 <= x[i + 1]:
x1, x2 = x[i], x[i + 1]
y1, y2 = y[i], y[i + 1]
return y1 + (y2 - y1) / (x2 - x1) * (x0 - x1)
But I get the following error:
numba.errors.TypingError: Failed at nopython (nopython frontend)
Failed at nopython (nopython frontend)
Untyped global name 'isinstance': cannot determine Numba type of <class 'builtin_function_or_method'>
File "Lookups.py", line 17
[1] During: resolving callee type: BoundFunction((<class 'numba.types.misc.ClassInstanceType'>, 'interp') for instance.jitclass.Lookup#2167664ca28<x:array(float64, 1d, A),y:array(float64, 1d, A)>)
[2] During: typing of call at <string> (3)
Is there a way to determine whether an input argument is of a certain type when using jitclasses or in nopython mode?
I should have mentioned this before but using the type built-in also does not seem to work. For example if I replace the interp method with:
def interp(self, x0):
if type(x0) == float or type(x0) == int:
result = self._interpolate(x0)
elif type(x0) == np.ndarray:
result = np.zeros(x0.size)
for i in range(x0.size):
result[i] = self._interpolate(x0[i])
else:
raise TypeError("`interp` method can only accept types of float, int, or ndarray.")
return result
I get the following error:
numba.errors.TypingError: Failed at nopython (nopython frontend)
Failed at nopython (nopython frontend)
Invalid usage of == with parameters (class(int64), Function(<class 'float'>))
Which I think is referring to the comparison of python float and numba's int64 when I do something like lookup_object.interp(370) for example.
You're out of luck if you need to determine and compare the type inside a numba jitclass or nopython jit function because isinstance isn't supported at all and type supports only on a few numeric types and namedtuples (note that this just returns the type - it's not suitable for comparisons - because == isn't implemented for classes inside numba functions).
As of Numba 0.35 the only supported built-ins are (source: numba documentation):
The following built-in functions are supported:
abs() bool complex divmod() enumerate() float int: only the one-argument form iter(): only the one-argument form len() min() max() next(): only the one-argument form print(): only numbers and strings; no file or sep argument range: semantics are similar to those of Python 3 even in Python 2: a range object is returned instead of an array of values. round() sorted(): the key argument is not supported type(): only the one-argument form, and only on some types (e.g. numbers and named tuples) zip()
My suggestion: Use a normal Python class and determine the type there and then forward to numba.njitted functions accordingly:
import numba as nb
import numpy as np
@nb.njit
def _interpolate_one(x, y, x0):
if x0 < x[0]:
return y[0]
elif x0 > x[-1]:
return y[-1]
else:
for i in range(len(x) - 1):
if x[i] <= x0 <= x[i + 1]:
x1, x2 = x[i], x[i + 1]
y1, y2 = y[i], y[i + 1]
return y1 + (y2 - y1) / (x2 - x1) * (x0 - x1)
@nb.njit
def _interpolate_many(x, y, x0):
result = np.zeros(x0.size, dtype=np.float_)
for i in range(x0.size):
result[i] = _interpolate_one(x, y, x0[i])
return result
class Lookup:
def __init__(self, x, y):
self.x = x
self.y = y
def interp(self, x0):
if isinstance(x0, (float, int)):
result = _interpolate_one(self.x, self.y, x0)
elif isinstance(x0, np.ndarray):
result = _interpolate_many(self.x, self.y, x0)
else:
raise TypeError("`interp` method can only accept types of float, int, or ndarray.")
return result
As of numba 0.52, np.shape() is supported. So if you only want to distinguish between np.ndarray and scalars, the following works:
@njit
def test(a):
if len(np.shape(a)) > 0:
return 'np.ndarray'
else:
return 'not an array'
>>> test(1)
'not an array'
>>> test(np.array([1,2,3]))
'np.ndarray'
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With