mirror of
https://github.com/saymrwulf/transformers.git
synced 2026-05-14 20:58:08 +00:00
[from_pretrained] Allow tokenizer_type ≠ model_type (#6995)
This commit is contained in:
parent
03e363f9ae
commit
ed71c21d6a
4 changed files with 28 additions and 1 deletions
|
|
@ -190,6 +190,7 @@ class PretrainedConfig(object):
|
|||
self.num_labels = kwargs.pop("num_labels", 2)
|
||||
|
||||
# Tokenizer arguments TODO: eventually tokenizer and models should share the same config
|
||||
self.tokenizer_class = kwargs.pop("tokenizer_class", None)
|
||||
self.prefix = kwargs.pop("prefix", None)
|
||||
self.bos_token_id = kwargs.pop("bos_token_id", None)
|
||||
self.pad_token_id = kwargs.pop("pad_token_id", None)
|
||||
|
|
|
|||
|
|
@ -15,6 +15,7 @@ from .file_utils import _tf_available, _torch_available, _torch_tpu_available
|
|||
|
||||
SMALL_MODEL_IDENTIFIER = "julien-c/bert-xsmall-dummy"
|
||||
DUMMY_UNKWOWN_IDENTIFIER = "julien-c/dummy-unknown"
|
||||
DUMMY_DIFF_TOKENIZER_IDENTIFIER = "julien-c/dummy-diff-tokenizer"
|
||||
# Used to test Auto{Config, Model, Tokenizer} model_type detection.
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -222,6 +222,17 @@ class AutoTokenizer:
|
|||
return BertJapaneseTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
|
||||
|
||||
use_fast = kwargs.pop("use_fast", False)
|
||||
|
||||
if config.tokenizer_class is not None:
|
||||
if use_fast and not config.tokenizer_class.endswith("Fast"):
|
||||
tokenizer_class_candidate = f"{config.tokenizer_class}Fast"
|
||||
else:
|
||||
tokenizer_class_candidate = config.tokenizer_class
|
||||
tokenizer_class = globals().get(tokenizer_class_candidate)
|
||||
if tokenizer_class is None:
|
||||
raise ValueError("Tokenizer class {} does not exist or is not currently imported.")
|
||||
return tokenizer_class.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
|
||||
|
||||
for config_class, (tokenizer_class_py, tokenizer_class_fast) in TOKENIZER_MAPPING.items():
|
||||
if isinstance(config, config_class):
|
||||
if tokenizer_class_fast and use_fast:
|
||||
|
|
|
|||
|
|
@ -27,7 +27,13 @@ from transformers import (
|
|||
RobertaTokenizer,
|
||||
RobertaTokenizerFast,
|
||||
)
|
||||
from transformers.testing_utils import DUMMY_UNKWOWN_IDENTIFIER, SMALL_MODEL_IDENTIFIER # noqa: F401
|
||||
from transformers.configuration_auto import AutoConfig
|
||||
from transformers.configuration_roberta import RobertaConfig
|
||||
from transformers.testing_utils import (
|
||||
DUMMY_DIFF_TOKENIZER_IDENTIFIER,
|
||||
DUMMY_UNKWOWN_IDENTIFIER,
|
||||
SMALL_MODEL_IDENTIFIER,
|
||||
)
|
||||
from transformers.tokenization_auto import TOKENIZER_MAPPING
|
||||
|
||||
|
||||
|
|
@ -56,6 +62,14 @@ class AutoTokenizerTest(unittest.TestCase):
|
|||
self.assertIsInstance(tokenizer, (RobertaTokenizer, RobertaTokenizerFast))
|
||||
self.assertEqual(tokenizer.vocab_size, 20)
|
||||
|
||||
def test_tokenizer_from_tokenizer_class(self):
|
||||
config = AutoConfig.from_pretrained(DUMMY_DIFF_TOKENIZER_IDENTIFIER)
|
||||
self.assertIsInstance(config, RobertaConfig)
|
||||
# Check that tokenizer_type ≠ model_type
|
||||
tokenizer = AutoTokenizer.from_pretrained(DUMMY_DIFF_TOKENIZER_IDENTIFIER, config=config)
|
||||
self.assertIsInstance(tokenizer, (BertTokenizer, BertTokenizerFast))
|
||||
self.assertEqual(tokenizer.vocab_size, 12)
|
||||
|
||||
def test_tokenizer_identifier_with_correct_config(self):
|
||||
for tokenizer_class in [BertTokenizer, BertTokenizerFast, AutoTokenizer]:
|
||||
tokenizer = tokenizer_class.from_pretrained("wietsedv/bert-base-dutch-cased")
|
||||
|
|
|
|||
Loading…
Reference in a new issue