I have a matrix with 383milj rows, i need to filter this matrix based on a list of values (index_to_remove). This function is performed several times during 1 iteration. Is there a faster alternative to:
def remove_from_result(matrix, index_to_remove, inv=True):
return matrix[np.isin(matrix, index_to_remove, invert=inv)]
This is a compiled version using a set as the list comprehension solution by @Matt Messersmith. It is basically a replacement for the slower np.isin method. I had some problems with the case where index_to_remove is a scalar value and implemented a seperated version for this case.
Code
import numpy as np
import numba as nb
@nb.njit(parallel=True)
def in1d_vec_nb(matrix, index_to_remove):
#matrix and index_to_remove have to be numpy arrays
#if index_to_remove is a list with different dtypes this
#function will fail
out=np.empty(matrix.shape[0],dtype=nb.boolean)
index_to_remove_set=set(index_to_remove)
for i in nb.prange(matrix.shape[0]):
if matrix[i] in index_to_remove_set:
out[i]=False
else:
out[i]=True
return out
@nb.njit(parallel=True)
def in1d_scal_nb(matrix, index_to_remove):
#matrix and index_to_remove have to be numpy arrays
#if index_to_remove is a list with different dtypes this
#function will fail
out=np.empty(matrix.shape[0],dtype=nb.boolean)
for i in nb.prange(matrix.shape[0]):
if (matrix[i] == index_to_remove):
out[i]=False
else:
out[i]=True
return out
def isin_nb(matrix_in, index_to_remove):
#both matrix_in and index_to_remove have to be a np.ndarray
#even if index_to_remove is actually a single number
shape=matrix_in.shape
if index_to_remove.shape==():
res=in1d_scal_nb(matrix_in.reshape(-1),index_to_remove.take(0))
else:
res=in1d_vec_nb(matrix_in.reshape(-1),index_to_remove)
return res.reshape(shape)
Example
data = np.array([[80,1,12],[160,2,12],[240,3,12],[80,4,11]])
test_elts= np.array((80))
data[isin_nb(data[:,0],test_elts),:]
Tmings
test_elts = np.arange(12345)
data=np.arange(1000*1000)
#The first call has compilation overhead of about 300ms
#which is not included in the timings
#remove_from_result: 52ms
#isin_nb: 1.59ms
The runtime of your filtering function appears to be linear w.r.t. the size of your input matrix. Note that filtering with a list comprehension with a set is definitely linear, and your function is running roughly twice as fast as a list comprehension filter with the same input on my machine. You can also see that if you increase the size by a factor of X, runtime also increases by a factor of X:
In [84]: test_elts = np.arange(12345)
In [85]: test_elts_set = set(test_elts)
In [86]: %timeit remove_from_result(np.arange(1000*1000), test_elts)
10 loops, best of 3: 81.5 ms per loop
In [87]: %timeit [x for x in np.arange(1000*1000) if x not in test_elts_set]
1 loop, best of 3: 201 ms per loop
In [88]: %timeit remove_from_result(np.arange(1000*1000*2), test_elts)
10 loops, best of 3: 191 ms per loop
In [89]: %timeit [x for x in np.arange(1000*1000*2) if x not in test_elts_set]
1 loop, best of 3: 430 ms per loop
In [90]: %timeit remove_from_result(np.arange(1000*1000*10), test_elts)
1 loop, best of 3: 916 ms per loop
In [91]: %timeit [x for x in np.arange(1000*1000*10) if x not in test_elts_set]
1 loop, best of 3: 2.04 s per loop
In [92]: %timeit remove_from_result(np.arange(1000*1000*100), test_elts)
1 loop, best of 3: 12.4 s per loop
In [93]: %timeit [x for x in np.arange(1000*1000*100) if x not in test_elts_set]
1 loop, best of 3: 26.4 s per loop
For filtering unstructured data, that's as fast as you can go in terms of algorithmic complexity, since you'll have to touch each element once. You can't do better than linear time. A couple of things that might help improve performance:
If you have access to something like pyspark (which you can get by using EMR on AWS if you're willing to pay a few bucks), you could do this much faster. The problem is pretty embarrassingly parallel. You can split up your input into K chunks, give each worker the items that need to be flitered and a chunk, have each worker filter, and then collect/merge at the end. Or you could even try using multiprocessing as well, but you'll have to be careful about the memory (multiprocessing is similar to C's fork(), it'll spawn subprocesses, but each of those clones your current memory space).
If your data has some structure (like it's sorted), you could be smarter about it, and get sublinear algorithmic complexity. For instance, if you need to remove a relatively small number of items from a large, sorted array, you could just do bin search for each item to remove. This would run in O(m log n) time where m is the number of items to remove and n is the size of your large array. If m is relatively small (compared to n), you're in business, as then you'll be close to O(log n). There are even more clever ways to handle this particular situation, but I choose this one since it's pretty easy to explain. If you know anything about your data's distribution, you might be able to do better than linear time.
HTH.
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With