Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Python - How to wrap the np.ndarray class?

Tags:

python

numpy

I'm trying to add a functionality to the np.ndarray class.

I was hoping it would be as simple as the following:

class myArray(np.ndarray):

    def __init__(self, *args, **kwargs):
        self = np.array(*args, **kwargs)
        # super(ImArray, self).__init__(*args, **kwargs) <-- my first attempt didn't work
        if self.ndim == 4:
            self.myShape = (self.shape[0]*self.shape[1], self.shape[2]*self.shape[3])
        else:
            self.myShape = self.shape

    def myStuff(self):
        self = self.reshape(self.myShape)

a = np.zeros([2, 2, 2, 2])
myArray(a)
# TypeError: only length-1 arrays can be converted to Python scalars
a = range(10)
myArray(a)
# AttributeError: 'numpy.ndarray' object has no attribute 'myShape'

Please let me know if you want more information or something similar.

EDIT:

To give a bit more information abuot why I want to do this, as someone suggested simply making a function might be more appropriate.

I want to add the following:

A.newFun(B)

which would be the same as:

def newFun(A,B):
    oldShapeA = A.shape
    A = np.matrix( A.reshape([A.shape[0]*A.shape[1], A.shape[2]*A.shape[3]]) )
    oldShapeB = B.shape
    B = np.matrix( B.reshape([1, -1]) )
    out = A*B
    out = np.reshape(out.A, oldShapeA[:2]+oldShapeB)
    return out

I have left out a lot of checks such as that the dimenions are correct etc but hopefully you get the point

like image 424
evan54 Avatar asked Aug 30 '25 18:08

evan54


2 Answers

Subclassing ndarray can be done, but has some subtleties. These are explained at length in the NumPy manual.

I don't really follow what you're trying to do in the subclass, but it's worthing considering the question whether subclassing is the right approach to the problem.

like image 74
NPE Avatar answered Sep 02 '25 06:09

NPE


Sub-classing np.ndarray requires a bit of finesse. The gory details are here: http://docs.scipy.org/doc/numpy/user/basics.subclassing.html

Specifically, I think this does what you wanted:

class myArray(np.ndarray):
    def __new__(cls, *args, **kwargs):
        this = np.array(*args, **kwargs)
        this = np.asarray(this).view(cls)
        return this

    def __array_finalize__(self, obj):
        if obj.ndim == 4:
            self.myShape = (self.shape[0]*self.shape[1], self.shape[2]*self.shape[3])
        else:
            self.myShape = self.shape

    def myStuff(self):
        self = self.reshape(self.myShape)

To see this in a (more elaborate) real-life use-case, take a look here: https://github.com/arokem/nitime/blob/master/nitime/timeseries.py#L101

like image 37
arokem Avatar answered Sep 02 '25 06:09

arokem