Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

What is smart way to get batched gather?

Tags:

python

pytorch

I have two matrices, A and B, with shapes (n, m, k) and (n, m) respectively. n is the batch size, m is the amount of data in a batch, and k is the feature size.

Each element of B is an index less than m (specifically B = torch.randint(high=m, shape=(n,m))).

I want to implement [A[i][B[i]] for i in range(n)] in a smarter way.

Is there a better way in pytorch to implement this without doing for loop?

like image 402
안규리 Avatar asked Oct 18 '25 10:10

안규리


1 Answers

You can use

a[torch.arange(n)[:, None], b]

An example:

>>> n, m, k = 3, 2, 5
>>> a = torch.arange(30).view(n, m, k)
>>> b = torch.randint(high=m, size=(n,m))

# first indexer (of shape (n, 1))
>>> torch.arange(n)[:, None]

tensor([[0],
        [1],
        [2]])

# second indexer
>>> b

tensor([[1, 0],
        [0, 1],
        [1, 1]])

The indexers have the shape (3, 1) and (3, 2) respectively so they'll be broadcasted to (3, 2) to effectively have

tensor([[0, 0],
        [1, 1],
        [2, 2]])

and

tensor([[1, 0],
        [0, 1],
        [1, 1]])

which says: for the first row, take 1st (k,) array and put the result and take 0th (k,) array and put the result. This fills in a (m, k) array in the output which is repeated n times for each row,

to get

>>> a[torch.arange(n)[:, None], b]

tensor([[[ 5,  6,  7,  8,  9],
         [ 0,  1,  2,  3,  4]],

        [[10, 11, 12, 13, 14],
         [15, 16, 17, 18, 19]],

        [[25, 26, 27, 28, 29],
         [25, 26, 27, 28, 29]]])

comparing with list comprehension:

>>> [a[i][b[i]] for i in range(n)]

[tensor([[5, 6, 7, 8, 9],
         [0, 1, 2, 3, 4]]),
 tensor([[10, 11, 12, 13, 14],
         [15, 16, 17, 18, 19]]),
 tensor([[25, 26, 27, 28, 29],
         [25, 26, 27, 28, 29]])]
like image 105
Mustafa Aydın Avatar answered Oct 21 '25 00:10

Mustafa Aydın



Donate For Us

If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!