I have just discovered numba, and learnt that optimal performance requires adding @njit to most functions, such that numba rarely exits LLVM mode.
I still have a few expensive/lookup functions that could benefit from memoization, but so far none of my attempts have found a workable solution that compiles without error.
@njit results in a numba not being able to do type inference.@njit fails to compile the decoratorglobal variables, even when using numba.typed.Dict
@njit also causes type errors when called from other @njit functionsWhat is the correct way to add memoization to functions when working inside numba?
import functools
import time
import fastcache
import numba
import numpy as np
import toolz
from numba import njit
from functools import lru_cache
from fastcache import clru_cache
from toolz import memoize
# @fastcache.clru_cache(None)  # BUG: Untyped global name 'expensive': cannot determine Numba type of <class 'fastcache.clru_cache'>
# @functools.lru_cache(None)   # BUG: Untyped global name 'expensive': cannot determine Numba type of <class 'functools._lru_cache_wrapper'>
# @toolz.memoize               # BUG: Untyped global name 'expensive': cannot determine Numba type of <class 'function'>
@njit
# @fastcache.clru_cache(None)  # BUG: AttributeError: 'fastcache.clru_cache' object has no attribute '__defaults__'
# @functools.lru_cache(None)   # BUG: AttributeError: 'functools._lru_cache_wrapper' object has no attribute '__defaults__'
# @toolz.memoize               # BUG: CALL_FUNCTION_EX with **kwargs not supported
def expensive():
    bitmasks = np.array([ 1 << n for n in range(0, 64) ], dtype=np.uint64)
    return bitmasks
# @fastcache.clru_cache(None)  # BUG: Untyped global name 'expensive_nojit': cannot determine Numba type of <class 'fastcache.clru_cache'>
# @functools.lru_cache(None)   # BUG: Untyped global name 'expensive_nojit': cannot determine Numba type of <class 'fastcache.clru_cache'>
# @toolz.memoize               # BUG: Untyped global name 'expensive_nojit': cannot determine Numba type of <class 'function'>
def expensive_nojit():
    bitmasks = np.array([ 1 << n for n in range(0, 64) ], dtype=np.uint64)
    return bitmasks
# BUG: Failed in nopython mode pipeline (step: analyzing bytecode)
#      Use of unsupported opcode (STORE_GLOBAL) found
_expensive_cache = None
@njit
def expensive_global():
    global _expensive_cache
    if _expensive_cache is None:
        bitmasks = np.array([ 1 << n for n in range(0, 64) ], dtype=np.uint64)
        _expensive_cache = bitmasks
    return _expensive_cache
# BUG: The use of a DictType[unicode_type,array(int64, 1d, A)] type, assigned to variable 'cache' in globals,
#      is not supported as globals are considered compile-time constants and there is no known way to compile
#      a DictType[unicode_type,array(int64, 1d, A)] type as a constant.
cache = numba.typed.Dict.empty(
    key_type   = numba.types.string,
    value_type = numba.uint64[:]
)
@njit
def expensive_cache():
    global cache
    if "expensive" not in cache:
        bitmasks = np.array([ 1 << n for n in range(0, 64) ], dtype=np.uint64)
        cache["expensive"] = bitmasks
    return cache["expensive"]
# BUG: Cannot capture the non-constant value associated with variable 'cache' in a function that will escape.
@njit()
def _expensive_wrapped():
    cache = []
    def wrapper(bitmasks):
        if len(cache) is None:
            bitmasks = np.array([ 1 << n for n in range(0, 64) ], dtype=np.uint64)
            cache.append(bitmasks)
        return cache[0]
    return wrapper
expensive_wrapped = _expensive_wrapped()
@njit
def loop(count):
    for n in range(count):
        expensive()
        # expensive_nojit()
        # expensive_cache()
        # expensive_global)
        # expensive_wrapped()
def main():
    time_start = time.perf_counter()
    count = 10000
    loop(count)
    time_taken = time.perf_counter() - time_start
    print(f'{count} loops in {time_taken:.4f}s')
loop(1)  # precache numba
main()
# Pure Python: 10000 loops in 0.2895s
# Numba @njit: 10000 loops in 0.0026s
Large data For larger input data, Numba version of function is must faster than Numpy version, even taking into account of the compiling time. In fact, the ratio of the Numpy and Numba run time will depends on both datasize, and the number of loops, or more general the nature of the function (to be compiled).
Both Cython and Numba speeds up Python code even small number of operations. More the number of operations more is the speed up. However, performance gain by Cython saturates at around 100-150 times of Python. On the other hand, speed up gain by Numba increases steadily with number of operations.
However numba and scipy are still not compatible. Yes, Scipy calls compiled C and Fortran, but it does so in a way that numba can't deal with.
nopython. Numba has two compilation modes: nopython mode and object mode. The former produces much faster code, but has limitations that can force Numba to fall back to the latter. To prevent Numba from falling back, and instead raise an error, pass nopython=True .
You already mentioned that your real code is more complex, but looking at your minimal example, I would recommend the following pattern:
@njit
def loop(count):
    expensive_result = expensive()
    for i in range(count):
        do_something(count, expensive_result)
Instead of using a cache, you could pre-compute it outside of the loop and provide the result to the loop body. Instead of using globals, I would recommend you to pass every argument explicitly (always, but especially when using the numba jit).
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With