Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Check if each element of a tensor is contained in a list

Say I have a tensor A and a container of values vals. Is there a clean way of returning a Boolean tensor of the same shape as A with each element being whether that element of A is contained within vals? e.g:

A = torch.tensor([[1,2,3],
                  [4,5,6]])
vals = [1,5]
# Desired output
torch.tensor([[True,False,False],
              [False,True,False]])
like image 707
iacob Avatar asked Oct 16 '25 16:10

iacob


2 Answers

Use torch.isin method is the most convinient way. It's simple as follows: torch.isin(A, vals)

like image 117
huyph Avatar answered Oct 18 '25 17:10

huyph


You can achieve this with a for loop:

sum(A==i for i in B).bool()
like image 35
iacob Avatar answered Oct 18 '25 16:10

iacob