I am trying to update very specific indices of a multidimensional tensor in Pytorch, and I am not sure how to access the correct indices. I can do this in a very straightforward way in Numpy:
import numpy as np
#set up the array containing the data
data = 100*np.ones((10,10,2))
data[5:,:,:] = 0
#select the data points that I want to update
idxs = np.nonzero(data.sum(2))
#generate the updates that I am going to do
updates = np.random.randint(5,size=(idxs[0].shape[0],2))
#update the data
data[idxs[0],idxs[1],:] = updates
I need to implement this in Pytorch but I am not sure how to do this. It seems like I need the scatter
function but that only works along a single dimension instead of the multiple dimensions that I need. How can I do this?
These operations work exactly the same in their PyTorch counterparts, except for torch.nonzero
, which by default returns a tensor of size [z, n] (where z is the number of non-zero elements and n the number of dimensions) instead of a tuple of n tensors with size [z] (as NumPy does), but that behaviour can be changed by setting as_tuple=True
.
Other than that you can directly translate it to PyTorch, but you need to make sure that the types match, because you cannot assign a tensor of type torch.long
(default of torch.randint
) to a tensor of type torch.float
(default of torch.ones
). In this case, data
is probably meant to have type torch.long
:
#set up the array containing the data
data = 100*torch.ones((10,10,2), dtype=torch.long)
data[5:,:,:] = 0
#select the data points that I want to update
idxs = torch.nonzero(data.sum(2), as_tuple=True)
#generate the updates that I am going to do
updates = torch.randint(5,size=(idxs[0].shape[0],2))
#update the data
data[idxs[0],idxs[1],:] = updates
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With