Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

pytorch modify array with list of indices

Tags:

python

pytorch

Suppose I have a list of indices and wish to modify an existing array with this list. Currently the only way I can do this is by using a for loop as follows. Just wondering if there is a faster/ efficient way.

torch.manual_seed(0)
a = torch.randn(5,3)
idx = torch.Tensor([[1,2], [3,2]], dtype=torch.long)
for i,j in idx:
    a[i,j] = 1

I initially assumed that gather or index_select would go some way in answering this question, but looking at documentation this doesn't seem to be the answer.

In my particular case, a is a 5 dimensional vector and idx is a Nx5 vector. So the output (after subscripting with something like a[idx]) I'd expect is a (N,) shaped vector.

Answer

Thanks to @shai below, the answer that I was seeking was: a[idx.t().chunk(chunks=2,dim=0)]. Taken from this SO answer.

like image 592
sachinruk Avatar asked Dec 31 '25 22:12

sachinruk


1 Answers

It's quite simple

a[idx[:,0], idx[:,1]] = 1

You can find a more general solution in this thread.

like image 186
Shai Avatar answered Jan 03 '26 11:01

Shai



Donate For Us

If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!