From d5bc32ce92ace9aaec7752e0b89d51ba18903a1b Mon Sep 17 00:00:00 2001 From: Philip May Date: Thu, 6 Aug 2020 12:52:28 +0200 Subject: [PATCH] Add strip_accents to basic BertTokenizer. (#6280) * Add strip_accents to basic tokenizer * Add tests for strip_accents. * fix style with black * Fix strip_accents test * empty commit to trigger CI * Improved strip_accents check * Add code quality with is not False --- src/transformers/tokenization_bert.py | 25 ++++++++++++++---- tests/test_tokenization_bert.py | 38 +++++++++++++++++++++++++++ 2 files changed, 58 insertions(+), 5 deletions(-) diff --git a/src/transformers/tokenization_bert.py b/src/transformers/tokenization_bert.py index 96c71d0d9..254398b34 100644 --- a/src/transformers/tokenization_bert.py +++ b/src/transformers/tokenization_bert.py @@ -154,6 +154,9 @@ class BertTokenizer(PreTrainedTokenizer): Whether to tokenize Chinese characters. This should likely be deactivated for Japanese: see: https://github.com/huggingface/transformers/issues/328 + strip_accents: (:obj:`bool`, `optional`, defaults to :obj:`None`): + Whether to strip all accents. If this option is not specified (ie == None), + then it will be determined by the value for `lowercase` (as in the original Bert). """ vocab_files_names = VOCAB_FILES_NAMES @@ -173,6 +176,7 @@ class BertTokenizer(PreTrainedTokenizer): cls_token="[CLS]", mask_token="[MASK]", tokenize_chinese_chars=True, + strip_accents=None, **kwargs ): super().__init__( @@ -194,7 +198,10 @@ class BertTokenizer(PreTrainedTokenizer): self.do_basic_tokenize = do_basic_tokenize if do_basic_tokenize: self.basic_tokenizer = BasicTokenizer( - do_lower_case=do_lower_case, never_split=never_split, tokenize_chinese_chars=tokenize_chinese_chars + do_lower_case=do_lower_case, + never_split=never_split, + tokenize_chinese_chars=tokenize_chinese_chars, + strip_accents=strip_accents, ) self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab, unk_token=self.unk_token) @@ -351,7 +358,7 @@ class BertTokenizer(PreTrainedTokenizer): class BasicTokenizer(object): """Runs basic tokenization (punctuation splitting, lower casing, etc.).""" - def __init__(self, do_lower_case=True, never_split=None, tokenize_chinese_chars=True): + def __init__(self, do_lower_case=True, never_split=None, tokenize_chinese_chars=True, strip_accents=None): """ Constructs a BasicTokenizer. Args: @@ -364,12 +371,16 @@ class BasicTokenizer(object): Whether to tokenize Chinese characters. This should likely be deactivated for Japanese: see: https://github.com/huggingface/pytorch-pretrained-BERT/issues/328 + **strip_accents**: (`optional`) boolean (default None) + Whether to strip all accents. If this option is not specified (ie == None), + then it will be determined by the value for `lowercase` (as in the original Bert). """ if never_split is None: never_split = [] self.do_lower_case = do_lower_case self.never_split = set(never_split) self.tokenize_chinese_chars = tokenize_chinese_chars + self.strip_accents = strip_accents def tokenize(self, text, never_split=None): """ Basic Tokenization of a piece of text. @@ -395,9 +406,13 @@ class BasicTokenizer(object): orig_tokens = whitespace_tokenize(text) split_tokens = [] for token in orig_tokens: - if self.do_lower_case and token not in never_split: - token = token.lower() - token = self._run_strip_accents(token) + if token not in never_split: + if self.do_lower_case: + token = token.lower() + if self.strip_accents is not False: + token = self._run_strip_accents(token) + elif self.strip_accents: + token = self._run_strip_accents(token) split_tokens.extend(self._run_split_on_punc(token, never_split)) output_tokens = whitespace_tokenize(" ".join(split_tokens)) diff --git a/tests/test_tokenization_bert.py b/tests/test_tokenization_bert.py index 7aa3fbe1d..4421d30de 100644 --- a/tests/test_tokenization_bert.py +++ b/tests/test_tokenization_bert.py @@ -130,6 +130,30 @@ class BertTokenizationTest(TokenizerTesterMixin, unittest.TestCase): ) self.assertListEqual(tokenizer.tokenize("H\u00E9llo"), ["hello"]) + def test_basic_tokenizer_lower_strip_accents_false(self): + tokenizer = BasicTokenizer(do_lower_case=True, strip_accents=False) + + self.assertListEqual( + tokenizer.tokenize(" \tHäLLo!how \n Are yoU? "), ["hällo", "!", "how", "are", "you", "?"] + ) + self.assertListEqual(tokenizer.tokenize("H\u00E9llo"), ["h\u00E9llo"]) + + def test_basic_tokenizer_lower_strip_accents_true(self): + tokenizer = BasicTokenizer(do_lower_case=True, strip_accents=True) + + self.assertListEqual( + tokenizer.tokenize(" \tHäLLo!how \n Are yoU? "), ["hallo", "!", "how", "are", "you", "?"] + ) + self.assertListEqual(tokenizer.tokenize("H\u00E9llo"), ["hello"]) + + def test_basic_tokenizer_lower_strip_accents_default(self): + tokenizer = BasicTokenizer(do_lower_case=True) + + self.assertListEqual( + tokenizer.tokenize(" \tHäLLo!how \n Are yoU? "), ["hallo", "!", "how", "are", "you", "?"] + ) + self.assertListEqual(tokenizer.tokenize("H\u00E9llo"), ["hello"]) + def test_basic_tokenizer_no_lower(self): tokenizer = BasicTokenizer(do_lower_case=False) @@ -137,6 +161,20 @@ class BertTokenizationTest(TokenizerTesterMixin, unittest.TestCase): tokenizer.tokenize(" \tHeLLo!how \n Are yoU? "), ["HeLLo", "!", "how", "Are", "yoU", "?"] ) + def test_basic_tokenizer_no_lower_strip_accents_false(self): + tokenizer = BasicTokenizer(do_lower_case=False, strip_accents=False) + + self.assertListEqual( + tokenizer.tokenize(" \tHäLLo!how \n Are yoU? "), ["HäLLo", "!", "how", "Are", "yoU", "?"] + ) + + def test_basic_tokenizer_no_lower_strip_accents_true(self): + tokenizer = BasicTokenizer(do_lower_case=False, strip_accents=True) + + self.assertListEqual( + tokenizer.tokenize(" \tHäLLo!how \n Are yoU? "), ["HaLLo", "!", "how", "Are", "yoU", "?"] + ) + def test_basic_tokenizer_respects_never_split_tokens(self): tokenizer = BasicTokenizer(do_lower_case=False, never_split=["[UNK]"])