I'm trying to add a functionality to the np.ndarray class.
I was hoping it would be as simple as the following:
class myArray(np.ndarray):
def __init__(self, *args, **kwargs):
self = np.array(*args, **kwargs)
# super(ImArray, self).__init__(*args, **kwargs) <-- my first attempt didn't work
if self.ndim == 4:
self.myShape = (self.shape[0]*self.shape[1], self.shape[2]*self.shape[3])
else:
self.myShape = self.shape
def myStuff(self):
self = self.reshape(self.myShape)
a = np.zeros([2, 2, 2, 2])
myArray(a)
# TypeError: only length-1 arrays can be converted to Python scalars
a = range(10)
myArray(a)
# AttributeError: 'numpy.ndarray' object has no attribute 'myShape'
Please let me know if you want more information or something similar.
EDIT:
To give a bit more information abuot why I want to do this, as someone suggested simply making a function might be more appropriate.
I want to add the following:
A.newFun(B)
which would be the same as:
def newFun(A,B):
oldShapeA = A.shape
A = np.matrix( A.reshape([A.shape[0]*A.shape[1], A.shape[2]*A.shape[3]]) )
oldShapeB = B.shape
B = np.matrix( B.reshape([1, -1]) )
out = A*B
out = np.reshape(out.A, oldShapeA[:2]+oldShapeB)
return out
I have left out a lot of checks such as that the dimenions are correct etc but hopefully you get the point
Subclassing ndarray
can be done, but has some subtleties. These are explained at length in the NumPy manual.
I don't really follow what you're trying to do in the subclass, but it's worthing considering the question whether subclassing is the right approach to the problem.
Sub-classing np.ndarray
requires a bit of finesse. The gory details are here: http://docs.scipy.org/doc/numpy/user/basics.subclassing.html
Specifically, I think this does what you wanted:
class myArray(np.ndarray):
def __new__(cls, *args, **kwargs):
this = np.array(*args, **kwargs)
this = np.asarray(this).view(cls)
return this
def __array_finalize__(self, obj):
if obj.ndim == 4:
self.myShape = (self.shape[0]*self.shape[1], self.shape[2]*self.shape[3])
else:
self.myShape = self.shape
def myStuff(self):
self = self.reshape(self.myShape)
To see this in a (more elaborate) real-life use-case, take a look here: https://github.com/arokem/nitime/blob/master/nitime/timeseries.py#L101
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With