Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Numpy seterr for all functions in module

I have a module containing functions with numpy operations. I would like to raise an exception for any floating-point error (e.g., division by zero) that occurs in any of the functions.

This line of code causes all floating-point errors to be raised:

np.seterr(all='raise')

I'm wondering how to set this for all functions in the module, without it affecting code outside of the module.

As I understand it, writing the line under if __name__ == '__main__': won't help, because it won't be invoked when the module is imported.

Is there a better way than writing np.seterr(all='raise') inside each function?

like image 278
Vermillion Avatar asked Oct 28 '25 08:10

Vermillion


1 Answers

Seems like the thread is alive for quite long, and I guess you/others already managed, but for future viewers (like me), here's something to save a few minutes of your life:

import numpy as np


def unmasked_call(func):
    def func2(*args, **kwargs):
        with np.errstate(all='raise'):
            return func(*args, **kwargs)
    return func2


@unmasked_call
def one():
    print ("One:")
    return np.max([1, 2, np.nan])


print ("Starting")
one()
like image 109
Dorish Avatar answered Oct 30 '25 01:10

Dorish



Donate For Us

If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!