mirror of
https://github.com/saymrwulf/transformers.git
synced 2026-05-14 20:58:08 +00:00
Update README.md
This commit is contained in:
parent
cf1c88e092
commit
5ff0d6d7d0
1 changed files with 2 additions and 5 deletions
|
|
@ -14,12 +14,9 @@ retriever = RagRetriever.from_pretrained("facebook/rag-token-nq", index_name="ex
|
|||
model = RagTokenForGeneration.from_pretrained("facebook/rag-token-nq", retriever=retriever)
|
||||
|
||||
input_dict = tokenizer.prepare_seq2seq_batch("How many people live in Paris?", "In Paris, there are 10 million people.", return_tensors="pt")
|
||||
outputs = model(input_ids=input_dict["input_ids"], labels=input_dict["labels"])
|
||||
|
||||
# outputs.loss should give 76.1230
|
||||
|
||||
generated = model.generate(input_ids=input_dict["input_ids"])
|
||||
generated_string = tokenizer.batch_decode(generated, skip_special_tokens=True)
|
||||
print(tokenizer.batch_decode(generated, skip_special_tokens=True)[0])
|
||||
|
||||
# generated_string should give 270,000 -> not quite correct the answer, but it also only uses a dummy index
|
||||
# generated_string should give 270,000,000 -> a bit too many I think
|
||||
```
|
||||
|
|
|
|||
Loading…
Reference in a new issue