I have a pytorch tensor X
of size m x n
and a list of nonnegative integers num_repeats
of length n
(assume sum(num_repeats)>0). Inside a forward() method, I want to create a tensor X_dup
of size m x sum(num_repeats)
where column i
of X
is repeated num_repeats[i]
times. The tensor X_dup
is to be used downstream in the forward() method so the gradient needs to be backpropogated correctly.
All solutions I could come up with require inplace operations (creating a new tensor and populating it by iterating over num_repeats
), but if I understand correctly this won't preserve the gradient (please correct me if I'm wrong, I'm new to the whole Pytorch thing).
Provided you're using PyTorch >= 1.1.0 you can use torch.repeat_interleave
.
repeat_tensor = torch.tensor(num_repeats).to(X.device, torch.int64)
X_dup = torch.repeat_interleave(X, repeat_tensor, dim=1)
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With