From ed71c21d6afcbfa2d8e5bb03acbb88ae0e0ea56a Mon Sep 17 00:00:00 2001 From: Julien Chaumond Date: Wed, 9 Sep 2020 10:22:59 +0200 Subject: [PATCH] =?UTF-8?q?[from=5Fpretrained]=20Allow=20tokenizer=5Ftype?= =?UTF-8?q?=20=E2=89=A0=20model=5Ftype=20(#6995)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/transformers/configuration_utils.py | 1 + src/transformers/testing_utils.py | 1 + src/transformers/tokenization_auto.py | 11 +++++++++++ tests/test_tokenization_auto.py | 16 +++++++++++++++- 4 files changed, 28 insertions(+), 1 deletion(-) diff --git a/src/transformers/configuration_utils.py b/src/transformers/configuration_utils.py index c1a9fc0c7..5f68bf12f 100755 --- a/src/transformers/configuration_utils.py +++ b/src/transformers/configuration_utils.py @@ -190,6 +190,7 @@ class PretrainedConfig(object): self.num_labels = kwargs.pop("num_labels", 2) # Tokenizer arguments TODO: eventually tokenizer and models should share the same config + self.tokenizer_class = kwargs.pop("tokenizer_class", None) self.prefix = kwargs.pop("prefix", None) self.bos_token_id = kwargs.pop("bos_token_id", None) self.pad_token_id = kwargs.pop("pad_token_id", None) diff --git a/src/transformers/testing_utils.py b/src/transformers/testing_utils.py index af2333c17..9944aa9d3 100644 --- a/src/transformers/testing_utils.py +++ b/src/transformers/testing_utils.py @@ -15,6 +15,7 @@ from .file_utils import _tf_available, _torch_available, _torch_tpu_available SMALL_MODEL_IDENTIFIER = "julien-c/bert-xsmall-dummy" DUMMY_UNKWOWN_IDENTIFIER = "julien-c/dummy-unknown" +DUMMY_DIFF_TOKENIZER_IDENTIFIER = "julien-c/dummy-diff-tokenizer" # Used to test Auto{Config, Model, Tokenizer} model_type detection. diff --git a/src/transformers/tokenization_auto.py b/src/transformers/tokenization_auto.py index 7dc44e6bf..8925539ef 100644 --- a/src/transformers/tokenization_auto.py +++ b/src/transformers/tokenization_auto.py @@ -222,6 +222,17 @@ class AutoTokenizer: return BertJapaneseTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs) use_fast = kwargs.pop("use_fast", False) + + if config.tokenizer_class is not None: + if use_fast and not config.tokenizer_class.endswith("Fast"): + tokenizer_class_candidate = f"{config.tokenizer_class}Fast" + else: + tokenizer_class_candidate = config.tokenizer_class + tokenizer_class = globals().get(tokenizer_class_candidate) + if tokenizer_class is None: + raise ValueError("Tokenizer class {} does not exist or is not currently imported.") + return tokenizer_class.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs) + for config_class, (tokenizer_class_py, tokenizer_class_fast) in TOKENIZER_MAPPING.items(): if isinstance(config, config_class): if tokenizer_class_fast and use_fast: diff --git a/tests/test_tokenization_auto.py b/tests/test_tokenization_auto.py index 54bfb2e13..524a22824 100644 --- a/tests/test_tokenization_auto.py +++ b/tests/test_tokenization_auto.py @@ -27,7 +27,13 @@ from transformers import ( RobertaTokenizer, RobertaTokenizerFast, ) -from transformers.testing_utils import DUMMY_UNKWOWN_IDENTIFIER, SMALL_MODEL_IDENTIFIER # noqa: F401 +from transformers.configuration_auto import AutoConfig +from transformers.configuration_roberta import RobertaConfig +from transformers.testing_utils import ( + DUMMY_DIFF_TOKENIZER_IDENTIFIER, + DUMMY_UNKWOWN_IDENTIFIER, + SMALL_MODEL_IDENTIFIER, +) from transformers.tokenization_auto import TOKENIZER_MAPPING @@ -56,6 +62,14 @@ class AutoTokenizerTest(unittest.TestCase): self.assertIsInstance(tokenizer, (RobertaTokenizer, RobertaTokenizerFast)) self.assertEqual(tokenizer.vocab_size, 20) + def test_tokenizer_from_tokenizer_class(self): + config = AutoConfig.from_pretrained(DUMMY_DIFF_TOKENIZER_IDENTIFIER) + self.assertIsInstance(config, RobertaConfig) + # Check that tokenizer_type ≠ model_type + tokenizer = AutoTokenizer.from_pretrained(DUMMY_DIFF_TOKENIZER_IDENTIFIER, config=config) + self.assertIsInstance(tokenizer, (BertTokenizer, BertTokenizerFast)) + self.assertEqual(tokenizer.vocab_size, 12) + def test_tokenizer_identifier_with_correct_config(self): for tokenizer_class in [BertTokenizer, BertTokenizerFast, AutoTokenizer]: tokenizer = tokenizer_class.from_pretrained("wietsedv/bert-base-dutch-cased")