Typo and fix the input of labels to cross_entropy (#7841)

The current version caused some errors. The changes fixed it for me. Hope this is helpful!
This commit is contained in:
Katarina Slama 2020-10-15 16:36:31 -07:00 committed by GitHub
parent a5a8eeb772
commit dfa4c26bc0
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -109,9 +109,9 @@ The following is equivalent to the previous example:
.. code-block:: python
from torch.nn import functional as F
labels = torch.tensor([1,0]).unsqueeze(0)
labels = torch.tensor([1,0])
outputs = model(input_ids, attention_mask=attention_mask)
loss = F.cross_entropy(labels, outputs.logitd)
loss = F.cross_entropy(outputs.logits, labels)
loss.backward()
optimizer.step()