I have two matrices, A and B, with shapes (n, m, k) and (n, m) respectively. n is the batch size, m is the amount of data in a batch, and k is the feature size.
Each element of B is an index less than m (specifically B = torch.randint(high=m, shape=(n,m))).
I want to implement [A[i][B[i]] for i in range(n)] in a smarter way.
Is there a better way in pytorch to implement this without doing for loop?
You can use
a[torch.arange(n)[:, None], b]
An example:
>>> n, m, k = 3, 2, 5
>>> a = torch.arange(30).view(n, m, k)
>>> b = torch.randint(high=m, size=(n,m))
# first indexer (of shape (n, 1))
>>> torch.arange(n)[:, None]
tensor([[0],
[1],
[2]])
# second indexer
>>> b
tensor([[1, 0],
[0, 1],
[1, 1]])
The indexers have the shape (3, 1) and (3, 2) respectively so they'll be broadcasted to (3, 2) to effectively have
tensor([[0, 0],
[1, 1],
[2, 2]])
and
tensor([[1, 0],
[0, 1],
[1, 1]])
which says: for the first row, take 1st (k,) array and put the result and take 0th (k,) array and put the result. This fills in a (m, k) array in the output which is repeated n times for each row,
to get
>>> a[torch.arange(n)[:, None], b]
tensor([[[ 5, 6, 7, 8, 9],
[ 0, 1, 2, 3, 4]],
[[10, 11, 12, 13, 14],
[15, 16, 17, 18, 19]],
[[25, 26, 27, 28, 29],
[25, 26, 27, 28, 29]]])
comparing with list comprehension:
>>> [a[i][b[i]] for i in range(n)]
[tensor([[5, 6, 7, 8, 9],
[0, 1, 2, 3, 4]]),
tensor([[10, 11, 12, 13, 14],
[15, 16, 17, 18, 19]]),
tensor([[25, 26, 27, 28, 29],
[25, 26, 27, 28, 29]])]
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With