Let's say that I have tensor
t = torch.tensor([1,2,3,4,5])
I want to split it using a same-sized tensor of indices that tells me for each element, in which split it should go.
indices = torch.tensor([0,1,1,0,2])
So that the final result is
splits
[tensor([1,4]), tensor([2,3]), tensor([5])]
Is there a neat way to do this in Pytorch?
EDIT : In general there will be more than 2 or 3 splits.
One could do it using argsort
for general case:
def mask_split(tensor, indices):
sorter = torch.argsort(indices)
_, counts = torch.unique(indices, return_counts=True)
return torch.split(t[sorter], counts.tolist())
mask_split(t, indices)
Though it might be better to use @flawr answer if this is your real use case (also list comprehension
might also be faster as it does not require sorting), something like this:
def mask_split(tensor, indices):
unique = torch.unique(indices)
return [tensor[indices == i] for i in unique]
That is indeed possible using logical indexing, you just have to make sure that the index "mask" is made from boolean vales, so in your case
splits = t[indices > 0] , t[indices < 1]
or alternatively you can first cast your tensor indices
to have boolean dtype.
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With