mirror of
https://github.com/saymrwulf/transformers.git
synced 2026-05-15 21:01:19 +00:00
Change REALM checkpoint to new ones (#15439)
* Change REALM checkpoint to new ones * Last checkpoint missing
This commit is contained in:
parent
7e56ba2864
commit
3385ca2582
6 changed files with 102 additions and 102 deletions
|
|
@ -21,14 +21,14 @@ from ...utils import logging
|
|||
logger = logging.get_logger(__name__)
|
||||
|
||||
REALM_PRETRAINED_CONFIG_ARCHIVE_MAP = {
|
||||
"qqaatw/realm-cc-news-pretrained-embedder": "https://huggingface.co/qqaatw/realm-cc-news-pretrained-embedder/resolve/main/config.json",
|
||||
"qqaatw/realm-cc-news-pretrained-encoder": "https://huggingface.co/qqaatw/realm-cc-news-pretrained-encoder/resolve/main/config.json",
|
||||
"qqaatw/realm-cc-news-pretrained-scorer": "https://huggingface.co/qqaatw/realm-cc-news-pretrained-scorer/resolve/main/config.json",
|
||||
"qqaatw/realm-cc-news-pretrained-openqa": "https://huggingface.co/qqaatw/realm-cc-news-pretrained-openqa/aresolve/main/config.json",
|
||||
"qqaatw/realm-orqa-nq-openqa": "https://huggingface.co/qqaatw/realm-orqa-nq-openqa/resolve/main/config.json",
|
||||
"qqaatw/realm-orqa-nq-reader": "https://huggingface.co/qqaatw/realm-orqa-nq-reader/resolve/main/config.json",
|
||||
"qqaatw/realm-orqa-wq-openqa": "https://huggingface.co/qqaatw/realm-orqa-wq-openqa/resolve/main/config.json",
|
||||
"qqaatw/realm-orqa-wq-reader": "https://huggingface.co/qqaatw/realm-orqa-wq-reader/resolve/main/config.json",
|
||||
"google/realm-cc-news-pretrained-embedder": "https://huggingface.co/google/realm-cc-news-pretrained-embedder/resolve/main/config.json",
|
||||
"google/realm-cc-news-pretrained-encoder": "https://huggingface.co/google/realm-cc-news-pretrained-encoder/resolve/main/config.json",
|
||||
"google/realm-cc-news-pretrained-scorer": "https://huggingface.co/google/realm-cc-news-pretrained-scorer/resolve/main/config.json",
|
||||
"google/realm-cc-news-pretrained-openqa": "https://huggingface.co/google/realm-cc-news-pretrained-openqa/aresolve/main/config.json",
|
||||
"google/realm-orqa-nq-openqa": "https://huggingface.co/google/realm-orqa-nq-openqa/resolve/main/config.json",
|
||||
"google/realm-orqa-nq-reader": "https://huggingface.co/google/realm-orqa-nq-reader/resolve/main/config.json",
|
||||
"google/realm-orqa-wq-openqa": "https://huggingface.co/google/realm-orqa-wq-openqa/resolve/main/config.json",
|
||||
"google/realm-orqa-wq-reader": "https://huggingface.co/google/realm-orqa-wq-reader/resolve/main/config.json",
|
||||
# See all REALM models at https://huggingface.co/models?filter=realm
|
||||
}
|
||||
|
||||
|
|
@ -46,7 +46,7 @@ class RealmConfig(PretrainedConfig):
|
|||
|
||||
It is used to instantiate an REALM model according to the specified arguments, defining the model architecture.
|
||||
Instantiating a configuration with the defaults will yield a similar configuration to that of the REALM
|
||||
[realm-cc-news-pretrained](https://huggingface.co/qqaatw/realm-cc-news-pretrained-embedder) architecture.
|
||||
[realm-cc-news-pretrained](https://huggingface.co/google/realm-cc-news-pretrained-embedder) architecture.
|
||||
|
||||
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
||||
documentation from [`PretrainedConfig`] for more information.
|
||||
|
|
@ -112,7 +112,7 @@ class RealmConfig(PretrainedConfig):
|
|||
>>> # Initializing a REALM realm-cc-news-pretrained-* style configuration
|
||||
>>> configuration = RealmConfig()
|
||||
|
||||
>>> # Initializing a model from the qqaatw/realm-cc-news-pretrained-embedder style configuration
|
||||
>>> # Initializing a model from the google/realm-cc-news-pretrained-embedder style configuration
|
||||
>>> model = RealmEmbedder(configuration)
|
||||
|
||||
>>> # Accessing the model configuration
|
||||
|
|
|
|||
|
|
@ -43,21 +43,21 @@ from .configuration_realm import RealmConfig
|
|||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
_EMBEDDER_CHECKPOINT_FOR_DOC = "qqaatw/realm-cc-news-pretrained-embedder"
|
||||
_ENCODER_CHECKPOINT_FOR_DOC = "qqaatw/realm-cc-news-pretrained-encoder"
|
||||
_SCORER_CHECKPOINT_FOR_DOC = "qqaatw/realm-cc-news-pretrained-scorer"
|
||||
_EMBEDDER_CHECKPOINT_FOR_DOC = "google/realm-cc-news-pretrained-embedder"
|
||||
_ENCODER_CHECKPOINT_FOR_DOC = "google/realm-cc-news-pretrained-encoder"
|
||||
_SCORER_CHECKPOINT_FOR_DOC = "google/realm-cc-news-pretrained-scorer"
|
||||
_CONFIG_FOR_DOC = "RealmConfig"
|
||||
_TOKENIZER_FOR_DOC = "RealmTokenizer"
|
||||
|
||||
REALM_PRETRAINED_MODEL_ARCHIVE_LIST = [
|
||||
"qqaatw/realm-cc-news-pretrained-embedder",
|
||||
"qqaatw/realm-cc-news-pretrained-encoder",
|
||||
"qqaatw/realm-cc-news-pretrained-scorer",
|
||||
"qqaatw/realm-cc-news-pretrained-openqa",
|
||||
"qqaatw/realm-orqa-nq-openqa",
|
||||
"qqaatw/realm-orqa-nq-reader",
|
||||
"qqaatw/realm-orqa-wq-openqa",
|
||||
"qqaatw/realm-orqa-wq-reader",
|
||||
"google/realm-cc-news-pretrained-embedder",
|
||||
"google/realm-cc-news-pretrained-encoder",
|
||||
"google/realm-cc-news-pretrained-scorer",
|
||||
"google/realm-cc-news-pretrained-openqa",
|
||||
"google/realm-orqa-nq-openqa",
|
||||
"google/realm-orqa-nq-reader",
|
||||
"google/realm-orqa-wq-openqa",
|
||||
"google/realm-orqa-wq-reader",
|
||||
# See all REALM models at https://huggingface.co/models?filter=realm
|
||||
]
|
||||
|
||||
|
|
@ -1180,8 +1180,8 @@ class RealmEmbedder(RealmPreTrainedModel):
|
|||
>>> from transformers import RealmTokenizer, RealmEmbedder
|
||||
>>> import torch
|
||||
|
||||
>>> tokenizer = RealmTokenizer.from_pretrained("qqaatw/realm-cc-news-pretrained-embedder")
|
||||
>>> model = RealmEmbedder.from_pretrained("qqaatw/realm-cc-news-pretrained-embedder")
|
||||
>>> tokenizer = RealmTokenizer.from_pretrained("google/realm-cc-news-pretrained-embedder")
|
||||
>>> model = RealmEmbedder.from_pretrained("google/realm-cc-news-pretrained-embedder")
|
||||
|
||||
>>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
|
||||
>>> outputs = model(**inputs)
|
||||
|
|
@ -1293,8 +1293,8 @@ class RealmScorer(RealmPreTrainedModel):
|
|||
>>> import torch
|
||||
>>> from transformers import RealmTokenizer, RealmScorer
|
||||
|
||||
>>> tokenizer = RealmTokenizer.from_pretrained("qqaatw/realm-cc-news-pretrained-scorer")
|
||||
>>> model = RealmScorer.from_pretrained("qqaatw/realm-cc-news-pretrained-scorer", num_candidates=2)
|
||||
>>> tokenizer = RealmTokenizer.from_pretrained("google/realm-cc-news-pretrained-scorer")
|
||||
>>> model = RealmScorer.from_pretrained("google/realm-cc-news-pretrained-scorer", num_candidates=2)
|
||||
|
||||
>>> # batch_size = 2, num_candidates = 2
|
||||
>>> input_texts = ["How are you?", "What is the item in the picture?"]
|
||||
|
|
@ -1433,9 +1433,9 @@ class RealmKnowledgeAugEncoder(RealmPreTrainedModel):
|
|||
>>> import torch
|
||||
>>> from transformers import RealmTokenizer, RealmKnowledgeAugEncoder
|
||||
|
||||
>>> tokenizer = RealmTokenizer.from_pretrained("qqaatw/realm-cc-news-pretrained-encoder")
|
||||
>>> tokenizer = RealmTokenizer.from_pretrained("google/realm-cc-news-pretrained-encoder")
|
||||
>>> model = RealmKnowledgeAugEncoder.from_pretrained(
|
||||
... "qqaatw/realm-cc-news-pretrained-encoder", num_candidates=2
|
||||
... "google/realm-cc-news-pretrained-encoder", num_candidates=2
|
||||
... )
|
||||
|
||||
>>> # batch_size = 2, num_candidates = 2
|
||||
|
|
@ -1761,9 +1761,9 @@ class RealmForOpenQA(RealmPreTrainedModel):
|
|||
>>> import torch
|
||||
>>> from transformers import RealmForOpenQA, RealmRetriever, RealmTokenizer
|
||||
|
||||
>>> retriever = RealmRetriever.from_pretrained("qqaatw/realm-orqa-nq-openqa")
|
||||
>>> tokenizer = RealmTokenizer.from_pretrained("qqaatw/realm-orqa-nq-openqa")
|
||||
>>> model = RealmForOpenQA.from_pretrained("qqaatw/realm-orqa-nq-openqa", retriever=retriever)
|
||||
>>> retriever = RealmRetriever.from_pretrained("google/realm-orqa-nq-openqa")
|
||||
>>> tokenizer = RealmTokenizer.from_pretrained("google/realm-orqa-nq-openqa")
|
||||
>>> model = RealmForOpenQA.from_pretrained("google/realm-orqa-nq-openqa", retriever=retriever)
|
||||
|
||||
>>> question = "Who is the pioneer in modern computer science?"
|
||||
>>> question_ids = tokenizer([question], return_tensors="pt")
|
||||
|
|
|
|||
|
|
@ -31,37 +31,37 @@ VOCAB_FILES_NAMES = {"vocab_file": "vocab.txt"}
|
|||
|
||||
PRETRAINED_VOCAB_FILES_MAP = {
|
||||
"vocab_file": {
|
||||
"qqaatw/realm-cc-news-pretrained-embedder": "https://huggingface.co/qqaatw/realm-cc-news-pretrained-embedder/resolve/main/vocab.txt",
|
||||
"qqaatw/realm-cc-news-pretrained-encoder": "https://huggingface.co/qqaatw/realm-cc-news-pretrained-encoder/resolve/main/vocab.txt",
|
||||
"qqaatw/realm-cc-news-pretrained-scorer": "https://huggingface.co/qqaatw/realm-cc-news-pretrained-scorer/resolve/main/vocab.txt",
|
||||
"qqaatw/realm-cc-news-pretrained-openqa": "https://huggingface.co/qqaatw/realm-cc-news-pretrained-openqa/aresolve/main/vocab.txt",
|
||||
"qqaatw/realm-orqa-nq-openqa": "https://huggingface.co/qqaatw/realm-orqa-nq-openqa/resolve/main/vocab.txt",
|
||||
"qqaatw/realm-orqa-nq-reader": "https://huggingface.co/qqaatw/realm-orqa-nq-reader/resolve/main/vocab.txt",
|
||||
"qqaatw/realm-orqa-wq-openqa": "https://huggingface.co/qqaatw/realm-orqa-wq-openqa/resolve/main/vocab.txt",
|
||||
"qqaatw/realm-orqa-wq-reader": "https://huggingface.co/qqaatw/realm-orqa-wq-reader/resolve/main/vocab.txt",
|
||||
"google/realm-cc-news-pretrained-embedder": "https://huggingface.co/google/realm-cc-news-pretrained-embedder/resolve/main/vocab.txt",
|
||||
"google/realm-cc-news-pretrained-encoder": "https://huggingface.co/google/realm-cc-news-pretrained-encoder/resolve/main/vocab.txt",
|
||||
"google/realm-cc-news-pretrained-scorer": "https://huggingface.co/google/realm-cc-news-pretrained-scorer/resolve/main/vocab.txt",
|
||||
"google/realm-cc-news-pretrained-openqa": "https://huggingface.co/google/realm-cc-news-pretrained-openqa/aresolve/main/vocab.txt",
|
||||
"google/realm-orqa-nq-openqa": "https://huggingface.co/google/realm-orqa-nq-openqa/resolve/main/vocab.txt",
|
||||
"google/realm-orqa-nq-reader": "https://huggingface.co/google/realm-orqa-nq-reader/resolve/main/vocab.txt",
|
||||
"google/realm-orqa-wq-openqa": "https://huggingface.co/google/realm-orqa-wq-openqa/resolve/main/vocab.txt",
|
||||
"google/realm-orqa-wq-reader": "https://huggingface.co/google/realm-orqa-wq-reader/resolve/main/vocab.txt",
|
||||
}
|
||||
}
|
||||
|
||||
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
|
||||
"qqaatw/realm-cc-news-pretrained-embedder": 512,
|
||||
"qqaatw/realm-cc-news-pretrained-encoder": 512,
|
||||
"qqaatw/realm-cc-news-pretrained-scorer": 512,
|
||||
"qqaatw/realm-cc-news-pretrained-openqa": 512,
|
||||
"qqaatw/realm-orqa-nq-openqa": 512,
|
||||
"qqaatw/realm-orqa-nq-reader": 512,
|
||||
"qqaatw/realm-orqa-wq-openqa": 512,
|
||||
"qqaatw/realm-orqa-wq-reader": 512,
|
||||
"google/realm-cc-news-pretrained-embedder": 512,
|
||||
"google/realm-cc-news-pretrained-encoder": 512,
|
||||
"google/realm-cc-news-pretrained-scorer": 512,
|
||||
"google/realm-cc-news-pretrained-openqa": 512,
|
||||
"google/realm-orqa-nq-openqa": 512,
|
||||
"google/realm-orqa-nq-reader": 512,
|
||||
"google/realm-orqa-wq-openqa": 512,
|
||||
"google/realm-orqa-wq-reader": 512,
|
||||
}
|
||||
|
||||
PRETRAINED_INIT_CONFIGURATION = {
|
||||
"qqaatw/realm-cc-news-pretrained-embedder": {"do_lower_case": True},
|
||||
"qqaatw/realm-cc-news-pretrained-encoder": {"do_lower_case": True},
|
||||
"qqaatw/realm-cc-news-pretrained-scorer": {"do_lower_case": True},
|
||||
"qqaatw/realm-cc-news-pretrained-openqa": {"do_lower_case": True},
|
||||
"qqaatw/realm-orqa-nq-openqa": {"do_lower_case": True},
|
||||
"qqaatw/realm-orqa-nq-reader": {"do_lower_case": True},
|
||||
"qqaatw/realm-orqa-wq-openqa": {"do_lower_case": True},
|
||||
"qqaatw/realm-orqa-wq-reader": {"do_lower_case": True},
|
||||
"google/realm-cc-news-pretrained-embedder": {"do_lower_case": True},
|
||||
"google/realm-cc-news-pretrained-encoder": {"do_lower_case": True},
|
||||
"google/realm-cc-news-pretrained-scorer": {"do_lower_case": True},
|
||||
"google/realm-cc-news-pretrained-openqa": {"do_lower_case": True},
|
||||
"google/realm-orqa-nq-openqa": {"do_lower_case": True},
|
||||
"google/realm-orqa-nq-reader": {"do_lower_case": True},
|
||||
"google/realm-orqa-wq-openqa": {"do_lower_case": True},
|
||||
"google/realm-orqa-wq-reader": {"do_lower_case": True},
|
||||
}
|
||||
|
||||
|
||||
|
|
@ -252,7 +252,7 @@ class RealmTokenizer(PreTrainedTokenizer):
|
|||
>>> # batch_size = 2, num_candidates = 2
|
||||
>>> text = [["Hello world!", "Nice to meet you!"], ["The cute cat.", "The adorable dog."]]
|
||||
|
||||
>>> tokenizer = RealmTokenizer.from_pretrained("qqaatw/realm-cc-news-pretrained-encoder")
|
||||
>>> tokenizer = RealmTokenizer.from_pretrained("google/realm-cc-news-pretrained-encoder")
|
||||
>>> tokenized_text = tokenizer.batch_encode_candidates(text, max_length=10, return_tensors="pt")
|
||||
```"""
|
||||
|
||||
|
|
|
|||
|
|
@ -32,47 +32,47 @@ VOCAB_FILES_NAMES = {"vocab_file": "vocab.txt", "tokenizer_file": "tokenizer.jso
|
|||
|
||||
PRETRAINED_VOCAB_FILES_MAP = {
|
||||
"vocab_file": {
|
||||
"qqaatw/realm-cc-news-pretrained-embedder": "https://huggingface.co/qqaatw/realm-cc-news-pretrained-embedder/resolve/main/vocab.txt",
|
||||
"qqaatw/realm-cc-news-pretrained-encoder": "https://huggingface.co/qqaatw/realm-cc-news-pretrained-encoder/resolve/main/vocab.txt",
|
||||
"qqaatw/realm-cc-news-pretrained-scorer": "https://huggingface.co/qqaatw/realm-cc-news-pretrained-scorer/resolve/main/vocab.txt",
|
||||
"qqaatw/realm-cc-news-pretrained-openqa": "https://huggingface.co/qqaatw/realm-cc-news-pretrained-openqa/aresolve/main/vocab.txt",
|
||||
"qqaatw/realm-orqa-nq-openqa": "https://huggingface.co/qqaatw/realm-orqa-nq-openqa/resolve/main/vocab.txt",
|
||||
"qqaatw/realm-orqa-nq-reader": "https://huggingface.co/qqaatw/realm-orqa-nq-reader/resolve/main/vocab.txt",
|
||||
"qqaatw/realm-orqa-wq-openqa": "https://huggingface.co/qqaatw/realm-orqa-wq-openqa/resolve/main/vocab.txt",
|
||||
"qqaatw/realm-orqa-wq-reader": "https://huggingface.co/qqaatw/realm-orqa-wq-reader/resolve/main/vocab.txt",
|
||||
"google/realm-cc-news-pretrained-embedder": "https://huggingface.co/google/realm-cc-news-pretrained-embedder/resolve/main/vocab.txt",
|
||||
"google/realm-cc-news-pretrained-encoder": "https://huggingface.co/google/realm-cc-news-pretrained-encoder/resolve/main/vocab.txt",
|
||||
"google/realm-cc-news-pretrained-scorer": "https://huggingface.co/google/realm-cc-news-pretrained-scorer/resolve/main/vocab.txt",
|
||||
"google/realm-cc-news-pretrained-openqa": "https://huggingface.co/google/realm-cc-news-pretrained-openqa/aresolve/main/vocab.txt",
|
||||
"google/realm-orqa-nq-openqa": "https://huggingface.co/google/realm-orqa-nq-openqa/resolve/main/vocab.txt",
|
||||
"google/realm-orqa-nq-reader": "https://huggingface.co/google/realm-orqa-nq-reader/resolve/main/vocab.txt",
|
||||
"google/realm-orqa-wq-openqa": "https://huggingface.co/google/realm-orqa-wq-openqa/resolve/main/vocab.txt",
|
||||
"google/realm-orqa-wq-reader": "https://huggingface.co/google/realm-orqa-wq-reader/resolve/main/vocab.txt",
|
||||
},
|
||||
"tokenizer_file": {
|
||||
"qqaatw/realm-cc-news-pretrained-embedder": "https://huggingface.co/qqaatw/realm-cc-news-pretrained-embedder/resolve/main/tokenizer.jsont",
|
||||
"qqaatw/realm-cc-news-pretrained-encoder": "https://huggingface.co/qqaatw/realm-cc-news-pretrained-encoder/resolve/main/tokenizer.json",
|
||||
"qqaatw/realm-cc-news-pretrained-scorer": "https://huggingface.co/qqaatw/realm-cc-news-pretrained-scorer/resolve/main/tokenizer.json",
|
||||
"qqaatw/realm-cc-news-pretrained-openqa": "https://huggingface.co/qqaatw/realm-cc-news-pretrained-openqa/aresolve/main/tokenizer.json",
|
||||
"qqaatw/realm-orqa-nq-openqa": "https://huggingface.co/qqaatw/realm-orqa-nq-openqa/resolve/main/tokenizer.json",
|
||||
"qqaatw/realm-orqa-nq-reader": "https://huggingface.co/qqaatw/realm-orqa-nq-reader/resolve/main/tokenizer.json",
|
||||
"qqaatw/realm-orqa-wq-openqa": "https://huggingface.co/qqaatw/realm-orqa-wq-openqa/resolve/main/tokenizer.json",
|
||||
"qqaatw/realm-orqa-wq-reader": "https://huggingface.co/qqaatw/realm-orqa-wq-reader/resolve/main/tokenizer.json",
|
||||
"google/realm-cc-news-pretrained-embedder": "https://huggingface.co/google/realm-cc-news-pretrained-embedder/resolve/main/tokenizer.jsont",
|
||||
"google/realm-cc-news-pretrained-encoder": "https://huggingface.co/google/realm-cc-news-pretrained-encoder/resolve/main/tokenizer.json",
|
||||
"google/realm-cc-news-pretrained-scorer": "https://huggingface.co/google/realm-cc-news-pretrained-scorer/resolve/main/tokenizer.json",
|
||||
"google/realm-cc-news-pretrained-openqa": "https://huggingface.co/google/realm-cc-news-pretrained-openqa/aresolve/main/tokenizer.json",
|
||||
"google/realm-orqa-nq-openqa": "https://huggingface.co/google/realm-orqa-nq-openqa/resolve/main/tokenizer.json",
|
||||
"google/realm-orqa-nq-reader": "https://huggingface.co/google/realm-orqa-nq-reader/resolve/main/tokenizer.json",
|
||||
"google/realm-orqa-wq-openqa": "https://huggingface.co/google/realm-orqa-wq-openqa/resolve/main/tokenizer.json",
|
||||
"google/realm-orqa-wq-reader": "https://huggingface.co/google/realm-orqa-wq-reader/resolve/main/tokenizer.json",
|
||||
},
|
||||
}
|
||||
|
||||
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
|
||||
"qqaatw/realm-cc-news-pretrained-embedder": 512,
|
||||
"qqaatw/realm-cc-news-pretrained-encoder": 512,
|
||||
"qqaatw/realm-cc-news-pretrained-scorer": 512,
|
||||
"qqaatw/realm-cc-news-pretrained-openqa": 512,
|
||||
"qqaatw/realm-orqa-nq-openqa": 512,
|
||||
"qqaatw/realm-orqa-nq-reader": 512,
|
||||
"qqaatw/realm-orqa-wq-openqa": 512,
|
||||
"qqaatw/realm-orqa-wq-reader": 512,
|
||||
"google/realm-cc-news-pretrained-embedder": 512,
|
||||
"google/realm-cc-news-pretrained-encoder": 512,
|
||||
"google/realm-cc-news-pretrained-scorer": 512,
|
||||
"google/realm-cc-news-pretrained-openqa": 512,
|
||||
"google/realm-orqa-nq-openqa": 512,
|
||||
"google/realm-orqa-nq-reader": 512,
|
||||
"google/realm-orqa-wq-openqa": 512,
|
||||
"google/realm-orqa-wq-reader": 512,
|
||||
}
|
||||
|
||||
PRETRAINED_INIT_CONFIGURATION = {
|
||||
"qqaatw/realm-cc-news-pretrained-embedder": {"do_lower_case": True},
|
||||
"qqaatw/realm-cc-news-pretrained-encoder": {"do_lower_case": True},
|
||||
"qqaatw/realm-cc-news-pretrained-scorer": {"do_lower_case": True},
|
||||
"qqaatw/realm-cc-news-pretrained-openqa": {"do_lower_case": True},
|
||||
"qqaatw/realm-orqa-nq-openqa": {"do_lower_case": True},
|
||||
"qqaatw/realm-orqa-nq-reader": {"do_lower_case": True},
|
||||
"qqaatw/realm-orqa-wq-openqa": {"do_lower_case": True},
|
||||
"qqaatw/realm-orqa-wq-reader": {"do_lower_case": True},
|
||||
"google/realm-cc-news-pretrained-embedder": {"do_lower_case": True},
|
||||
"google/realm-cc-news-pretrained-encoder": {"do_lower_case": True},
|
||||
"google/realm-cc-news-pretrained-scorer": {"do_lower_case": True},
|
||||
"google/realm-cc-news-pretrained-openqa": {"do_lower_case": True},
|
||||
"google/realm-orqa-nq-openqa": {"do_lower_case": True},
|
||||
"google/realm-orqa-nq-reader": {"do_lower_case": True},
|
||||
"google/realm-orqa-wq-openqa": {"do_lower_case": True},
|
||||
"google/realm-orqa-wq-reader": {"do_lower_case": True},
|
||||
}
|
||||
|
||||
|
||||
|
|
@ -200,7 +200,7 @@ class RealmTokenizerFast(PreTrainedTokenizerFast):
|
|||
>>> # batch_size = 2, num_candidates = 2
|
||||
>>> text = [["Hello world!", "Nice to meet you!"], ["The cute cat.", "The adorable dog."]]
|
||||
|
||||
>>> tokenizer = RealmTokenizerFast.from_pretrained("qqaatw/realm-cc-news-pretrained-encoder")
|
||||
>>> tokenizer = RealmTokenizerFast.from_pretrained("google/realm-cc-news-pretrained-encoder")
|
||||
>>> tokenized_text = tokenizer.batch_encode_candidates(text, max_length=10, return_tensors="pt")
|
||||
```"""
|
||||
|
||||
|
|
|
|||
|
|
@ -358,7 +358,7 @@ class RealmModelTest(ModelTesterMixin, unittest.TestCase):
|
|||
input_ids, token_type_ids, input_mask, scorer_encoder_inputs = inputs[0:4]
|
||||
config.return_dict = True
|
||||
|
||||
tokenizer = RealmTokenizer.from_pretrained("qqaatw/realm-orqa-nq-openqa")
|
||||
tokenizer = RealmTokenizer.from_pretrained("google/realm-orqa-nq-openqa")
|
||||
|
||||
# RealmKnowledgeAugEncoder training
|
||||
model = RealmKnowledgeAugEncoder(config)
|
||||
|
|
@ -411,27 +411,27 @@ class RealmModelTest(ModelTesterMixin, unittest.TestCase):
|
|||
|
||||
@slow
|
||||
def test_embedder_from_pretrained(self):
|
||||
model = RealmEmbedder.from_pretrained("qqaatw/realm-cc-news-pretrained-embedder")
|
||||
model = RealmEmbedder.from_pretrained("google/realm-cc-news-pretrained-embedder")
|
||||
self.assertIsNotNone(model)
|
||||
|
||||
@slow
|
||||
def test_encoder_from_pretrained(self):
|
||||
model = RealmKnowledgeAugEncoder.from_pretrained("qqaatw/realm-cc-news-pretrained-encoder")
|
||||
model = RealmKnowledgeAugEncoder.from_pretrained("google/realm-cc-news-pretrained-encoder")
|
||||
self.assertIsNotNone(model)
|
||||
|
||||
@slow
|
||||
def test_open_qa_from_pretrained(self):
|
||||
model = RealmForOpenQA.from_pretrained("qqaatw/realm-orqa-nq-openqa")
|
||||
model = RealmForOpenQA.from_pretrained("google/realm-orqa-nq-openqa")
|
||||
self.assertIsNotNone(model)
|
||||
|
||||
@slow
|
||||
def test_reader_from_pretrained(self):
|
||||
model = RealmReader.from_pretrained("qqaatw/realm-orqa-nq-reader")
|
||||
model = RealmReader.from_pretrained("google/realm-orqa-nq-reader")
|
||||
self.assertIsNotNone(model)
|
||||
|
||||
@slow
|
||||
def test_scorer_from_pretrained(self):
|
||||
model = RealmScorer.from_pretrained("qqaatw/realm-cc-news-pretrained-scorer")
|
||||
model = RealmScorer.from_pretrained("google/realm-cc-news-pretrained-scorer")
|
||||
self.assertIsNotNone(model)
|
||||
|
||||
|
||||
|
|
@ -441,7 +441,7 @@ class RealmModelIntegrationTest(unittest.TestCase):
|
|||
def test_inference_embedder(self):
|
||||
retriever_projected_size = 128
|
||||
|
||||
model = RealmEmbedder.from_pretrained("qqaatw/realm-cc-news-pretrained-embedder")
|
||||
model = RealmEmbedder.from_pretrained("google/realm-cc-news-pretrained-embedder")
|
||||
input_ids = torch.tensor([[0, 1, 2, 3, 4, 5]])
|
||||
output = model(input_ids)[0]
|
||||
|
||||
|
|
@ -457,7 +457,7 @@ class RealmModelIntegrationTest(unittest.TestCase):
|
|||
vocab_size = 30522
|
||||
|
||||
model = RealmKnowledgeAugEncoder.from_pretrained(
|
||||
"qqaatw/realm-cc-news-pretrained-encoder", num_candidates=num_candidates
|
||||
"google/realm-cc-news-pretrained-encoder", num_candidates=num_candidates
|
||||
)
|
||||
input_ids = torch.tensor([[0, 1, 2, 3, 4, 5], [6, 7, 8, 9, 10, 11]])
|
||||
relevance_score = torch.tensor([[0.3, 0.7]], dtype=torch.float32)
|
||||
|
|
@ -476,11 +476,11 @@ class RealmModelIntegrationTest(unittest.TestCase):
|
|||
|
||||
config = RealmConfig()
|
||||
|
||||
tokenizer = RealmTokenizer.from_pretrained("qqaatw/realm-orqa-nq-openqa")
|
||||
retriever = RealmRetriever.from_pretrained("qqaatw/realm-orqa-nq-openqa")
|
||||
tokenizer = RealmTokenizer.from_pretrained("google/realm-orqa-nq-openqa")
|
||||
retriever = RealmRetriever.from_pretrained("google/realm-orqa-nq-openqa")
|
||||
|
||||
model = RealmForOpenQA.from_pretrained(
|
||||
"qqaatw/realm-orqa-nq-openqa",
|
||||
"google/realm-orqa-nq-openqa",
|
||||
retriever=retriever,
|
||||
config=config,
|
||||
)
|
||||
|
|
@ -503,7 +503,7 @@ class RealmModelIntegrationTest(unittest.TestCase):
|
|||
@slow
|
||||
def test_inference_reader(self):
|
||||
config = RealmConfig(reader_beam_size=2, max_span_width=3)
|
||||
model = RealmReader.from_pretrained("qqaatw/realm-orqa-nq-reader", config=config)
|
||||
model = RealmReader.from_pretrained("google/realm-orqa-nq-reader", config=config)
|
||||
|
||||
concat_input_ids = torch.arange(10).view((2, 5))
|
||||
concat_token_type_ids = torch.tensor([[0, 0, 1, 1, 1], [0, 0, 1, 1, 1]], dtype=torch.int64)
|
||||
|
|
@ -532,7 +532,7 @@ class RealmModelIntegrationTest(unittest.TestCase):
|
|||
def test_inference_scorer(self):
|
||||
num_candidates = 2
|
||||
|
||||
model = RealmScorer.from_pretrained("qqaatw/realm-cc-news-pretrained-scorer", num_candidates=num_candidates)
|
||||
model = RealmScorer.from_pretrained("google/realm-cc-news-pretrained-scorer", num_candidates=num_candidates)
|
||||
|
||||
input_ids = torch.tensor([[0, 1, 2, 3, 4, 5]])
|
||||
candidate_input_ids = torch.tensor([[0, 1, 2, 3, 4, 5], [6, 7, 8, 9, 10, 11]])
|
||||
|
|
|
|||
|
|
@ -180,6 +180,6 @@ class RealmRetrieverTest(TestCase):
|
|||
mock_hf_hub_download.return_value = os.path.join(
|
||||
os.path.join(self.tmpdirname, "realm_block_records"), _REALM_BLOCK_RECORDS_FILENAME
|
||||
)
|
||||
retriever = RealmRetriever.from_pretrained("qqaatw/realm-cc-news-pretrained-openqa")
|
||||
retriever = RealmRetriever.from_pretrained("google/realm-cc-news-pretrained-openqa")
|
||||
|
||||
self.assertEqual(retriever.block_records[0], b"This is the first record")
|
||||
|
|
|
|||
Loading…
Reference in a new issue