Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Identifying and removing duplicate columns/rows in sparse binary matrix in PyTorch

Let's suppose we have a binary matrix A with shape n x m, I want to identify rows that have duplicates in the matrix, i.e. there is another index on the same dimension with the same elements in the same positions.

It's very important not to convert this matrix into a dense representation, since the real matrices I'm using are quite large and difficult to handle in terms of memory.

Using PyTorch for the implementation:

# This is just a toy sparse binary matrix with n = 10 and m = 100
A = torch.randint(0, 2, (10, 100), dtype=torch.float32).to_sparse()

Intuitively, we can perform the dot product of this matrix producing a new m x m matrix which contains in terms i, j, the number of 1s that the index i has in the same position of the index j at dimension 0.

B = A.T @ A # In PyTorch, this operation will also produce a sparse representation

At this point, I've tried to combine these values, comparing them with A.sum(0),

num_elements = A.sum(0)
duplicate_rows = torch.logical_and([
   num_elements[B.indices()[0]] == num_elements[B.indices()[1]],
   num_elements[B.indices()[0]] == B.values()
])

But this did not work!

I think that the solution can be written only by using operations on PyTorch Sparse tensors (without using Python loops and so on), and this could also be a benefit in terms of performance.

like image 980
daqh Avatar asked Oct 25 '25 06:10

daqh


1 Answers

I've found a solution that only takes advantage of torch sparse representation and is very efficient in terms of memory computation and memory consumption:

# A is the sparse matrix

B = A.T @ A # or A @ A.T depending on the dimension we are working on
num_elements = A.sum(0).to_dense()

duplicates = torch.logical_and(
   B.indices()[0] < B.indices()[1], # Consider only elements over the upper diagonal
   torch.logical_and(
      B.values() == num_elements[B.indices()[0]],
      B.values() == num_elements[B.indices()[1]],
   )
)
duplicate_indices = B.indices()[1, duplicates].unique()

At this point we can use the generated mask duplicate_indices in order to remove duplicate indices.

unique_indices = A.indices()[:,
   ~torch.isin(
      A.indices()[1],
      duplicate_edges
)]

unique_indices is a sparse representation of the filtered matrix A.


Additionally, we can normalize the result to remove unused indices:

_, unique_indices[1] = torch.unique(unique_indices[1], return_inverse=True)
like image 51
daqh Avatar answered Oct 26 '25 19:10

daqh