Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Determine widening conversion of NumPy types

I have a library based on NumPy with a few classes that overload the arithmetic operations. The internals are a little hairy due to a significant amount of error checking, but I've run across a serious problem with how I'm doing it.

The idea behind the library is to make it very easy and intuitive to use with minimal syntax and effort by the programmer. As such, I want it to be easy to combine arrays of different data types and simply convert the narrower data type to the wider case.

For example, if I have two arrays, one with dtype float64 and the other with dtype complex128, when adding them together I would like to convert the float64 to complex128, but if it's float64 and complex192, I want to convert to that instead. However, if it was a combination of float64 and complex64, there isn't a valid conversion between the two without losing precision of the float64, so I would want to convert both to complex128.

I immediately saw the problem that I would have to look for each combination of types and determine their narrowest common widened type (think least common multiple) if I wanted my library to be fully robust. I don't want to convert everything to the widest type possible, since that becomes memory inefficient very quickly, and I often have very large arrays stored in memory.

Is there a good way to determine the narrowest common widened type between two NumPy types?

like image 902
bheklilr Avatar asked Sep 14 '25 10:09

bheklilr


2 Answers

@amaurea has the right idea; in fact, functions already exist in numpy for this. Take a look at result_type and promote_types.

like image 152
Warren Weckesser Avatar answered Sep 15 '25 23:09

Warren Weckesser


I think numpy already does this internally, so how about just asking numpy?

(zeros([],dtype=float64)+zeros([],dtype=complex64)).dtype
=> dtype('complex128')

(zeros([],dtype=float32)+zeros([],dtype=complex64)).dtype
=> dtype('complex64')

You could generalize this into a function:

def common_dtype(dtypes):
    return np.sum([np.zeros([],dtype=d) for d in dtypes]).dtype
like image 34
amaurea Avatar answered Sep 16 '25 00:09

amaurea