diff --git a/src/transformers/models/realm/configuration_realm.py b/src/transformers/models/realm/configuration_realm.py index 576294551..a0ff2bb9b 100644 --- a/src/transformers/models/realm/configuration_realm.py +++ b/src/transformers/models/realm/configuration_realm.py @@ -21,14 +21,14 @@ from ...utils import logging logger = logging.get_logger(__name__) REALM_PRETRAINED_CONFIG_ARCHIVE_MAP = { - "qqaatw/realm-cc-news-pretrained-embedder": "https://huggingface.co/qqaatw/realm-cc-news-pretrained-embedder/resolve/main/config.json", - "qqaatw/realm-cc-news-pretrained-encoder": "https://huggingface.co/qqaatw/realm-cc-news-pretrained-encoder/resolve/main/config.json", - "qqaatw/realm-cc-news-pretrained-scorer": "https://huggingface.co/qqaatw/realm-cc-news-pretrained-scorer/resolve/main/config.json", - "qqaatw/realm-cc-news-pretrained-openqa": "https://huggingface.co/qqaatw/realm-cc-news-pretrained-openqa/aresolve/main/config.json", - "qqaatw/realm-orqa-nq-openqa": "https://huggingface.co/qqaatw/realm-orqa-nq-openqa/resolve/main/config.json", - "qqaatw/realm-orqa-nq-reader": "https://huggingface.co/qqaatw/realm-orqa-nq-reader/resolve/main/config.json", - "qqaatw/realm-orqa-wq-openqa": "https://huggingface.co/qqaatw/realm-orqa-wq-openqa/resolve/main/config.json", - "qqaatw/realm-orqa-wq-reader": "https://huggingface.co/qqaatw/realm-orqa-wq-reader/resolve/main/config.json", + "google/realm-cc-news-pretrained-embedder": "https://huggingface.co/google/realm-cc-news-pretrained-embedder/resolve/main/config.json", + "google/realm-cc-news-pretrained-encoder": "https://huggingface.co/google/realm-cc-news-pretrained-encoder/resolve/main/config.json", + "google/realm-cc-news-pretrained-scorer": "https://huggingface.co/google/realm-cc-news-pretrained-scorer/resolve/main/config.json", + "google/realm-cc-news-pretrained-openqa": "https://huggingface.co/google/realm-cc-news-pretrained-openqa/aresolve/main/config.json", + "google/realm-orqa-nq-openqa": "https://huggingface.co/google/realm-orqa-nq-openqa/resolve/main/config.json", + "google/realm-orqa-nq-reader": "https://huggingface.co/google/realm-orqa-nq-reader/resolve/main/config.json", + "google/realm-orqa-wq-openqa": "https://huggingface.co/google/realm-orqa-wq-openqa/resolve/main/config.json", + "google/realm-orqa-wq-reader": "https://huggingface.co/google/realm-orqa-wq-reader/resolve/main/config.json", # See all REALM models at https://huggingface.co/models?filter=realm } @@ -46,7 +46,7 @@ class RealmConfig(PretrainedConfig): It is used to instantiate an REALM model according to the specified arguments, defining the model architecture. Instantiating a configuration with the defaults will yield a similar configuration to that of the REALM - [realm-cc-news-pretrained](https://huggingface.co/qqaatw/realm-cc-news-pretrained-embedder) architecture. + [realm-cc-news-pretrained](https://huggingface.co/google/realm-cc-news-pretrained-embedder) architecture. Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the documentation from [`PretrainedConfig`] for more information. @@ -112,7 +112,7 @@ class RealmConfig(PretrainedConfig): >>> # Initializing a REALM realm-cc-news-pretrained-* style configuration >>> configuration = RealmConfig() - >>> # Initializing a model from the qqaatw/realm-cc-news-pretrained-embedder style configuration + >>> # Initializing a model from the google/realm-cc-news-pretrained-embedder style configuration >>> model = RealmEmbedder(configuration) >>> # Accessing the model configuration diff --git a/src/transformers/models/realm/modeling_realm.py b/src/transformers/models/realm/modeling_realm.py index 95e941f9e..165e62c0e 100644 --- a/src/transformers/models/realm/modeling_realm.py +++ b/src/transformers/models/realm/modeling_realm.py @@ -43,21 +43,21 @@ from .configuration_realm import RealmConfig logger = logging.get_logger(__name__) -_EMBEDDER_CHECKPOINT_FOR_DOC = "qqaatw/realm-cc-news-pretrained-embedder" -_ENCODER_CHECKPOINT_FOR_DOC = "qqaatw/realm-cc-news-pretrained-encoder" -_SCORER_CHECKPOINT_FOR_DOC = "qqaatw/realm-cc-news-pretrained-scorer" +_EMBEDDER_CHECKPOINT_FOR_DOC = "google/realm-cc-news-pretrained-embedder" +_ENCODER_CHECKPOINT_FOR_DOC = "google/realm-cc-news-pretrained-encoder" +_SCORER_CHECKPOINT_FOR_DOC = "google/realm-cc-news-pretrained-scorer" _CONFIG_FOR_DOC = "RealmConfig" _TOKENIZER_FOR_DOC = "RealmTokenizer" REALM_PRETRAINED_MODEL_ARCHIVE_LIST = [ - "qqaatw/realm-cc-news-pretrained-embedder", - "qqaatw/realm-cc-news-pretrained-encoder", - "qqaatw/realm-cc-news-pretrained-scorer", - "qqaatw/realm-cc-news-pretrained-openqa", - "qqaatw/realm-orqa-nq-openqa", - "qqaatw/realm-orqa-nq-reader", - "qqaatw/realm-orqa-wq-openqa", - "qqaatw/realm-orqa-wq-reader", + "google/realm-cc-news-pretrained-embedder", + "google/realm-cc-news-pretrained-encoder", + "google/realm-cc-news-pretrained-scorer", + "google/realm-cc-news-pretrained-openqa", + "google/realm-orqa-nq-openqa", + "google/realm-orqa-nq-reader", + "google/realm-orqa-wq-openqa", + "google/realm-orqa-wq-reader", # See all REALM models at https://huggingface.co/models?filter=realm ] @@ -1180,8 +1180,8 @@ class RealmEmbedder(RealmPreTrainedModel): >>> from transformers import RealmTokenizer, RealmEmbedder >>> import torch - >>> tokenizer = RealmTokenizer.from_pretrained("qqaatw/realm-cc-news-pretrained-embedder") - >>> model = RealmEmbedder.from_pretrained("qqaatw/realm-cc-news-pretrained-embedder") + >>> tokenizer = RealmTokenizer.from_pretrained("google/realm-cc-news-pretrained-embedder") + >>> model = RealmEmbedder.from_pretrained("google/realm-cc-news-pretrained-embedder") >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt") >>> outputs = model(**inputs) @@ -1293,8 +1293,8 @@ class RealmScorer(RealmPreTrainedModel): >>> import torch >>> from transformers import RealmTokenizer, RealmScorer - >>> tokenizer = RealmTokenizer.from_pretrained("qqaatw/realm-cc-news-pretrained-scorer") - >>> model = RealmScorer.from_pretrained("qqaatw/realm-cc-news-pretrained-scorer", num_candidates=2) + >>> tokenizer = RealmTokenizer.from_pretrained("google/realm-cc-news-pretrained-scorer") + >>> model = RealmScorer.from_pretrained("google/realm-cc-news-pretrained-scorer", num_candidates=2) >>> # batch_size = 2, num_candidates = 2 >>> input_texts = ["How are you?", "What is the item in the picture?"] @@ -1433,9 +1433,9 @@ class RealmKnowledgeAugEncoder(RealmPreTrainedModel): >>> import torch >>> from transformers import RealmTokenizer, RealmKnowledgeAugEncoder - >>> tokenizer = RealmTokenizer.from_pretrained("qqaatw/realm-cc-news-pretrained-encoder") + >>> tokenizer = RealmTokenizer.from_pretrained("google/realm-cc-news-pretrained-encoder") >>> model = RealmKnowledgeAugEncoder.from_pretrained( - ... "qqaatw/realm-cc-news-pretrained-encoder", num_candidates=2 + ... "google/realm-cc-news-pretrained-encoder", num_candidates=2 ... ) >>> # batch_size = 2, num_candidates = 2 @@ -1761,9 +1761,9 @@ class RealmForOpenQA(RealmPreTrainedModel): >>> import torch >>> from transformers import RealmForOpenQA, RealmRetriever, RealmTokenizer - >>> retriever = RealmRetriever.from_pretrained("qqaatw/realm-orqa-nq-openqa") - >>> tokenizer = RealmTokenizer.from_pretrained("qqaatw/realm-orqa-nq-openqa") - >>> model = RealmForOpenQA.from_pretrained("qqaatw/realm-orqa-nq-openqa", retriever=retriever) + >>> retriever = RealmRetriever.from_pretrained("google/realm-orqa-nq-openqa") + >>> tokenizer = RealmTokenizer.from_pretrained("google/realm-orqa-nq-openqa") + >>> model = RealmForOpenQA.from_pretrained("google/realm-orqa-nq-openqa", retriever=retriever) >>> question = "Who is the pioneer in modern computer science?" >>> question_ids = tokenizer([question], return_tensors="pt") diff --git a/src/transformers/models/realm/tokenization_realm.py b/src/transformers/models/realm/tokenization_realm.py index 571cc7c19..3ddfd19e5 100644 --- a/src/transformers/models/realm/tokenization_realm.py +++ b/src/transformers/models/realm/tokenization_realm.py @@ -31,37 +31,37 @@ VOCAB_FILES_NAMES = {"vocab_file": "vocab.txt"} PRETRAINED_VOCAB_FILES_MAP = { "vocab_file": { - "qqaatw/realm-cc-news-pretrained-embedder": "https://huggingface.co/qqaatw/realm-cc-news-pretrained-embedder/resolve/main/vocab.txt", - "qqaatw/realm-cc-news-pretrained-encoder": "https://huggingface.co/qqaatw/realm-cc-news-pretrained-encoder/resolve/main/vocab.txt", - "qqaatw/realm-cc-news-pretrained-scorer": "https://huggingface.co/qqaatw/realm-cc-news-pretrained-scorer/resolve/main/vocab.txt", - "qqaatw/realm-cc-news-pretrained-openqa": "https://huggingface.co/qqaatw/realm-cc-news-pretrained-openqa/aresolve/main/vocab.txt", - "qqaatw/realm-orqa-nq-openqa": "https://huggingface.co/qqaatw/realm-orqa-nq-openqa/resolve/main/vocab.txt", - "qqaatw/realm-orqa-nq-reader": "https://huggingface.co/qqaatw/realm-orqa-nq-reader/resolve/main/vocab.txt", - "qqaatw/realm-orqa-wq-openqa": "https://huggingface.co/qqaatw/realm-orqa-wq-openqa/resolve/main/vocab.txt", - "qqaatw/realm-orqa-wq-reader": "https://huggingface.co/qqaatw/realm-orqa-wq-reader/resolve/main/vocab.txt", + "google/realm-cc-news-pretrained-embedder": "https://huggingface.co/google/realm-cc-news-pretrained-embedder/resolve/main/vocab.txt", + "google/realm-cc-news-pretrained-encoder": "https://huggingface.co/google/realm-cc-news-pretrained-encoder/resolve/main/vocab.txt", + "google/realm-cc-news-pretrained-scorer": "https://huggingface.co/google/realm-cc-news-pretrained-scorer/resolve/main/vocab.txt", + "google/realm-cc-news-pretrained-openqa": "https://huggingface.co/google/realm-cc-news-pretrained-openqa/aresolve/main/vocab.txt", + "google/realm-orqa-nq-openqa": "https://huggingface.co/google/realm-orqa-nq-openqa/resolve/main/vocab.txt", + "google/realm-orqa-nq-reader": "https://huggingface.co/google/realm-orqa-nq-reader/resolve/main/vocab.txt", + "google/realm-orqa-wq-openqa": "https://huggingface.co/google/realm-orqa-wq-openqa/resolve/main/vocab.txt", + "google/realm-orqa-wq-reader": "https://huggingface.co/google/realm-orqa-wq-reader/resolve/main/vocab.txt", } } PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = { - "qqaatw/realm-cc-news-pretrained-embedder": 512, - "qqaatw/realm-cc-news-pretrained-encoder": 512, - "qqaatw/realm-cc-news-pretrained-scorer": 512, - "qqaatw/realm-cc-news-pretrained-openqa": 512, - "qqaatw/realm-orqa-nq-openqa": 512, - "qqaatw/realm-orqa-nq-reader": 512, - "qqaatw/realm-orqa-wq-openqa": 512, - "qqaatw/realm-orqa-wq-reader": 512, + "google/realm-cc-news-pretrained-embedder": 512, + "google/realm-cc-news-pretrained-encoder": 512, + "google/realm-cc-news-pretrained-scorer": 512, + "google/realm-cc-news-pretrained-openqa": 512, + "google/realm-orqa-nq-openqa": 512, + "google/realm-orqa-nq-reader": 512, + "google/realm-orqa-wq-openqa": 512, + "google/realm-orqa-wq-reader": 512, } PRETRAINED_INIT_CONFIGURATION = { - "qqaatw/realm-cc-news-pretrained-embedder": {"do_lower_case": True}, - "qqaatw/realm-cc-news-pretrained-encoder": {"do_lower_case": True}, - "qqaatw/realm-cc-news-pretrained-scorer": {"do_lower_case": True}, - "qqaatw/realm-cc-news-pretrained-openqa": {"do_lower_case": True}, - "qqaatw/realm-orqa-nq-openqa": {"do_lower_case": True}, - "qqaatw/realm-orqa-nq-reader": {"do_lower_case": True}, - "qqaatw/realm-orqa-wq-openqa": {"do_lower_case": True}, - "qqaatw/realm-orqa-wq-reader": {"do_lower_case": True}, + "google/realm-cc-news-pretrained-embedder": {"do_lower_case": True}, + "google/realm-cc-news-pretrained-encoder": {"do_lower_case": True}, + "google/realm-cc-news-pretrained-scorer": {"do_lower_case": True}, + "google/realm-cc-news-pretrained-openqa": {"do_lower_case": True}, + "google/realm-orqa-nq-openqa": {"do_lower_case": True}, + "google/realm-orqa-nq-reader": {"do_lower_case": True}, + "google/realm-orqa-wq-openqa": {"do_lower_case": True}, + "google/realm-orqa-wq-reader": {"do_lower_case": True}, } @@ -252,7 +252,7 @@ class RealmTokenizer(PreTrainedTokenizer): >>> # batch_size = 2, num_candidates = 2 >>> text = [["Hello world!", "Nice to meet you!"], ["The cute cat.", "The adorable dog."]] - >>> tokenizer = RealmTokenizer.from_pretrained("qqaatw/realm-cc-news-pretrained-encoder") + >>> tokenizer = RealmTokenizer.from_pretrained("google/realm-cc-news-pretrained-encoder") >>> tokenized_text = tokenizer.batch_encode_candidates(text, max_length=10, return_tensors="pt") ```""" diff --git a/src/transformers/models/realm/tokenization_realm_fast.py b/src/transformers/models/realm/tokenization_realm_fast.py index e78dc4f99..7f55a72d0 100644 --- a/src/transformers/models/realm/tokenization_realm_fast.py +++ b/src/transformers/models/realm/tokenization_realm_fast.py @@ -32,47 +32,47 @@ VOCAB_FILES_NAMES = {"vocab_file": "vocab.txt", "tokenizer_file": "tokenizer.jso PRETRAINED_VOCAB_FILES_MAP = { "vocab_file": { - "qqaatw/realm-cc-news-pretrained-embedder": "https://huggingface.co/qqaatw/realm-cc-news-pretrained-embedder/resolve/main/vocab.txt", - "qqaatw/realm-cc-news-pretrained-encoder": "https://huggingface.co/qqaatw/realm-cc-news-pretrained-encoder/resolve/main/vocab.txt", - "qqaatw/realm-cc-news-pretrained-scorer": "https://huggingface.co/qqaatw/realm-cc-news-pretrained-scorer/resolve/main/vocab.txt", - "qqaatw/realm-cc-news-pretrained-openqa": "https://huggingface.co/qqaatw/realm-cc-news-pretrained-openqa/aresolve/main/vocab.txt", - "qqaatw/realm-orqa-nq-openqa": "https://huggingface.co/qqaatw/realm-orqa-nq-openqa/resolve/main/vocab.txt", - "qqaatw/realm-orqa-nq-reader": "https://huggingface.co/qqaatw/realm-orqa-nq-reader/resolve/main/vocab.txt", - "qqaatw/realm-orqa-wq-openqa": "https://huggingface.co/qqaatw/realm-orqa-wq-openqa/resolve/main/vocab.txt", - "qqaatw/realm-orqa-wq-reader": "https://huggingface.co/qqaatw/realm-orqa-wq-reader/resolve/main/vocab.txt", + "google/realm-cc-news-pretrained-embedder": "https://huggingface.co/google/realm-cc-news-pretrained-embedder/resolve/main/vocab.txt", + "google/realm-cc-news-pretrained-encoder": "https://huggingface.co/google/realm-cc-news-pretrained-encoder/resolve/main/vocab.txt", + "google/realm-cc-news-pretrained-scorer": "https://huggingface.co/google/realm-cc-news-pretrained-scorer/resolve/main/vocab.txt", + "google/realm-cc-news-pretrained-openqa": "https://huggingface.co/google/realm-cc-news-pretrained-openqa/aresolve/main/vocab.txt", + "google/realm-orqa-nq-openqa": "https://huggingface.co/google/realm-orqa-nq-openqa/resolve/main/vocab.txt", + "google/realm-orqa-nq-reader": "https://huggingface.co/google/realm-orqa-nq-reader/resolve/main/vocab.txt", + "google/realm-orqa-wq-openqa": "https://huggingface.co/google/realm-orqa-wq-openqa/resolve/main/vocab.txt", + "google/realm-orqa-wq-reader": "https://huggingface.co/google/realm-orqa-wq-reader/resolve/main/vocab.txt", }, "tokenizer_file": { - "qqaatw/realm-cc-news-pretrained-embedder": "https://huggingface.co/qqaatw/realm-cc-news-pretrained-embedder/resolve/main/tokenizer.jsont", - "qqaatw/realm-cc-news-pretrained-encoder": "https://huggingface.co/qqaatw/realm-cc-news-pretrained-encoder/resolve/main/tokenizer.json", - "qqaatw/realm-cc-news-pretrained-scorer": "https://huggingface.co/qqaatw/realm-cc-news-pretrained-scorer/resolve/main/tokenizer.json", - "qqaatw/realm-cc-news-pretrained-openqa": "https://huggingface.co/qqaatw/realm-cc-news-pretrained-openqa/aresolve/main/tokenizer.json", - "qqaatw/realm-orqa-nq-openqa": "https://huggingface.co/qqaatw/realm-orqa-nq-openqa/resolve/main/tokenizer.json", - "qqaatw/realm-orqa-nq-reader": "https://huggingface.co/qqaatw/realm-orqa-nq-reader/resolve/main/tokenizer.json", - "qqaatw/realm-orqa-wq-openqa": "https://huggingface.co/qqaatw/realm-orqa-wq-openqa/resolve/main/tokenizer.json", - "qqaatw/realm-orqa-wq-reader": "https://huggingface.co/qqaatw/realm-orqa-wq-reader/resolve/main/tokenizer.json", + "google/realm-cc-news-pretrained-embedder": "https://huggingface.co/google/realm-cc-news-pretrained-embedder/resolve/main/tokenizer.jsont", + "google/realm-cc-news-pretrained-encoder": "https://huggingface.co/google/realm-cc-news-pretrained-encoder/resolve/main/tokenizer.json", + "google/realm-cc-news-pretrained-scorer": "https://huggingface.co/google/realm-cc-news-pretrained-scorer/resolve/main/tokenizer.json", + "google/realm-cc-news-pretrained-openqa": "https://huggingface.co/google/realm-cc-news-pretrained-openqa/aresolve/main/tokenizer.json", + "google/realm-orqa-nq-openqa": "https://huggingface.co/google/realm-orqa-nq-openqa/resolve/main/tokenizer.json", + "google/realm-orqa-nq-reader": "https://huggingface.co/google/realm-orqa-nq-reader/resolve/main/tokenizer.json", + "google/realm-orqa-wq-openqa": "https://huggingface.co/google/realm-orqa-wq-openqa/resolve/main/tokenizer.json", + "google/realm-orqa-wq-reader": "https://huggingface.co/google/realm-orqa-wq-reader/resolve/main/tokenizer.json", }, } PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = { - "qqaatw/realm-cc-news-pretrained-embedder": 512, - "qqaatw/realm-cc-news-pretrained-encoder": 512, - "qqaatw/realm-cc-news-pretrained-scorer": 512, - "qqaatw/realm-cc-news-pretrained-openqa": 512, - "qqaatw/realm-orqa-nq-openqa": 512, - "qqaatw/realm-orqa-nq-reader": 512, - "qqaatw/realm-orqa-wq-openqa": 512, - "qqaatw/realm-orqa-wq-reader": 512, + "google/realm-cc-news-pretrained-embedder": 512, + "google/realm-cc-news-pretrained-encoder": 512, + "google/realm-cc-news-pretrained-scorer": 512, + "google/realm-cc-news-pretrained-openqa": 512, + "google/realm-orqa-nq-openqa": 512, + "google/realm-orqa-nq-reader": 512, + "google/realm-orqa-wq-openqa": 512, + "google/realm-orqa-wq-reader": 512, } PRETRAINED_INIT_CONFIGURATION = { - "qqaatw/realm-cc-news-pretrained-embedder": {"do_lower_case": True}, - "qqaatw/realm-cc-news-pretrained-encoder": {"do_lower_case": True}, - "qqaatw/realm-cc-news-pretrained-scorer": {"do_lower_case": True}, - "qqaatw/realm-cc-news-pretrained-openqa": {"do_lower_case": True}, - "qqaatw/realm-orqa-nq-openqa": {"do_lower_case": True}, - "qqaatw/realm-orqa-nq-reader": {"do_lower_case": True}, - "qqaatw/realm-orqa-wq-openqa": {"do_lower_case": True}, - "qqaatw/realm-orqa-wq-reader": {"do_lower_case": True}, + "google/realm-cc-news-pretrained-embedder": {"do_lower_case": True}, + "google/realm-cc-news-pretrained-encoder": {"do_lower_case": True}, + "google/realm-cc-news-pretrained-scorer": {"do_lower_case": True}, + "google/realm-cc-news-pretrained-openqa": {"do_lower_case": True}, + "google/realm-orqa-nq-openqa": {"do_lower_case": True}, + "google/realm-orqa-nq-reader": {"do_lower_case": True}, + "google/realm-orqa-wq-openqa": {"do_lower_case": True}, + "google/realm-orqa-wq-reader": {"do_lower_case": True}, } @@ -200,7 +200,7 @@ class RealmTokenizerFast(PreTrainedTokenizerFast): >>> # batch_size = 2, num_candidates = 2 >>> text = [["Hello world!", "Nice to meet you!"], ["The cute cat.", "The adorable dog."]] - >>> tokenizer = RealmTokenizerFast.from_pretrained("qqaatw/realm-cc-news-pretrained-encoder") + >>> tokenizer = RealmTokenizerFast.from_pretrained("google/realm-cc-news-pretrained-encoder") >>> tokenized_text = tokenizer.batch_encode_candidates(text, max_length=10, return_tensors="pt") ```""" diff --git a/tests/test_modeling_realm.py b/tests/test_modeling_realm.py index 3126c7186..d9d0ede90 100644 --- a/tests/test_modeling_realm.py +++ b/tests/test_modeling_realm.py @@ -358,7 +358,7 @@ class RealmModelTest(ModelTesterMixin, unittest.TestCase): input_ids, token_type_ids, input_mask, scorer_encoder_inputs = inputs[0:4] config.return_dict = True - tokenizer = RealmTokenizer.from_pretrained("qqaatw/realm-orqa-nq-openqa") + tokenizer = RealmTokenizer.from_pretrained("google/realm-orqa-nq-openqa") # RealmKnowledgeAugEncoder training model = RealmKnowledgeAugEncoder(config) @@ -411,27 +411,27 @@ class RealmModelTest(ModelTesterMixin, unittest.TestCase): @slow def test_embedder_from_pretrained(self): - model = RealmEmbedder.from_pretrained("qqaatw/realm-cc-news-pretrained-embedder") + model = RealmEmbedder.from_pretrained("google/realm-cc-news-pretrained-embedder") self.assertIsNotNone(model) @slow def test_encoder_from_pretrained(self): - model = RealmKnowledgeAugEncoder.from_pretrained("qqaatw/realm-cc-news-pretrained-encoder") + model = RealmKnowledgeAugEncoder.from_pretrained("google/realm-cc-news-pretrained-encoder") self.assertIsNotNone(model) @slow def test_open_qa_from_pretrained(self): - model = RealmForOpenQA.from_pretrained("qqaatw/realm-orqa-nq-openqa") + model = RealmForOpenQA.from_pretrained("google/realm-orqa-nq-openqa") self.assertIsNotNone(model) @slow def test_reader_from_pretrained(self): - model = RealmReader.from_pretrained("qqaatw/realm-orqa-nq-reader") + model = RealmReader.from_pretrained("google/realm-orqa-nq-reader") self.assertIsNotNone(model) @slow def test_scorer_from_pretrained(self): - model = RealmScorer.from_pretrained("qqaatw/realm-cc-news-pretrained-scorer") + model = RealmScorer.from_pretrained("google/realm-cc-news-pretrained-scorer") self.assertIsNotNone(model) @@ -441,7 +441,7 @@ class RealmModelIntegrationTest(unittest.TestCase): def test_inference_embedder(self): retriever_projected_size = 128 - model = RealmEmbedder.from_pretrained("qqaatw/realm-cc-news-pretrained-embedder") + model = RealmEmbedder.from_pretrained("google/realm-cc-news-pretrained-embedder") input_ids = torch.tensor([[0, 1, 2, 3, 4, 5]]) output = model(input_ids)[0] @@ -457,7 +457,7 @@ class RealmModelIntegrationTest(unittest.TestCase): vocab_size = 30522 model = RealmKnowledgeAugEncoder.from_pretrained( - "qqaatw/realm-cc-news-pretrained-encoder", num_candidates=num_candidates + "google/realm-cc-news-pretrained-encoder", num_candidates=num_candidates ) input_ids = torch.tensor([[0, 1, 2, 3, 4, 5], [6, 7, 8, 9, 10, 11]]) relevance_score = torch.tensor([[0.3, 0.7]], dtype=torch.float32) @@ -476,11 +476,11 @@ class RealmModelIntegrationTest(unittest.TestCase): config = RealmConfig() - tokenizer = RealmTokenizer.from_pretrained("qqaatw/realm-orqa-nq-openqa") - retriever = RealmRetriever.from_pretrained("qqaatw/realm-orqa-nq-openqa") + tokenizer = RealmTokenizer.from_pretrained("google/realm-orqa-nq-openqa") + retriever = RealmRetriever.from_pretrained("google/realm-orqa-nq-openqa") model = RealmForOpenQA.from_pretrained( - "qqaatw/realm-orqa-nq-openqa", + "google/realm-orqa-nq-openqa", retriever=retriever, config=config, ) @@ -503,7 +503,7 @@ class RealmModelIntegrationTest(unittest.TestCase): @slow def test_inference_reader(self): config = RealmConfig(reader_beam_size=2, max_span_width=3) - model = RealmReader.from_pretrained("qqaatw/realm-orqa-nq-reader", config=config) + model = RealmReader.from_pretrained("google/realm-orqa-nq-reader", config=config) concat_input_ids = torch.arange(10).view((2, 5)) concat_token_type_ids = torch.tensor([[0, 0, 1, 1, 1], [0, 0, 1, 1, 1]], dtype=torch.int64) @@ -532,7 +532,7 @@ class RealmModelIntegrationTest(unittest.TestCase): def test_inference_scorer(self): num_candidates = 2 - model = RealmScorer.from_pretrained("qqaatw/realm-cc-news-pretrained-scorer", num_candidates=num_candidates) + model = RealmScorer.from_pretrained("google/realm-cc-news-pretrained-scorer", num_candidates=num_candidates) input_ids = torch.tensor([[0, 1, 2, 3, 4, 5]]) candidate_input_ids = torch.tensor([[0, 1, 2, 3, 4, 5], [6, 7, 8, 9, 10, 11]]) diff --git a/tests/test_retrieval_realm.py b/tests/test_retrieval_realm.py index 2813f31a3..3ffefef16 100644 --- a/tests/test_retrieval_realm.py +++ b/tests/test_retrieval_realm.py @@ -180,6 +180,6 @@ class RealmRetrieverTest(TestCase): mock_hf_hub_download.return_value = os.path.join( os.path.join(self.tmpdirname, "realm_block_records"), _REALM_BLOCK_RECORDS_FILENAME ) - retriever = RealmRetriever.from_pretrained("qqaatw/realm-cc-news-pretrained-openqa") + retriever = RealmRetriever.from_pretrained("google/realm-cc-news-pretrained-openqa") self.assertEqual(retriever.block_records[0], b"This is the first record")