Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

What is the difference between src and tgt in nn.Transformer for PyTorch?

From the docs it says to create a transformer model like this:

transformer_model = nn.Transformer(nhead=16, num_encoder_layers=12)
src = torch.rand((10, 32, 512))
tgt = torch.rand((20, 32, 512)) # What is tgt??
out = transformer_model(src, tgt) 

What is tgt mean't to be? Should tgt be the same as the src?

like image 637
Dylan Kerler Avatar asked Sep 03 '25 05:09

Dylan Kerler


1 Answers

The transformer structure is of two components, the encoder and the decoder. The src is the input to encoder and the tgt is the input to decoder.

For example doing a machine translation task that translates English sentence to French, the src is english sequence ids and tgt is french sequence ids.

like image 110
emily Avatar answered Sep 05 '25 01:09

emily