Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Python - how to speed up a for loop creating a numpy array from another numpy array calculation

First off, apologies for the vague title, I couldn't think of an appropriate name for this issue.

I have 3 numpy arrays in the follwing formats:

N = ([[13, 14, 15], [2, 5, 7], [4, 6, 8] ... several hundred thousand elements long

e1 = [1, 0, 0]

e2 = [0, 1, 0]

The idea is to create a fourth array, 'v', which shall have the same dimensions as 'N', but will be given values based on an if statement. Here is what I currently have which should better explain the issue:

v = np.zeros([len(N), 3])    

for i in range(0, len(N)):
    if((N*e1)[i,0] != 0):
        v[i] = np.cross(N[i],e1)
    else:
        v[i] = np.cross(N[i],e2)

This code does what I require it to but does so in a longer than anticipated time (> 5 mins). Is there any form of list comprehension or similar concept I could use to increase the efficiency of the code?


2 Answers

You can use numpy.where to replace if-else and vectorize the process with broadcasting, here is an option with numpy.where:

import numpy as np
np.where(np.repeat(N[:,0] != 0, 3).reshape(1000,3), np.cross(N, e1), np.cross(N, e2))

Some benchmarks here:

1) Data set up:

N = np.array([np.random.randint(0,10,3) for i in range(1000)])
N

#array([[3, 5, 0],
#       [5, 0, 8],
#       [4, 6, 0],
#       ..., 
#       [9, 4, 2],
#       [6, 9, 3],
#       [2, 9, 2]])

e1 = np.array([1, 0, 0])
e2 = np.array([0, 1, 0])

2) Timing:

def forloop():
    v = np.zeros([len(N), 3]);    
​
    for i in range(0, len(N)):
        if((N*e1)[i,0] != 0):
            v[i] = np.cross(N[i],e1)
        else:
            v[i] = np.cross(N[i],e2)
    return v

def forloop2():
    v = np.zeros([len(N), 3])    
​
    # Only calculate this one time.
    my_product = N*e1
​
    for i in range(0, len(N)):
        if my_product[i,0] != 0:
            v[i] = np.cross(N[i],e1)
        else:
            v[i] = np.cross(N[i],e2)               
    return v

%timeit forloop()
10 loops, best of 3: 25.5 ms per loop

%timeit forloop2()
100 loops, best of 3: 12.7 ms per loop    

%timeit np.where(np.repeat(N[:,0] != 0, 3).reshape(1000,3), np.cross(N, e1), np.cross(N, e2))
10000 loops, best of 3: 71.9 µs per loop

3) Result checking for all methods:

v1 = forloop()   

v2 = np.where(np.repeat(N[:,0] != 0, 3).reshape(1000,3), np.cross(N, e1), np.cross(N, e2))

v3 = forloop2()

(v3 == v1).all()
# True

(v1 == v2).all()
# True
like image 151
Psidom Avatar answered Apr 01 '26 08:04

Psidom


I'm not certain what it is you're trying to do, but I know why this specific code is so slow for you. The worst offender is (N*e1). That's a simple calculation, and it runs pretty fast with numpy, but you're executing it inside of the loop, len(N) times!.

I am able to execute your code with N == 1000000 in less than 15 seconds on my machine by pulling that outside of the loop. Example below.

v = np.zeros([len(N), 3])    

# Only calculate this one time.
my_product = N*e1

for i in range(0, len(N)):
    if my_product[i,0] != 0):
        v[i] = np.cross(N[i],e1)
    else:
        v[i] = np.cross(N[i],e2)

The other answer demonstrates how to avoid the for loop and if statements for a lot of extra speed at the cost of somewhat less readable code.

like image 23
rileymcdowell Avatar answered Apr 01 '26 08:04

rileymcdowell



Donate For Us

If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!