I have a pytorch tensor X of size m x n and a list of nonnegative integers num_repeats of length n (assume sum(num_repeats)>0). Inside a forward() method, I want to create a tensor X_dup of size m x sum(num_repeats) where column i of X is repeated num_repeats[i] times. The tensor X_dup is to be used downstream in the forward() method so the gradient needs to be backpropogated correctly.
All solutions I could come up with require inplace operations (creating a new tensor and populating it by iterating over num_repeats), but if I understand correctly this won't preserve the gradient (please correct me if I'm wrong, I'm new to the whole Pytorch thing).
Provided you're using PyTorch >= 1.1.0 you can use torch.repeat_interleave.
repeat_tensor = torch.tensor(num_repeats).to(X.device, torch.int64)
X_dup = torch.repeat_interleave(X, repeat_tensor, dim=1)
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With