Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

numpy 'isin' performance improvement

I have a matrix with 383milj rows, i need to filter this matrix based on a list of values (index_to_remove). This function is performed several times during 1 iteration. Is there a faster alternative to:

def remove_from_result(matrix, index_to_remove, inv=True):
    return matrix[np.isin(matrix, index_to_remove, invert=inv)]
like image 212
Ward Avatar asked Oct 20 '25 20:10

Ward


2 Answers

A faster Implementation

This is a compiled version using a set as the list comprehension solution by @Matt Messersmith. It is basically a replacement for the slower np.isin method. I had some problems with the case where index_to_remove is a scalar value and implemented a seperated version for this case.

Code

import numpy as np
import numba as nb

@nb.njit(parallel=True)
def in1d_vec_nb(matrix, index_to_remove):
  #matrix and index_to_remove have to be numpy arrays
  #if index_to_remove is a list with different dtypes this 
  #function will fail

  out=np.empty(matrix.shape[0],dtype=nb.boolean)
  index_to_remove_set=set(index_to_remove)

  for i in nb.prange(matrix.shape[0]):
    if matrix[i] in index_to_remove_set:
      out[i]=False
    else:
      out[i]=True

  return out

@nb.njit(parallel=True)
def in1d_scal_nb(matrix, index_to_remove):
  #matrix and index_to_remove have to be numpy arrays
  #if index_to_remove is a list with different dtypes this 
  #function will fail

  out=np.empty(matrix.shape[0],dtype=nb.boolean)
  for i in nb.prange(matrix.shape[0]):
    if (matrix[i] == index_to_remove):
      out[i]=False
    else:
      out[i]=True

  return out


def isin_nb(matrix_in, index_to_remove):
  #both matrix_in and index_to_remove have to be a np.ndarray
  #even if index_to_remove is actually a single number
  shape=matrix_in.shape
  if index_to_remove.shape==():
    res=in1d_scal_nb(matrix_in.reshape(-1),index_to_remove.take(0))
  else:
    res=in1d_vec_nb(matrix_in.reshape(-1),index_to_remove)

  return res.reshape(shape)

Example

data = np.array([[80,1,12],[160,2,12],[240,3,12],[80,4,11]])
test_elts= np.array((80))

data[isin_nb(data[:,0],test_elts),:]

Tmings

test_elts = np.arange(12345)
data=np.arange(1000*1000)

#The first call has compilation overhead of about 300ms
#which is not included in the timings
#remove_from_result:     52ms
#isin_nb:                1.59ms
like image 72
max9111 Avatar answered Oct 23 '25 11:10

max9111


The runtime of your filtering function appears to be linear w.r.t. the size of your input matrix. Note that filtering with a list comprehension with a set is definitely linear, and your function is running roughly twice as fast as a list comprehension filter with the same input on my machine. You can also see that if you increase the size by a factor of X, runtime also increases by a factor of X:

In [84]: test_elts = np.arange(12345)

In [85]: test_elts_set = set(test_elts)

In [86]: %timeit remove_from_result(np.arange(1000*1000), test_elts)
10 loops, best of 3: 81.5 ms per loop

In [87]: %timeit [x for x in np.arange(1000*1000) if x not in test_elts_set]
1 loop, best of 3: 201 ms per loop

In [88]: %timeit remove_from_result(np.arange(1000*1000*2), test_elts)
10 loops, best of 3: 191 ms per loop

In [89]: %timeit [x for x in np.arange(1000*1000*2) if x not in test_elts_set]
1 loop, best of 3: 430 ms per loop

In [90]: %timeit remove_from_result(np.arange(1000*1000*10), test_elts)
1 loop, best of 3: 916 ms per loop

In [91]: %timeit [x for x in np.arange(1000*1000*10) if x not in test_elts_set]
1 loop, best of 3: 2.04 s per loop

In [92]: %timeit remove_from_result(np.arange(1000*1000*100), test_elts)
1 loop, best of 3: 12.4 s per loop

In [93]: %timeit [x for x in np.arange(1000*1000*100) if x not in test_elts_set]
1 loop, best of 3: 26.4 s per loop

For filtering unstructured data, that's as fast as you can go in terms of algorithmic complexity, since you'll have to touch each element once. You can't do better than linear time. A couple of things that might help improve performance:

  1. If you have access to something like pyspark (which you can get by using EMR on AWS if you're willing to pay a few bucks), you could do this much faster. The problem is pretty embarrassingly parallel. You can split up your input into K chunks, give each worker the items that need to be flitered and a chunk, have each worker filter, and then collect/merge at the end. Or you could even try using multiprocessing as well, but you'll have to be careful about the memory (multiprocessing is similar to C's fork(), it'll spawn subprocesses, but each of those clones your current memory space).

  2. If your data has some structure (like it's sorted), you could be smarter about it, and get sublinear algorithmic complexity. For instance, if you need to remove a relatively small number of items from a large, sorted array, you could just do bin search for each item to remove. This would run in O(m log n) time where m is the number of items to remove and n is the size of your large array. If m is relatively small (compared to n), you're in business, as then you'll be close to O(log n). There are even more clever ways to handle this particular situation, but I choose this one since it's pretty easy to explain. If you know anything about your data's distribution, you might be able to do better than linear time.

HTH.

like image 41
Matt Messersmith Avatar answered Oct 23 '25 11:10

Matt Messersmith



Donate For Us

If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!