From 9a12b9696fca52c71601b59a73c8e18426519027 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Mon, 21 Dec 2020 15:41:34 +0100 Subject: [PATCH] [MPNet] Add slow to fast tokenizer converter (#9233) * add converter * delet unnecessary comments --- src/transformers/convert_slow_tokenizer.py | 64 ++++++++++++++-------- tests/test_tokenization_mpnet.py | 5 +- 2 files changed, 44 insertions(+), 25 deletions(-) diff --git a/src/transformers/convert_slow_tokenizer.py b/src/transformers/convert_slow_tokenizer.py index b1b3408ac..2c0f9f7f2 100644 --- a/src/transformers/convert_slow_tokenizer.py +++ b/src/transformers/convert_slow_tokenizer.py @@ -74,18 +74,6 @@ class BertConverter(Converter): vocab = self.original_tokenizer.vocab tokenizer = Tokenizer(WordPiece(vocab, unk_token=str(self.original_tokenizer.unk_token))) - # # Let the tokenizer know about special tokens if they are part of the vocab - # if tokenizer.token_to_id(str(self.original_tokenizer.unk_token)) is not None: - # tokenizer.add_special_tokens([str(self.original_tokenizer.unk_token)]) - # if tokenizer.token_to_id(str(self.original_tokenizer.sep_token)) is not None: - # tokenizer.add_special_tokens([str(self.original_tokenizer.sep_token)]) - # if tokenizer.token_to_id(str(self.original_tokenizer.cls_token)) is not None: - # tokenizer.add_special_tokens([str(self.original_tokenizer.cls_token)]) - # if tokenizer.token_to_id(str(self.original_tokenizer.pad_token)) is not None: - # tokenizer.add_special_tokens([str(self.original_tokenizer.pad_token)]) - # if tokenizer.token_to_id(str(self.original_tokenizer.mask_token)) is not None: - # tokenizer.add_special_tokens([str(self.original_tokenizer.mask_token)]) - tokenize_chinese_chars = False strip_accents = False do_lower_case = False @@ -125,18 +113,6 @@ class FunnelConverter(Converter): vocab = self.original_tokenizer.vocab tokenizer = Tokenizer(WordPiece(vocab, unk_token=str(self.original_tokenizer.unk_token))) - # # Let the tokenizer know about special tokens if they are part of the vocab - # if tokenizer.token_to_id(str(self.original_tokenizer.unk_token)) is not None: - # tokenizer.add_special_tokens([str(self.original_tokenizer.unk_token)]) - # if tokenizer.token_to_id(str(self.original_tokenizer.sep_token)) is not None: - # tokenizer.add_special_tokens([str(self.original_tokenizer.sep_token)]) - # if tokenizer.token_to_id(str(self.original_tokenizer.cls_token)) is not None: - # tokenizer.add_special_tokens([str(self.original_tokenizer.cls_token)]) - # if tokenizer.token_to_id(str(self.original_tokenizer.pad_token)) is not None: - # tokenizer.add_special_tokens([str(self.original_tokenizer.pad_token)]) - # if tokenizer.token_to_id(str(self.original_tokenizer.mask_token)) is not None: - # tokenizer.add_special_tokens([str(self.original_tokenizer.mask_token)]) - tokenize_chinese_chars = False strip_accents = False do_lower_case = False @@ -171,6 +147,45 @@ class FunnelConverter(Converter): return tokenizer +class MPNetConverter(Converter): + def converted(self) -> Tokenizer: + vocab = self.original_tokenizer.vocab + tokenizer = Tokenizer(WordPiece(vocab, unk_token=str(self.original_tokenizer.unk_token))) + + tokenize_chinese_chars = False + strip_accents = False + do_lower_case = False + if hasattr(self.original_tokenizer, "basic_tokenizer"): + tokenize_chinese_chars = self.original_tokenizer.basic_tokenizer.tokenize_chinese_chars + strip_accents = self.original_tokenizer.basic_tokenizer.strip_accents + do_lower_case = self.original_tokenizer.basic_tokenizer.do_lower_case + + tokenizer.normalizer = normalizers.BertNormalizer( + clean_text=True, + handle_chinese_chars=tokenize_chinese_chars, + strip_accents=strip_accents, + lowercase=do_lower_case, + ) + tokenizer.pre_tokenizer = pre_tokenizers.BertPreTokenizer() + + cls = str(self.original_tokenizer.cls_token) + sep = str(self.original_tokenizer.sep_token) + cls_token_id = self.original_tokenizer.cls_token_id + sep_token_id = self.original_tokenizer.sep_token_id + + tokenizer.post_processor = processors.TemplateProcessing( + single=f"{cls}:0 $A:0 {sep}:0", + pair=f"{cls}:0 $A:0 {sep}:0 {sep}:0 $B:1 {sep}:1", # MPNet uses two [SEP] tokens + special_tokens=[ + (cls, cls_token_id), + (sep, sep_token_id), + ], + ) + tokenizer.decoder = decoders.WordPiece(prefix="##") + + return tokenizer + + class OpenAIGPTConverter(Converter): def converted(self) -> Tokenizer: vocab = self.original_tokenizer.encoder @@ -602,6 +617,7 @@ SLOW_TO_FAST_CONVERTERS = { "LongformerTokenizer": RobertaConverter, "LxmertTokenizer": BertConverter, "MBartTokenizer": MBartConverter, + "MPNetTokenizer": MPNetConverter, "MobileBertTokenizer": BertConverter, "OpenAIGPTTokenizer": OpenAIGPTConverter, "PegasusTokenizer": PegasusConverter, diff --git a/tests/test_tokenization_mpnet.py b/tests/test_tokenization_mpnet.py index 2a4f26ff9..733b2891f 100644 --- a/tests/test_tokenization_mpnet.py +++ b/tests/test_tokenization_mpnet.py @@ -17,6 +17,7 @@ import os import unittest +from transformers import MPNetTokenizerFast from transformers.models.mpnet.tokenization_mpnet import VOCAB_FILES_NAMES, MPNetTokenizer from transformers.testing_utils import require_tokenizers, slow @@ -27,7 +28,9 @@ from .test_tokenization_common import TokenizerTesterMixin class MPNetTokenizerTest(TokenizerTesterMixin, unittest.TestCase): tokenizer_class = MPNetTokenizer - test_rust_tokenizer = False + rust_tokenizer_class = MPNetTokenizerFast + test_rust_tokenizer = True + space_between_special_tokens = True def setUp(self): super().setUp()