Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Repeat specific columns of a tensor in Pytorch

Tags:

python

pytorch

I have a pytorch tensor X of size m x n and a list of nonnegative integers num_repeats of length n (assume sum(num_repeats)>0). Inside a forward() method, I want to create a tensor X_dup of size m x sum(num_repeats) where column i of X is repeated num_repeats[i] times. The tensor X_dup is to be used downstream in the forward() method so the gradient needs to be backpropogated correctly. All solutions I could come up with require inplace operations (creating a new tensor and populating it by iterating over num_repeats), but if I understand correctly this won't preserve the gradient (please correct me if I'm wrong, I'm new to the whole Pytorch thing).

like image 469
H.Rappeport Avatar asked Aug 31 '25 20:08

H.Rappeport


1 Answers

Provided you're using PyTorch >= 1.1.0 you can use torch.repeat_interleave.

repeat_tensor = torch.tensor(num_repeats).to(X.device, torch.int64)
X_dup = torch.repeat_interleave(X, repeat_tensor, dim=1)
like image 115
jodag Avatar answered Sep 04 '25 00:09

jodag