Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Split a torch tensor using a same-sized tensor of indices

Let's say that I have tensor

t = torch.tensor([1,2,3,4,5])

I want to split it using a same-sized tensor of indices that tells me for each element, in which split it should go.

indices = torch.tensor([0,1,1,0,2])

So that the final result is

splits
[tensor([1,4]), tensor([2,3]), tensor([5])]

Is there a neat way to do this in Pytorch?

EDIT : In general there will be more than 2 or 3 splits.

like image 894
Undead Avatar asked Oct 17 '25 11:10

Undead


2 Answers

One could do it using argsort for general case:

def mask_split(tensor, indices):
    sorter = torch.argsort(indices)
    _, counts = torch.unique(indices, return_counts=True)
    return torch.split(t[sorter], counts.tolist())


mask_split(t, indices)

Though it might be better to use @flawr answer if this is your real use case (also list comprehension might also be faster as it does not require sorting), something like this:

def mask_split(tensor, indices):
    unique = torch.unique(indices)
    return [tensor[indices == i] for i in unique]
like image 162
Szymon Maszke Avatar answered Oct 19 '25 23:10

Szymon Maszke


That is indeed possible using logical indexing, you just have to make sure that the index "mask" is made from boolean vales, so in your case

splits = t[indices > 0] , t[indices < 1]

or alternatively you can first cast your tensor indices to have boolean dtype.

like image 39
flawr Avatar answered Oct 19 '25 23:10

flawr



Donate For Us

If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!