From 31565ff0fd827c7ffa707d8b99cb2abf7fa7f6bb Mon Sep 17 00:00:00 2001 From: Hao Wang <50416856+conan1024hao@users.noreply.github.com> Date: Fri, 21 Oct 2022 23:04:49 +0900 Subject: [PATCH] Add sentencepiece to BertJapaneseTokenizer (#19769) * support sentencepiece for bertjapanesetokenizer * add test vocab file for sentencepiece, bertjapanesetokenizer * make BasicTokenizer be identical to transformers.models.bert.tokenization_bert.BasicTokenizer * fix missing of \n in comment * fix init argument missing in tests * make spm_file be optional, exclude spiece.model from tests/fixtures, and add description comments * make comment length less than 119 * apply doc style check --- .../tokenization_bert_japanese.py | 166 ++++++++++++++---- .../test_tokenization_bert_japanese.py | 10 ++ 2 files changed, 146 insertions(+), 30 deletions(-) diff --git a/src/transformers/models/bert_japanese/tokenization_bert_japanese.py b/src/transformers/models/bert_japanese/tokenization_bert_japanese.py index edb462f03..cbe700b67 100644 --- a/src/transformers/models/bert_japanese/tokenization_bert_japanese.py +++ b/src/transformers/models/bert_japanese/tokenization_bert_japanese.py @@ -19,7 +19,9 @@ import collections import copy import os import unicodedata -from typing import List, Optional, Tuple +from typing import Any, Dict, List, Optional, Tuple + +import sentencepiece as spm from ...tokenization_utils import PreTrainedTokenizer, _is_control, _is_punctuation, _is_whitespace from ...utils import logging @@ -27,7 +29,9 @@ from ...utils import logging logger = logging.get_logger(__name__) -VOCAB_FILES_NAMES = {"vocab_file": "vocab.txt"} +VOCAB_FILES_NAMES = {"vocab_file": "vocab.txt", "spm_file": "spiece.model"} + +SPIECE_UNDERLINE = "▁" PRETRAINED_VOCAB_FILES_MAP = { "vocab_file": { @@ -107,6 +111,9 @@ class BertJapaneseTokenizer(PreTrainedTokenizer): Args: vocab_file (`str`): Path to a one-wordpiece-per-line vocabulary file. + spm_file (`str`, *optional*): + Path to [SentencePiece](https://github.com/google/sentencepiece) file (generally has a .spm or .model + extension) that contains the vocabulary. do_lower_case (`bool`, *optional*, defaults to `True`): Whether to lower case the input. Only has an effect when do_basic_tokenize=True. do_word_tokenize (`bool`, *optional*, defaults to `True`): @@ -116,7 +123,7 @@ class BertJapaneseTokenizer(PreTrainedTokenizer): word_tokenizer_type (`str`, *optional*, defaults to `"basic"`): Type of word tokenizer. Choose from ["basic", "mecab", "sudachi", "jumanpp"]. subword_tokenizer_type (`str`, *optional*, defaults to `"wordpiece"`): - Type of subword tokenizer. Choose from ["wordpiece", "character"]. + Type of subword tokenizer. Choose from ["wordpiece", "character", "sentencepiece",]. mecab_kwargs (`dict`, *optional*): Dictionary passed to the `MecabTokenizer` constructor. sudachi_kwargs (`dict`, *optional*): @@ -133,6 +140,7 @@ class BertJapaneseTokenizer(PreTrainedTokenizer): def __init__( self, vocab_file, + spm_file=None, do_lower_case=False, do_word_tokenize=True, do_subword_tokenize=True, @@ -150,6 +158,7 @@ class BertJapaneseTokenizer(PreTrainedTokenizer): **kwargs ): super().__init__( + spm_file=spm_file, unk_token=unk_token, sep_token=sep_token, pad_token=pad_token, @@ -167,13 +176,21 @@ class BertJapaneseTokenizer(PreTrainedTokenizer): **kwargs, ) - if not os.path.isfile(vocab_file): - raise ValueError( - f"Can't find a vocabulary file at path '{vocab_file}'. To load the vocabulary from a Google pretrained" - " model use `tokenizer = AutoTokenizer.from_pretrained(PRETRAINED_MODEL_NAME)`" - ) - self.vocab = load_vocab(vocab_file) - self.ids_to_tokens = collections.OrderedDict([(ids, tok) for tok, ids in self.vocab.items()]) + if subword_tokenizer_type == "sentencepiece": + if not os.path.isfile(spm_file): + raise ValueError( + f"Can't find a vocabulary file at path '{spm_file}'. To load the vocabulary from a Google" + " pretrained model use `tokenizer = AutoTokenizer.from_pretrained(PRETRAINED_MODEL_NAME)`" + ) + self.spm_file = spm_file + else: + if not os.path.isfile(vocab_file): + raise ValueError( + f"Can't find a vocabulary file at path '{vocab_file}'. To load the vocabulary from a Google" + " pretrained model use `tokenizer = AutoTokenizer.from_pretrained(PRETRAINED_MODEL_NAME)`" + ) + self.vocab = load_vocab(vocab_file) + self.ids_to_tokens = collections.OrderedDict([(ids, tok) for tok, ids in self.vocab.items()]) self.do_word_tokenize = do_word_tokenize self.word_tokenizer_type = word_tokenizer_type @@ -209,6 +226,8 @@ class BertJapaneseTokenizer(PreTrainedTokenizer): self.subword_tokenizer = WordpieceTokenizer(vocab=self.vocab, unk_token=self.unk_token) elif subword_tokenizer_type == "character": self.subword_tokenizer = CharacterTokenizer(vocab=self.vocab, unk_token=self.unk_token) + elif subword_tokenizer_type == "sentencepiece": + self.subword_tokenizer = SentencepieceTokenizer(vocab=self.spm_file, unk_token=self.unk_token) else: raise ValueError(f"Invalid subword_tokenizer_type '{subword_tokenizer_type}' is specified.") @@ -251,27 +270,34 @@ class BertJapaneseTokenizer(PreTrainedTokenizer): return split_tokens @property - # Copied from transformers.models.bert.tokenization_bert.BertTokenizer.vocab_size def vocab_size(self): + if self.subword_tokenizer_type == "sentencepiece": + return len(self.subword_tokenizer.sp_model) return len(self.vocab) - # Copied from transformers.models.bert.tokenization_bert.BertTokenizer.get_vocab def get_vocab(self): + if self.subword_tokenizer_type == "sentencepiece": + vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size)} + vocab.update(self.added_tokens_encoder) + return vocab return dict(self.vocab, **self.added_tokens_encoder) - # Copied from transformers.models.bert.tokenization_bert.BertTokenizer._convert_token_to_id def _convert_token_to_id(self, token): """Converts a token (str) in an id using the vocab.""" + if self.subword_tokenizer_type == "sentencepiece": + return self.subword_tokenizer.sp_model.PieceToId(token) return self.vocab.get(token, self.vocab.get(self.unk_token)) - # Copied from transformers.models.bert.tokenization_bert.BertTokenizer._convert_id_to_token def _convert_id_to_token(self, index): """Converts an index (integer) in a token (str) using the vocab.""" + if self.subword_tokenizer_type == "sentencepiece": + return self.subword_tokenizer.sp_model.IdToPiece(index) return self.ids_to_tokens.get(index, self.unk_token) - # Copied from transformers.models.bert.tokenization_bert.BertTokenizer.convert_tokens_to_string def convert_tokens_to_string(self, tokens): """Converts a sequence of tokens (string) in a single string.""" + if self.subword_tokenizer_type == "sentencepiece": + return self.subword_tokenizer.sp_model.decode(tokens) out_string = " ".join(tokens).replace(" ##", "").strip() return out_string @@ -360,25 +386,36 @@ class BertJapaneseTokenizer(PreTrainedTokenizer): return len(cls + token_ids_0 + sep) * [0] return len(cls + token_ids_0 + sep) * [0] + len(token_ids_1 + sep) * [1] - # Copied from transformers.models.bert.tokenization_bert.BertTokenizer.save_vocabulary def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]: - index = 0 if os.path.isdir(save_directory): - vocab_file = os.path.join( - save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"] - ) + if self.subword_tokenizer_type == "sentencepiece": + vocab_file = os.path.join( + save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["spm_file"] + ) + else: + vocab_file = os.path.join( + save_directory, + (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"], + ) else: vocab_file = (filename_prefix + "-" if filename_prefix else "") + save_directory - with open(vocab_file, "w", encoding="utf-8") as writer: - for token, token_index in sorted(self.vocab.items(), key=lambda kv: kv[1]): - if index != token_index: - logger.warning( - f"Saving vocabulary to {vocab_file}: vocabulary indices are not consecutive." - " Please check that the vocabulary is not corrupted!" - ) - index = token_index - writer.write(token + "\n") - index += 1 + + if self.subword_tokenizer_type == "sentencepiece": + with open(vocab_file, "wb") as writer: + content_spiece_model = self.subword_tokenizer.sp_model.serialized_model_proto() + writer.write(content_spiece_model) + else: + with open(vocab_file, "w", encoding="utf-8") as writer: + index = 0 + for token, token_index in sorted(self.vocab.items(), key=lambda kv: kv[1]): + if index != token_index: + logger.warning( + f"Saving vocabulary to {vocab_file}: vocabulary indices are not consecutive." + " Please check that the vocabulary is not corrupted!" + ) + index = token_index + writer.write(token + "\n") + index += 1 return (vocab_file,) @@ -893,3 +930,72 @@ class WordpieceTokenizer(object): else: output_tokens.extend(sub_tokens) return output_tokens + + +class SentencepieceTokenizer(object): + """ + Runs sentencepiece tokenization. Based on transformers.models.albert.tokenization_albert.AlbertTokenizer. + """ + + def __init__( + self, + vocab, + unk_token, + do_lower_case=False, + remove_space=True, + keep_accents=True, + sp_model_kwargs: Optional[Dict[str, Any]] = None, + ): + self.vocab = vocab + self.unk_token = unk_token + self.do_lower_case = do_lower_case + self.remove_space = remove_space + self.keep_accents = keep_accents + + self.sp_model_kwargs = {} if sp_model_kwargs is None else sp_model_kwargs + self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs) + self.sp_model.Load(self.vocab) + + def preprocess_text(self, inputs): + if self.remove_space: + outputs = " ".join(inputs.strip().split()) + else: + outputs = inputs + outputs = outputs.replace("``", '"').replace("''", '"') + + if not self.keep_accents: + outputs = unicodedata.normalize("NFKD", outputs) + outputs = "".join([c for c in outputs if not unicodedata.combining(c)]) + if self.do_lower_case: + outputs = outputs.lower() + + return outputs + + def tokenize(self, text): + """ + Tokenizes text by sentencepiece. Based on [SentencePiece](https://github.com/google/sentencepiece). + Tokenization needs the given vocabulary. + + Args: + text: A string needs to be tokenized. + + Returns: + A list of sentencepiece tokens. + """ + text = self.preprocess_text(text) + pieces = self.sp_model.encode(text, out_type=str) + new_pieces = [] + for piece in pieces: + if len(piece) > 1 and piece[-1] == str(",") and piece[-2].isdigit(): + cur_pieces = self.sp_model.EncodeAsPieces(piece[:-1].replace(SPIECE_UNDERLINE, "")) + if piece[0] != SPIECE_UNDERLINE and cur_pieces[0][0] == SPIECE_UNDERLINE: + if len(cur_pieces[0]) == 1: + cur_pieces = cur_pieces[1:] + else: + cur_pieces[0] = cur_pieces[0][1:] + cur_pieces.append(piece[-1]) + new_pieces.extend(cur_pieces) + else: + new_pieces.append(piece) + + return new_pieces diff --git a/tests/models/bert_japanese/test_tokenization_bert_japanese.py b/tests/models/bert_japanese/test_tokenization_bert_japanese.py index 7141f7551..7e89c36b7 100644 --- a/tests/models/bert_japanese/test_tokenization_bert_japanese.py +++ b/tests/models/bert_japanese/test_tokenization_bert_japanese.py @@ -334,6 +334,16 @@ class BertJapaneseTokenizationTest(TokenizerTesterMixin, unittest.TestCase): self.assertListEqual(tokenizer.tokenize("こんばんは こんばんにちは こんにちは"), ["こん", "##ばんは", "[UNK]", "こんにちは"]) + def test_sentencepiece_tokenizer(self): + tokenizer = BertJapaneseTokenizer.from_pretrained("nlp-waseda/roberta-base-japanese-with-auto-jumanpp") + subword_tokenizer = tokenizer.subword_tokenizer + + tokens = subword_tokenizer.tokenize("国境 の 長い トンネル を 抜ける と 雪国 であった 。") + self.assertListEqual(tokens, ["▁国境", "▁の", "▁長い", "▁トンネル", "▁を", "▁抜ける", "▁と", "▁雪", "国", "▁であった", "▁。"]) + + tokens = subword_tokenizer.tokenize("こんばんは こんばん にち は こんにちは") + self.assertListEqual(tokens, ["▁こん", "ばん", "は", "▁こん", "ばん", "▁に", "ち", "▁は", "▁こんにちは"]) + def test_sequence_builders(self): tokenizer = self.tokenizer_class.from_pretrained("cl-tohoku/bert-base-japanese")