mirror of
https://github.com/saymrwulf/transformers.git
synced 2026-05-14 20:58:08 +00:00
fix backend tokenizer args override: key mismatch (#10686)
* fix backend tokenizer args override: key mismatch * no touching the docs * fix mpnet * add mpnet to test * fix test Co-authored-by: theo <theo@matussie.re>
This commit is contained in:
parent
427ea3fecb
commit
117dba9948
3 changed files with 15 additions and 4 deletions
|
|
@ -190,11 +190,11 @@ class BertTokenizerFast(PreTrainedTokenizerFast):
|
|||
|
||||
pre_tok_state = json.loads(self.backend_tokenizer.normalizer.__getstate__())
|
||||
if (
|
||||
pre_tok_state.get("do_lower_case", do_lower_case) != do_lower_case
|
||||
pre_tok_state.get("lowercase", do_lower_case) != do_lower_case
|
||||
or pre_tok_state.get("strip_accents", strip_accents) != strip_accents
|
||||
):
|
||||
pre_tok_class = getattr(normalizers, pre_tok_state.pop("type"))
|
||||
pre_tok_state["do_lower_case"] = do_lower_case
|
||||
pre_tok_state["lowercase"] = do_lower_case
|
||||
pre_tok_state["strip_accents"] = strip_accents
|
||||
self.backend_tokenizer.normalizer = pre_tok_class(**pre_tok_state)
|
||||
|
||||
|
|
|
|||
|
|
@ -138,11 +138,11 @@ class MPNetTokenizerFast(PreTrainedTokenizerFast):
|
|||
|
||||
pre_tok_state = json.loads(self.backend_tokenizer.normalizer.__getstate__())
|
||||
if (
|
||||
pre_tok_state.get("do_lower_case", do_lower_case) != do_lower_case
|
||||
pre_tok_state.get("lowercase", do_lower_case) != do_lower_case
|
||||
or pre_tok_state.get("strip_accents", strip_accents) != strip_accents
|
||||
):
|
||||
pre_tok_class = getattr(normalizers, pre_tok_state.pop("type"))
|
||||
pre_tok_state["do_lower_case"] = do_lower_case
|
||||
pre_tok_state["lowercase"] = do_lower_case
|
||||
pre_tok_state["strip_accents"] = strip_accents
|
||||
self.backend_tokenizer.normalizer = pre_tok_class(**pre_tok_state)
|
||||
|
||||
|
|
|
|||
|
|
@ -110,3 +110,14 @@ class AutoTokenizerTest(unittest.TestCase):
|
|||
def test_from_pretrained_use_fast_toggle(self):
|
||||
self.assertIsInstance(AutoTokenizer.from_pretrained("bert-base-cased", use_fast=False), BertTokenizer)
|
||||
self.assertIsInstance(AutoTokenizer.from_pretrained("bert-base-cased"), BertTokenizerFast)
|
||||
|
||||
@require_tokenizers
|
||||
def test_do_lower_case(self):
|
||||
tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased", do_lower_case=False)
|
||||
sample = "Hello, world. How are you?"
|
||||
tokens = tokenizer.tokenize(sample)
|
||||
self.assertEqual("[UNK]", tokens[0])
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained("microsoft/mpnet-base", do_lower_case=False)
|
||||
tokens = tokenizer.tokenize(sample)
|
||||
self.assertEqual("[UNK]", tokens[0])
|
||||
|
|
|
|||
Loading…
Reference in a new issue