From the docs it says to create a transformer model like this:
transformer_model = nn.Transformer(nhead=16, num_encoder_layers=12)
src = torch.rand((10, 32, 512))
tgt = torch.rand((20, 32, 512)) # What is tgt??
out = transformer_model(src, tgt)
What is tgt mean't to be? Should tgt be the same as the src?
The transformer structure is of two components, the encoder and the decoder. The src is the input to encoder and the tgt is the input to decoder.
For example doing a machine translation task that translates English sentence to French, the src is english sequence ids and tgt is french sequence ids.
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With