From 30b3c46ff5a7c9761a800a9ab4bcf8cdb206727e Mon Sep 17 00:00:00 2001 From: Arthur <48595927+ArthurZucker@users.noreply.github.com> Date: Fri, 18 Aug 2023 13:26:27 +0200 Subject: [PATCH] [`split_special_tokens`] Add support for `split_special_tokens` argument to encode (#25081) * draft changes * update and add tests * styling for no * move test * path to usable model * update test * small update * update bertbased tokenizers * don'tuse kwargs for _tokenize * don'tuse kwargs for _tokenize * fix copies * update * update test for special tokenizers * fixup * skip two tests * remove pdb breakpiont() * wowo * rewrite custom tests * nits * revert chang in target keys * fix markup lm * update documentation of the argument --- .../models/bert/tokenization_bert.py | 6 +++-- .../models/convbert/tokenization_convbert.py | 6 +++-- .../retribert/tokenization_retribert.py | 6 +++-- .../distilbert/tokenization_distilbert.py | 6 +++-- .../models/electra/tokenization_electra.py | 6 +++-- .../models/funnel/tokenization_funnel.py | 6 +++-- .../models/layoutlm/tokenization_layoutlm.py | 6 +++-- .../models/lxmert/tokenization_lxmert.py | 6 +++-- .../mobilebert/tokenization_mobilebert.py | 6 +++-- .../models/roc_bert/tokenization_roc_bert.py | 6 +++-- .../squeezebert/tokenization_squeezebert.py | 6 +++-- src/transformers/tokenization_utils.py | 11 ++++++-- src/transformers/tokenization_utils_base.py | 8 ++++++ .../test_tokenization_layoutlmv2.py | 4 +++ .../test_tokenization_layoutlmv3.py | 4 +++ .../layoutxlm/test_tokenization_layoutxlm.py | 13 +++++++++ .../markuplm/test_tokenization_markuplm.py | 13 +++++++++ tests/test_tokenization_common.py | 27 +++++++++++++++++++ 18 files changed, 122 insertions(+), 24 deletions(-) diff --git a/src/transformers/models/bert/tokenization_bert.py b/src/transformers/models/bert/tokenization_bert.py index 536eb0864..a24f39564 100644 --- a/src/transformers/models/bert/tokenization_bert.py +++ b/src/transformers/models/bert/tokenization_bert.py @@ -238,10 +238,12 @@ class BertTokenizer(PreTrainedTokenizer): def get_vocab(self): return dict(self.vocab, **self.added_tokens_encoder) - def _tokenize(self, text): + def _tokenize(self, text, split_special_tokens=False): split_tokens = [] if self.do_basic_tokenize: - for token in self.basic_tokenizer.tokenize(text, never_split=self.all_special_tokens): + for token in self.basic_tokenizer.tokenize( + text, never_split=self.all_special_tokens if not split_special_tokens else None + ): # If the token is part of the never_split set if token in self.basic_tokenizer.never_split: split_tokens.append(token) diff --git a/src/transformers/models/convbert/tokenization_convbert.py b/src/transformers/models/convbert/tokenization_convbert.py index 4fbed8fe1..800848caa 100644 --- a/src/transformers/models/convbert/tokenization_convbert.py +++ b/src/transformers/models/convbert/tokenization_convbert.py @@ -177,10 +177,12 @@ class ConvBertTokenizer(PreTrainedTokenizer): def get_vocab(self): return dict(self.vocab, **self.added_tokens_encoder) - def _tokenize(self, text): + def _tokenize(self, text, split_special_tokens=False): split_tokens = [] if self.do_basic_tokenize: - for token in self.basic_tokenizer.tokenize(text, never_split=self.all_special_tokens): + for token in self.basic_tokenizer.tokenize( + text, never_split=self.all_special_tokens if not split_special_tokens else None + ): # If the token is part of the never_split set if token in self.basic_tokenizer.never_split: split_tokens.append(token) diff --git a/src/transformers/models/deprecated/retribert/tokenization_retribert.py b/src/transformers/models/deprecated/retribert/tokenization_retribert.py index 4529e8e90..de50c74b7 100644 --- a/src/transformers/models/deprecated/retribert/tokenization_retribert.py +++ b/src/transformers/models/deprecated/retribert/tokenization_retribert.py @@ -178,10 +178,12 @@ class RetriBertTokenizer(PreTrainedTokenizer): return dict(self.vocab, **self.added_tokens_encoder) # Copied from transformers.models.bert.tokenization_bert.BertTokenizer._tokenize - def _tokenize(self, text): + def _tokenize(self, text, split_special_tokens=False): split_tokens = [] if self.do_basic_tokenize: - for token in self.basic_tokenizer.tokenize(text, never_split=self.all_special_tokens): + for token in self.basic_tokenizer.tokenize( + text, never_split=self.all_special_tokens if not split_special_tokens else None + ): # If the token is part of the never_split set if token in self.basic_tokenizer.never_split: split_tokens.append(token) diff --git a/src/transformers/models/distilbert/tokenization_distilbert.py b/src/transformers/models/distilbert/tokenization_distilbert.py index 025968258..5e96e4972 100644 --- a/src/transformers/models/distilbert/tokenization_distilbert.py +++ b/src/transformers/models/distilbert/tokenization_distilbert.py @@ -195,10 +195,12 @@ class DistilBertTokenizer(PreTrainedTokenizer): return dict(self.vocab, **self.added_tokens_encoder) # Copied from transformers.models.bert.tokenization_bert.BertTokenizer._tokenize - def _tokenize(self, text): + def _tokenize(self, text, split_special_tokens=False): split_tokens = [] if self.do_basic_tokenize: - for token in self.basic_tokenizer.tokenize(text, never_split=self.all_special_tokens): + for token in self.basic_tokenizer.tokenize( + text, never_split=self.all_special_tokens if not split_special_tokens else None + ): # If the token is part of the never_split set if token in self.basic_tokenizer.never_split: split_tokens.append(token) diff --git a/src/transformers/models/electra/tokenization_electra.py b/src/transformers/models/electra/tokenization_electra.py index e202f773e..aabeccba7 100644 --- a/src/transformers/models/electra/tokenization_electra.py +++ b/src/transformers/models/electra/tokenization_electra.py @@ -194,10 +194,12 @@ class ElectraTokenizer(PreTrainedTokenizer): def get_vocab(self): return dict(self.vocab, **self.added_tokens_encoder) - def _tokenize(self, text): + def _tokenize(self, text, split_special_tokens=False): split_tokens = [] if self.do_basic_tokenize: - for token in self.basic_tokenizer.tokenize(text, never_split=self.all_special_tokens): + for token in self.basic_tokenizer.tokenize( + text, never_split=self.all_special_tokens if not split_special_tokens else None + ): # If the token is part of the never_split set if token in self.basic_tokenizer.never_split: split_tokens.append(token) diff --git a/src/transformers/models/funnel/tokenization_funnel.py b/src/transformers/models/funnel/tokenization_funnel.py index f085fd7c4..37a913d0a 100644 --- a/src/transformers/models/funnel/tokenization_funnel.py +++ b/src/transformers/models/funnel/tokenization_funnel.py @@ -205,10 +205,12 @@ class FunnelTokenizer(PreTrainedTokenizer): return dict(self.vocab, **self.added_tokens_encoder) # Copied from transformers.models.bert.tokenization_bert.BertTokenizer._tokenize - def _tokenize(self, text): + def _tokenize(self, text, split_special_tokens=False): split_tokens = [] if self.do_basic_tokenize: - for token in self.basic_tokenizer.tokenize(text, never_split=self.all_special_tokens): + for token in self.basic_tokenizer.tokenize( + text, never_split=self.all_special_tokens if not split_special_tokens else None + ): # If the token is part of the never_split set if token in self.basic_tokenizer.never_split: split_tokens.append(token) diff --git a/src/transformers/models/layoutlm/tokenization_layoutlm.py b/src/transformers/models/layoutlm/tokenization_layoutlm.py index 57c29d587..b51887422 100644 --- a/src/transformers/models/layoutlm/tokenization_layoutlm.py +++ b/src/transformers/models/layoutlm/tokenization_layoutlm.py @@ -176,10 +176,12 @@ class LayoutLMTokenizer(PreTrainedTokenizer): def get_vocab(self): return dict(self.vocab, **self.added_tokens_encoder) - def _tokenize(self, text): + def _tokenize(self, text, split_special_tokens=False): split_tokens = [] if self.do_basic_tokenize: - for token in self.basic_tokenizer.tokenize(text, never_split=self.all_special_tokens): + for token in self.basic_tokenizer.tokenize( + text, never_split=self.all_special_tokens if not split_special_tokens else None + ): # If the token is part of the never_split set if token in self.basic_tokenizer.never_split: split_tokens.append(token) diff --git a/src/transformers/models/lxmert/tokenization_lxmert.py b/src/transformers/models/lxmert/tokenization_lxmert.py index daa761878..e651b8f44 100644 --- a/src/transformers/models/lxmert/tokenization_lxmert.py +++ b/src/transformers/models/lxmert/tokenization_lxmert.py @@ -168,10 +168,12 @@ class LxmertTokenizer(PreTrainedTokenizer): def get_vocab(self): return dict(self.vocab, **self.added_tokens_encoder) - def _tokenize(self, text): + def _tokenize(self, text, split_special_tokens=False): split_tokens = [] if self.do_basic_tokenize: - for token in self.basic_tokenizer.tokenize(text, never_split=self.all_special_tokens): + for token in self.basic_tokenizer.tokenize( + text, never_split=self.all_special_tokens if not split_special_tokens else None + ): # If the token is part of the never_split set if token in self.basic_tokenizer.never_split: split_tokens.append(token) diff --git a/src/transformers/models/mobilebert/tokenization_mobilebert.py b/src/transformers/models/mobilebert/tokenization_mobilebert.py index 63c0ab28a..389e38bce 100644 --- a/src/transformers/models/mobilebert/tokenization_mobilebert.py +++ b/src/transformers/models/mobilebert/tokenization_mobilebert.py @@ -166,10 +166,12 @@ class MobileBertTokenizer(PreTrainedTokenizer): def get_vocab(self): return dict(self.vocab, **self.added_tokens_encoder) - def _tokenize(self, text): + def _tokenize(self, text, split_special_tokens=False): split_tokens = [] if self.do_basic_tokenize: - for token in self.basic_tokenizer.tokenize(text, never_split=self.all_special_tokens): + for token in self.basic_tokenizer.tokenize( + text, never_split=self.all_special_tokens if not split_special_tokens else None + ): # If the token is part of the never_split set if token in self.basic_tokenizer.never_split: split_tokens.append(token) diff --git a/src/transformers/models/roc_bert/tokenization_roc_bert.py b/src/transformers/models/roc_bert/tokenization_roc_bert.py index cee778dc8..d665b91a0 100644 --- a/src/transformers/models/roc_bert/tokenization_roc_bert.py +++ b/src/transformers/models/roc_bert/tokenization_roc_bert.py @@ -210,10 +210,12 @@ class RoCBertTokenizer(PreTrainedTokenizer): return dict(self.vocab, **self.added_tokens_encoder) # Copied from transformers.models.bert.tokenization_bert.BertTokenizer._tokenize - def _tokenize(self, text): + def _tokenize(self, text, split_special_tokens=False): split_tokens = [] if self.do_basic_tokenize: - for token in self.basic_tokenizer.tokenize(text, never_split=self.all_special_tokens): + for token in self.basic_tokenizer.tokenize( + text, never_split=self.all_special_tokens if not split_special_tokens else None + ): # If the token is part of the never_split set if token in self.basic_tokenizer.never_split: split_tokens.append(token) diff --git a/src/transformers/models/squeezebert/tokenization_squeezebert.py b/src/transformers/models/squeezebert/tokenization_squeezebert.py index ccce92809..f061a1a53 100644 --- a/src/transformers/models/squeezebert/tokenization_squeezebert.py +++ b/src/transformers/models/squeezebert/tokenization_squeezebert.py @@ -180,10 +180,12 @@ class SqueezeBertTokenizer(PreTrainedTokenizer): def get_vocab(self): return dict(self.vocab, **self.added_tokens_encoder) - def _tokenize(self, text): + def _tokenize(self, text, split_special_tokens=False): split_tokens = [] if self.do_basic_tokenize: - for token in self.basic_tokenizer.tokenize(text, never_split=self.all_special_tokens): + for token in self.basic_tokenizer.tokenize( + text, never_split=self.all_special_tokens if not split_special_tokens else None + ): # If the token is part of the never_split set if token in self.basic_tokenizer.never_split: split_tokens.append(token) diff --git a/src/transformers/tokenization_utils.py b/src/transformers/tokenization_utils.py index c1dd9c329..e26c0c6d5 100644 --- a/src/transformers/tokenization_utils.py +++ b/src/transformers/tokenization_utils.py @@ -498,6 +498,7 @@ class PreTrainedTokenizer(PreTrainedTokenizerBase): all_special_tokens_extended = { str(t): t for t in self.all_special_tokens_extended if isinstance(t, AddedToken) } + split_special_tokens = kwargs.pop("split_special_tokens", self.split_special_tokens) text, kwargs = self.prepare_for_tokenization(text, **kwargs) @@ -513,8 +514,14 @@ class PreTrainedTokenizer(PreTrainedTokenizerBase): pattern = r"(" + r"|".join(escaped_special_toks) + r")|" + r"(.+?)" text = re.sub(pattern, lambda m: m.groups()[0] or m.groups()[1].lower(), text) - no_split_token = set(self.unique_no_split_tokens) - tokens = self.tokens_trie.split(text) + # split_special_tokens: empty `no_split_token` + if split_special_tokens: + no_split_token = [] + tokens = [text] + else: + no_split_token = set(self.unique_no_split_tokens) + tokens = self.tokens_trie.split(text) + # ["This is something", "", " else"] for i, token in enumerate(tokens): if token in no_split_token: diff --git a/src/transformers/tokenization_utils_base.py b/src/transformers/tokenization_utils_base.py index c3d2c4eb8..0490bec39 100644 --- a/src/transformers/tokenization_utils_base.py +++ b/src/transformers/tokenization_utils_base.py @@ -1492,6 +1492,11 @@ INIT_TOKENIZER_DOCSTRING = r""" clean_up_tokenization_spaces (`bool`, *optional*, defaults to `True`): Whether or not the model should cleanup the spaces that were added when splitting the input text during the tokenization process. + split_special_tokens (`bool`, *optional*, defaults to `False`): + Whether or not the special tokens should be split during the tokenization process. The default behavior is + to not split special tokens. This means that if `` is the `bos_token`, then `tokenizer.tokenize("") = + ['`]. Otherwise, if `split_special_tokens=True`, then `tokenizer.tokenize("")` will be give `['<', + 's', '>']`. This argument is only supported for `slow` tokenizers for the moment. """ @@ -1546,6 +1551,9 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin): # By default, cleaning tokenization spaces for both fast and slow tokenizers self.clean_up_tokenization_spaces = kwargs.pop("clean_up_tokenization_spaces", True) + # By default, do not split special tokens for both fast and slow tokenizers + self.split_special_tokens = kwargs.pop("split_special_tokens", False) + self.deprecation_warnings = ( {} ) # Use to store when we have already noticed a deprecation warning (avoid overlogging). diff --git a/tests/models/layoutlmv2/test_tokenization_layoutlmv2.py b/tests/models/layoutlmv2/test_tokenization_layoutlmv2.py index 9224fbd87..942cceaf7 100644 --- a/tests/models/layoutlmv2/test_tokenization_layoutlmv2.py +++ b/tests/models/layoutlmv2/test_tokenization_layoutlmv2.py @@ -384,6 +384,10 @@ class LayoutLMv2TokenizationTest(TokenizerTesterMixin, unittest.TestCase): def test_right_and_left_truncation(self): pass + @unittest.skip("Not implemented") + def test_split_special_tokens(self): + pass + def test_encode_plus_with_padding(self): tokenizers = self.get_tokenizers(do_lower_case=False) for tokenizer in tokenizers: diff --git a/tests/models/layoutlmv3/test_tokenization_layoutlmv3.py b/tests/models/layoutlmv3/test_tokenization_layoutlmv3.py index 63d86f280..58092834e 100644 --- a/tests/models/layoutlmv3/test_tokenization_layoutlmv3.py +++ b/tests/models/layoutlmv3/test_tokenization_layoutlmv3.py @@ -264,6 +264,10 @@ class LayoutLMv3TokenizationTest(TokenizerTesterMixin, unittest.TestCase): def test_right_and_left_truncation(self): pass + @unittest.skip("Not implemented") + def test_split_special_tokens(self): + pass + def test_encode_plus_with_padding(self): tokenizers = self.get_tokenizers(do_lower_case=False) for tokenizer in tokenizers: diff --git a/tests/models/layoutxlm/test_tokenization_layoutxlm.py b/tests/models/layoutxlm/test_tokenization_layoutxlm.py index bf295c9c9..f7f832970 100644 --- a/tests/models/layoutxlm/test_tokenization_layoutxlm.py +++ b/tests/models/layoutxlm/test_tokenization_layoutxlm.py @@ -144,6 +144,19 @@ class LayoutXLMTokenizationTest(TokenizerTesterMixin, unittest.TestCase): self.assertEqual(encoding_tokenizer_slow_1, encoding_tokenizer_slow_2) self.assertEqual(encoding_tokenizer_slow_1, encoding_tokenizer_slow_3) + def test_split_special_tokens(self): + tokenizer = self.tokenizer_class.from_pretrained("microsoft/layoutxlm-base") + _, _, boxes = self.get_question_words_and_boxes() + special_token = "[SPECIAL_TOKEN]" + tokenizer.add_special_tokens({"additional_special_tokens": [special_token]}) + encoded_special_token = tokenizer.tokenize(special_token, boxes=boxes, add_special_tokens=False) + self.assertEqual(len(encoded_special_token), 1) + + encoded_split_special_token = tokenizer.tokenize( + special_token, add_special_tokens=False, split_special_tokens=True, boxes=boxes + ) + self.assertTrue(len(encoded_split_special_token) > 1) + @slow def test_sequence_builders(self): tokenizer = self.tokenizer_class.from_pretrained("microsoft/layoutxlm-base") diff --git a/tests/models/markuplm/test_tokenization_markuplm.py b/tests/models/markuplm/test_tokenization_markuplm.py index 533a3429a..73979b255 100644 --- a/tests/models/markuplm/test_tokenization_markuplm.py +++ b/tests/models/markuplm/test_tokenization_markuplm.py @@ -1344,6 +1344,19 @@ class MarkupLMTokenizationTest(TokenizerTesterMixin, unittest.TestCase): self.assertTrue(special_token_id in p_output) self.assertTrue(special_token_id in cr_output) + def test_split_special_tokens(self): + # TODO this is only possible for slow currently + tokenizer = self.get_tokenizer() + special_token = "[SPECIAL_TOKEN]" + tokenizer.add_special_tokens({"additional_special_tokens": [special_token]}) + encoded_special_token = tokenizer.tokenize(special_token, add_special_tokens=False) + self.assertEqual(len(encoded_special_token), 1) + + encoded_split_special_token = tokenizer.tokenize( + special_token, add_special_tokens=False, split_special_tokens=True + ) + self.assertTrue(len(encoded_split_special_token) > 1) + def test_training_new_tokenizer(self): # This feature only exists for fast tokenizers if not self.test_rust_tokenizer: diff --git a/tests/test_tokenization_common.py b/tests/test_tokenization_common.py index 3b17c6ea4..aec5e493c 100644 --- a/tests/test_tokenization_common.py +++ b/tests/test_tokenization_common.py @@ -3909,6 +3909,7 @@ class TokenizerTesterMixin: # Should not raise an error self.rust_tokenizer_class.from_pretrained(tmp_dir_2) + # TODO This is ran for all models but only tests bert... def test_clean_up_tokenization_spaces(self): tokenizer = BertTokenizer.from_pretrained("bert-base-uncased") assert tokenizer.clean_up_tokenization_spaces is True @@ -3953,3 +3954,29 @@ class TokenizerTesterMixin: tokenizer.clean_up_tokenization_spaces = True decoded = tokenizer.decode(tokens) assert decoded == "[CLS] this shouldn't be! he'll go. [SEP]" + + def test_split_special_tokens(self): + if not self.test_slow_tokenizer: + return + + for tokenizer, pretrained_name, kwargs in self.tokenizers_list: + special_token = "[SPECIAL_TOKEN]" + with self.subTest(f"{tokenizer.__class__.__name__} ({pretrained_name})"): + tokenizer = self.tokenizer_class.from_pretrained(pretrained_name, **kwargs) + + if not tokenizer.is_fast: + # bloom, gptneox etc only have a fast + tokenizer.add_special_tokens({"additional_special_tokens": [special_token]}) + encoded_special_token = tokenizer.encode(special_token, add_special_tokens=False) + self.assertEqual(len(encoded_special_token), 1) + + encoded_split_special_token = tokenizer.encode( + special_token, add_special_tokens=False, split_special_tokens=True + ) + if len(encoded_split_special_token) == 1: + # if we have subword tokenization or special vocab + self.assertTrue( + encoded_split_special_token[0] != tokenizer.convert_tokens_to_ids(special_token) + ) + else: + self.assertTrue(len(encoded_split_special_token) > 1)