mirror of
https://github.com/saymrwulf/transformers.git
synced 2026-05-14 20:58:08 +00:00
Typo and fix the input of labels to cross_entropy (#7841)
The current version caused some errors. The changes fixed it for me. Hope this is helpful!
This commit is contained in:
parent
a5a8eeb772
commit
dfa4c26bc0
1 changed files with 2 additions and 2 deletions
|
|
@ -109,9 +109,9 @@ The following is equivalent to the previous example:
|
|||
.. code-block:: python
|
||||
|
||||
from torch.nn import functional as F
|
||||
labels = torch.tensor([1,0]).unsqueeze(0)
|
||||
labels = torch.tensor([1,0])
|
||||
outputs = model(input_ids, attention_mask=attention_mask)
|
||||
loss = F.cross_entropy(labels, outputs.logitd)
|
||||
loss = F.cross_entropy(outputs.logits, labels)
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue