Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How to convert one-hot vector to label index and back in Pytorch?

How to transform vectors of labels to one-hot encoding and back in Pytorch?

The solution to the question was copied to here after having to go through the entire forum discussion, instead of just finding an easy one from googling.

like image 453
Gulzar Avatar asked Dec 06 '25 17:12

Gulzar


2 Answers

From the Pytorch forums

import torch
import numpy as np


labels = torch.randint(0, 10, (10,))

# labels --> one-hot 
one_hot = torch.nn.functional.one_hot(labels)
# one-hot --> labels
labels_again = torch.argmax(one_hot, dim=1)

np.testing.assert_equals(labels.numpy(), labels_again.numpy())
like image 140
Gulzar Avatar answered Dec 08 '25 05:12

Gulzar


Since I can't comment on the accepted answer, I just wanted to add that if your target does not include all classes (e.g. because you train in batches), you can specify the number of classes as argument:

# labels --> one-hot 
one_hot = torch.nn.functional.one_hot(target, num_classes=7)
like image 35
swageta Avatar answered Dec 08 '25 06:12

swageta



Donate For Us

If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!