diff --git a/src/transformers/models/bert/tokenization_bert_fast.py b/src/transformers/models/bert/tokenization_bert_fast.py index f93446c35..e477cf7af 100644 --- a/src/transformers/models/bert/tokenization_bert_fast.py +++ b/src/transformers/models/bert/tokenization_bert_fast.py @@ -190,11 +190,11 @@ class BertTokenizerFast(PreTrainedTokenizerFast): pre_tok_state = json.loads(self.backend_tokenizer.normalizer.__getstate__()) if ( - pre_tok_state.get("do_lower_case", do_lower_case) != do_lower_case + pre_tok_state.get("lowercase", do_lower_case) != do_lower_case or pre_tok_state.get("strip_accents", strip_accents) != strip_accents ): pre_tok_class = getattr(normalizers, pre_tok_state.pop("type")) - pre_tok_state["do_lower_case"] = do_lower_case + pre_tok_state["lowercase"] = do_lower_case pre_tok_state["strip_accents"] = strip_accents self.backend_tokenizer.normalizer = pre_tok_class(**pre_tok_state) diff --git a/src/transformers/models/mpnet/tokenization_mpnet_fast.py b/src/transformers/models/mpnet/tokenization_mpnet_fast.py index 8f35528b9..07547fce5 100644 --- a/src/transformers/models/mpnet/tokenization_mpnet_fast.py +++ b/src/transformers/models/mpnet/tokenization_mpnet_fast.py @@ -138,11 +138,11 @@ class MPNetTokenizerFast(PreTrainedTokenizerFast): pre_tok_state = json.loads(self.backend_tokenizer.normalizer.__getstate__()) if ( - pre_tok_state.get("do_lower_case", do_lower_case) != do_lower_case + pre_tok_state.get("lowercase", do_lower_case) != do_lower_case or pre_tok_state.get("strip_accents", strip_accents) != strip_accents ): pre_tok_class = getattr(normalizers, pre_tok_state.pop("type")) - pre_tok_state["do_lower_case"] = do_lower_case + pre_tok_state["lowercase"] = do_lower_case pre_tok_state["strip_accents"] = strip_accents self.backend_tokenizer.normalizer = pre_tok_class(**pre_tok_state) diff --git a/tests/test_tokenization_auto.py b/tests/test_tokenization_auto.py index 71c5f29f4..d632cbc55 100644 --- a/tests/test_tokenization_auto.py +++ b/tests/test_tokenization_auto.py @@ -110,3 +110,14 @@ class AutoTokenizerTest(unittest.TestCase): def test_from_pretrained_use_fast_toggle(self): self.assertIsInstance(AutoTokenizer.from_pretrained("bert-base-cased", use_fast=False), BertTokenizer) self.assertIsInstance(AutoTokenizer.from_pretrained("bert-base-cased"), BertTokenizerFast) + + @require_tokenizers + def test_do_lower_case(self): + tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased", do_lower_case=False) + sample = "Hello, world. How are you?" + tokens = tokenizer.tokenize(sample) + self.assertEqual("[UNK]", tokens[0]) + + tokenizer = AutoTokenizer.from_pretrained("microsoft/mpnet-base", do_lower_case=False) + tokens = tokenizer.tokenize(sample) + self.assertEqual("[UNK]", tokens[0])