As stated by the title, I want to remove parts from an 1D array that have consecutive zeros and length equal or above a threshold.
I produced the solution shown in the following MRE:
import numpy as np
THRESHOLD = 4
a = np.array((1,1,0,1,0,0,0,0,1,1,0,0,0,1,0,0,0,0,0,1))
print("Input: " + str(a))
# Find the indices of the parts that meet threshold requirement
gaps_above_threshold_inds = np.where(np.diff(np.nonzero(a)[0]) - 1 >= THRESHOLD)[0]
# Delete these parts from array
for idx in gaps_above_threshold_inds:
a = np.delete(a, list(range(np.nonzero(a)[0][idx] + 1, np.nonzero(a)[0][idx + 1])))
print("Output: " + str(a))
Output:
Input: [1 1 0 1 0 0 0 0 1 1 0 0 0 1 0 0 0 0 0 1]
Output: [1 1 0 1 1 1 0 0 0 1 1]
Is there a less complicated and more efficient way to do this on a numpy array?
Based on @mozway comments, I'm editing my question providing some more information.
Basically, the problem domain is:
My goal is to remove the zero parts above a length threshold as I have already said.
Regarding my first concern about efficient numpy handling, @mathfux's solution is really great and basically what I was looking for. That's why I accepted this one.
However, the approach by @Jérôme Richard answers my second question and it presents a really high performance solution; really useful if the dataset is extremely big.
Thanks for your great answers!
np.delete create a new array every time it is called which is very inefficient. A faster solution is to store all the value to keep in a mask/boolean array and then filter the input array at once. However, this will still likely require a pure-Python loop if done only with Numpy. A simpler and faster solution is to use Numba (or Cython) to do that. Here is an implementation:
import numpy as np
import numba as nb
@nb.njit('int_[:](int_[:], int_)')
def filterZeros(arr, threshold):
n = len(arr)
res = np.empty(n, dtype=arr.dtype)
count = 0
j = 0
for i in range(n):
if arr[i] == 0:
count += 1
else:
if count >= threshold:
j -= count
count = 0
res[j] = arr[i]
j += 1
if n > 0 and arr[n-1] == 0 and count >= threshold:
j -= count
return res[0:j]
a = np.array((1,1,0,1,0,0,0,0,1,1,0,0,0,1,0,0,0,0,0,1))
a = filterZeros(a, 4)
print("Output: " + str(a))
Here are the result with a random binary array containing 100_000 items on my machine:
Reference implementation: 5982 ms
Mozway's solution: 23.4 ms
This implementation: 0.11 ms
Thus, the solution is about 54381 faster than the initial solution and 212 times faster than the one of Mozway. The code can even be ~30% faster by working in-place (destroy the input array) and by telling Numba the array is contiguous in memory (using ::1 instead of :).
It's also possible to find differences of nonzero items, fix the ones that exceeed threshold and reconstruct a sequence in a correct way.
def numpy_fix(a):
# STEP 1. find indices of nonzero items: [0 1 3 8 9 13 19]
idx = np.flatnonzero(a)
# STEP 2. Find differences along these indices (also insert a leading zero): [0 1 2 5 1 4 6]
df = np.diff(idx, prepend=0)
# STEP 3. Fix differences of indices larger than THRESHOLD: [0 1 2 1 1 4 1]
df[df>THRESHOLD] = 1
# STEP 4. Given differences on indices, reconstruct indices themselves: [0 1 3 4 5 9 10]
cs = np.cumsum(df)
z = np.zeros(cs[-1]+1, dtype=int) # create a list of zeros
z[cs] = 1 #pad it with ones within indices found
return z
>>> numpy_fix(a)
array([1, 1, 0, 1, 1, 1, 0, 0, 0, 1, 1])
(Note that it's correct only if a has no leading or trailing zeros)
%timeit numpy_fix(np.tile(a, (1, 50000)))
39.3 ms ± 865 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With