Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How to randomly set a fixed number of elements in each row of a tensor in PyTorch

Tags:

pytorch

I was wondering if there is any more efficient alternative for the below code, without using the "for" loop in the 4th line?

import torch
n, d = 37700, 7842
k = 4
sample = torch.cat([torch.randperm(d)[:k] for _ in range(n)]).view(n, k)
mask = torch.zeros(n, d, dtype=torch.bool)
mask.scatter_(dim=1, index=sample, value=True)

Basically, what I am trying to do is to create an n by d mask tensor, such that in each row exactly k random elements are True.

like image 406
sisaman Avatar asked Nov 26 '25 09:11

sisaman


1 Answers

Here's a way to do this with no loop. Let's start with a random matrix where all elements are drawn iid, in this case uniformly on [0,1]. Then we take the k'th quantile for each row and set all smaller or equal elements to True and the rest to False on each row:

rand_mat = torch.rand(n, d)
k_th_quant = torch.topk(rand_mat, k, largest = False)[0][:,-1:]
mask = rand_mat <= k_th_quant

No loop needed :) x2.1598 faster than the code you attached on my CPU.

like image 128
Gil Pinsky Avatar answered Nov 28 '25 04:11

Gil Pinsky



Donate For Us

If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!