Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Vectorize a python loop over a numpy array

Tags:

python

numpy

I need to speed up the processing of this loop as it is very slow. But I don't know how to vectorize it since the result of one value depends on the result of a previous value. Any suggestions?

import numpy as np

sig = np.random.randn(44100)
alpha = .9887
beta = .999

out = np.zeros_like(sig)

for n in range(1, len(sig)):
    if np.abs(sig[n]) >= out[n-1]:
        out[n] = alpha * out[n-1] + (1 - alpha) * np.abs( sig[n] )
    else:
        out[n] = beta * out[n-1]
like image 534
Christopher Brown Avatar asked Jun 13 '26 21:06

Christopher Brown


1 Answers

Numba's just-in-time compiler should deal with indexing overhead you're facing pretty well by compiling the function to native code during first execution:

In [1]: %cpaste
Pasting code; enter '--' alone on the line to stop or use Ctrl-D.
:import numpy as np
:
:sig = np.random.randn(44100)
:alpha = .9887
:beta = .999
:
:def nonvectorized(sig):
:    out = np.zeros_like(sig)
:
:    for n in range(1, len(sig)):
:        if np.abs(sig[n]) >= out[n-1]:
:            out[n] = alpha * out[n-1] + (1 - alpha) * np.abs( sig[n] )
:        else:
:            out[n] = beta * out[n-1]
:    return out
:--

In [2]: nonvectorized(sig)
Out[2]: 
array([ 0.        ,  0.01862503,  0.04124917, ...,  1.2979579 ,
        1.304247  ,  1.30294275])

In [3]: %timeit nonvectorized(sig)
10 loops, best of 3: 80.2 ms per loop

In [4]: from numba import jit

In [5]: vectorized = jit(nonvectorized)

In [6]: np.allclose(vectorized(sig), nonvectorized(sig))
Out[6]: True

In [7]: %timeit vectorized(sig)
1000 loops, best of 3: 249 µs per loop

EDIT: as suggested in a comment, adding jit benchmarks. jit(nonvectorized) is creating a lightweight wrapper and thus is a cheap operation.

In [8]: %timeit jit(nonvectorized)
10000 loops, best of 3: 45.3 µs per loop

The function itself is compiled during the first execution (hence just-in-time) which takes a while, but probably not as much:

In [9]: %timeit jit(nonvectorized)(sig)
10 loops, best of 3: 169 ms per loop
like image 120
immerrr Avatar answered Jun 15 '26 11:06

immerrr