Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Select rows from array that are greater than template

Now I have one 2D Numpy array of float values, i.e. a, and its shape is (10^6, 3). I want to know which rows are greater than np.array([25.0, 25.0, 25.0]). And then outputting the rows that satisfy this condition. My code appears as follows.

# Create an empty array
a_cut = np.empty(shape=(0, 3), dtype=float)

minimum = np.array([25.0, 25.0, 25.0])

for i in range(len(a)):
    if a[i,:].all() > minimum.all():
        a_cut = np.append(a_cut, a[i,:], axis=0)

However, the code is inefficient. After a few hours, the result has not come out. So Is there a way to improve the speed of this loop?

like image 459
Stephen Wong Avatar asked Sep 07 '25 23:09

Stephen Wong


1 Answers

np.append re-allocates the entire array every time you call it. It is basically the same as np.concatenate: use it very sparingly. The goal is to perform the entire operation in bulk.

You can construct a mask:

mask = (a > minimum).all(axis=1)

Then select:

a_cut = a[mask, :]

You may get a slight improvement from using indices instead of a boolean mask:

a_cut = a[np.flatnonzero(mask), :]

Indexing with fewer indices than there are dimensions applies the indices to the leading dimensions, so you can do

a_cut = a[mask]

The one liner is therefore:

a_cut = a[(a > minimium).all(1)]
like image 158
Mad Physicist Avatar answered Sep 09 '25 21:09

Mad Physicist