From a16e568f22a4d07813ba76343309ec20096115a5 Mon Sep 17 00:00:00 2001 From: wlhgtc Date: Thu, 22 Oct 2020 21:19:00 +0800 Subject: [PATCH] # Add whole word mask support for lm fine-tune (#7925) * ADD: add whole word mask proxy for both eng and chinese * MOD: adjust format * MOD: reformat code * MOD: update import * MOD: fix bug * MOD: add import * MOD: fix bug * MOD: decouple code and update readme * MOD: reformat code * Update examples/language-modeling/README.md Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Update examples/language-modeling/README.md Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Update examples/language-modeling/run_language_modeling.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Update examples/language-modeling/run_language_modeling.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Update examples/language-modeling/run_language_modeling.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Update examples/language-modeling/run_language_modeling.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * change wwm to whole_word_mask * reformat code * reformat * format * Code quality * ADD: update chinese ref readme * MOD: small changes * MOD: small changes2 * update readme Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Co-authored-by: Sylvain Gugger --- examples/language-modeling/README.md | 52 ++++++- examples/language-modeling/chinese_ref.py | 147 ++++++++++++++++++ .../run_language_modeling.py | 29 +++- src/transformers/__init__.py | 2 + src/transformers/data/data_collator.py | 119 ++++++++++++++ src/transformers/data/datasets/__init__.py | 1 + .../data/datasets/language_modeling.py | 41 ++++- src/transformers/utils/dummy_pt_objects.py | 10 ++ 8 files changed, 394 insertions(+), 7 deletions(-) create mode 100644 examples/language-modeling/chinese_ref.py diff --git a/examples/language-modeling/README.md b/examples/language-modeling/README.md index a66215351..800349592 100644 --- a/examples/language-modeling/README.md +++ b/examples/language-modeling/README.md @@ -45,6 +45,8 @@ slightly slower (over-fitting takes more epochs). We use the `--mlm` flag so that the script may change its loss function. +If using whole-word masking, use both the`--mlm` and `--wwm` flags. + ```bash export TRAIN_FILE=/path/to/dataset/wiki.train.raw export TEST_FILE=/path/to/dataset/wiki.test.raw @@ -57,7 +59,55 @@ python run_language_modeling.py \ --train_data_file=$TRAIN_FILE \ --do_eval \ --eval_data_file=$TEST_FILE \ - --mlm + --mlm \ + --wwm +``` + +For Chinese models, it's same with English model with only --mlm`. If using whole-word masking, we need to generate a reference files, case it's char level. + +**Q :** Why ref file ? + +**A :** Suppose we have a Chinese sentence like : `我喜欢你` The original Chinese-BERT will tokenize it as `['我','喜','欢','你']` in char level. +Actually, `喜欢` is a whole word. For whole word mask proxy, We need res like `['我','喜','##欢','你']`. +So we need a ref file to tell model which pos of BERT original token should be added `##`. + +**Q :** Why LTP ? + +**A :** Cause the best known Chinese WWM BERT is [Chinese-BERT-wwm](https://github.com/ymcui/Chinese-BERT-wwm) by HIT. It works well on so many Chines Task like CLUE (Chinese GLUE). +They use LTP, so if we want to fine-tune their model, we need LTP. + +```bash +export TRAIN_FILE=/path/to/dataset/wiki.train.raw +export LTP_RESOURCE=/path/to/ltp/tokenizer +export BERT_RESOURCE=/path/to/bert/tokenizer +export SAVE_PATH=/path/to/data/ref.txt + +python chinese_ref.py \ + --file_name=$TRAIN_FILE \ + --ltp=$LTP_RESOURCE + --bert=$BERT_RESOURCE \ + --save_path=$SAVE_PATH +``` +Now Chinese Ref is only supported by `LineByLineWithRefDataset` Class, so we need add `line_by_line` flag: + + +```bash +export TRAIN_FILE=/path/to/dataset/wiki.train.raw +export TEST_FILE=/path/to/dataset/wiki.test.raw +export REF_FILE=/path/to/ref.txt + +python run_language_modeling.py \ + --output_dir=output \ + --model_type=roberta \ + --model_name_or_path=roberta-base \ + --do_train \ + --train_data_file=$TRAIN_FILE \ + --chinese_ref_file=$REF_FILE \ + --do_eval \ + --eval_data_file=$TEST_FILE \ + --mlm \ + --line_by_line \ + --wwm ``` ### XLNet and permutation language modeling diff --git a/examples/language-modeling/chinese_ref.py b/examples/language-modeling/chinese_ref.py new file mode 100644 index 000000000..02a1038f1 --- /dev/null +++ b/examples/language-modeling/chinese_ref.py @@ -0,0 +1,147 @@ +import argparse +import json +from typing import List + +from ltp import LTP +from transformers.tokenization_bert import BertTokenizer + + +def _is_chinese_char(cp): + """Checks whether CP is the codepoint of a CJK character.""" + # This defines a "chinese character" as anything in the CJK Unicode block: + # https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block) + # + # Note that the CJK Unicode block is NOT all Japanese and Korean characters, + # despite its name. The modern Korean Hangul alphabet is a different block, + # as is Japanese Hiragana and Katakana. Those alphabets are used to write + # space-separated words, so they are not treated specially and handled + # like the all of the other languages. + if ( + (cp >= 0x4E00 and cp <= 0x9FFF) + or (cp >= 0x3400 and cp <= 0x4DBF) # + or (cp >= 0x20000 and cp <= 0x2A6DF) # + or (cp >= 0x2A700 and cp <= 0x2B73F) # + or (cp >= 0x2B740 and cp <= 0x2B81F) # + or (cp >= 0x2B820 and cp <= 0x2CEAF) # + or (cp >= 0xF900 and cp <= 0xFAFF) + or (cp >= 0x2F800 and cp <= 0x2FA1F) # + ): # + return True + + return False + + +def is_chinese(word: str): + # word like '180' or '身高' or '神' + for char in word: + char = ord(char) + if not _is_chinese_char(char): + return 0 + return 1 + + +def get_chinese_word(tokens: List[str]): + word_set = set() + + for token in tokens: + chinese_word = len(token) > 1 and is_chinese(token) + if chinese_word: + word_set.add(token) + word_list = list(word_set) + return word_list + + +def add_sub_symbol(bert_tokens: List[str], chinese_word_set: set()): + if not chinese_word_set: + return bert_tokens + max_word_len = max([len(w) for w in chinese_word_set]) + + bert_word = bert_tokens + start, end = 0, len(bert_word) + while start < end: + single_word = True + if is_chinese(bert_word[start]): + l = min(end - start, max_word_len) + for i in range(l, 1, -1): + whole_word = "".join(bert_word[start : start + i]) + if whole_word in chinese_word_set: + for j in range(start + 1, start + i): + bert_word[j] = "##" + bert_word[j] + start = start + i + single_word = False + break + if single_word: + start += 1 + return bert_word + + +def prepare_ref(lines: List[str], ltp_tokenizer: LTP, bert_tokenizer: BertTokenizer): + ltp_res = [] + + for i in range(0, len(lines), 100): + res = ltp_tokenizer.seg(lines[i : i + 100])[0] + res = [get_chinese_word(r) for r in res] + ltp_res.extend(res) + assert len(ltp_res) == len(lines) + + bert_res = [] + for i in range(0, len(lines), 100): + res = bert_tokenizer(lines[i : i + 100], add_special_tokens=True, truncation=True, max_length=512) + bert_res.extend(res["input_ids"]) + assert len(bert_res) == len(lines) + + ref_ids = [] + for input_ids, chinese_word in zip(bert_res, ltp_res): + + input_tokens = [] + for id in input_ids: + token = bert_tokenizer._convert_id_to_token(id) + input_tokens.append(token) + input_tokens = add_sub_symbol(input_tokens, chinese_word) + ref_id = [] + # We only save pos of chinese subwords start with ##, which mean is part of a whole word. + for i, token in enumerate(input_tokens): + if token[:2] == "##": + clean_token = token[2:] + # save chinese tokens' pos + if len(clean_token) == 1 and _is_chinese_char(ord(clean_token)): + ref_id.append(i) + ref_ids.append(ref_id) + + assert len(ref_ids) == len(bert_res) + + return ref_ids + + +def main(args): + # For Chinese (Ro)Bert, the best result is from : RoBERTa-wwm-ext (https://github.com/ymcui/Chinese-BERT-wwm) + # If we want to fine-tune these model, we have to use same tokenizer : LTP (https://github.com/HIT-SCIR/ltp) + with open(args.file_name, "r", encoding="utf-8") as f: + data = f.readlines() + + ltp_tokenizer = LTP(args.ltp) # faster in GPU device + bert_tokenizer = BertTokenizer.from_pretrained(args.bert) + + ref_ids = prepare_ref(data, ltp_tokenizer, bert_tokenizer) + + with open(args.save_path, "w", encoding="utf-8") as f: + data = [json.dumps(ref) + "\n" for ref in ref_ids] + f.writelines(data) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="prepare_chinese_ref") + parser.add_argument( + "--file_name", + type=str, + default="./resources/chinese-demo.txt", + help="file need process, same as training data in lm", + ) + parser.add_argument( + "--ltp", type=str, default="./resources/ltp", help="resources for LTP tokenizer, usually a path" + ) + parser.add_argument("--bert", type=str, default="./resources/robert", help="resources for Bert tokenizer") + parser.add_argument("--save_path", type=str, default="./resources/ref.txt", help="path to save res") + + args = parser.parse_args() + main(args) diff --git a/examples/language-modeling/run_language_modeling.py b/examples/language-modeling/run_language_modeling.py index f3ce40e32..6a3f3f63f 100644 --- a/examples/language-modeling/run_language_modeling.py +++ b/examples/language-modeling/run_language_modeling.py @@ -37,8 +37,10 @@ from transformers import ( AutoTokenizer, DataCollatorForLanguageModeling, DataCollatorForPermutationLanguageModeling, + DataCollatorForWholeWordMask, HfArgumentParser, LineByLineTextDataset, + LineByLineWithRefDataset, PreTrainedTokenizer, TextDataset, Trainer, @@ -101,6 +103,10 @@ class DataTrainingArguments: default=None, metadata={"help": "An optional input evaluation data file to evaluate the perplexity on (a text file)."}, ) + chinese_ref_file: Optional[str] = field( + default=None, + metadata={"help": "An optional input ref data file for whole word mask in Chinees."}, + ) line_by_line: bool = field( default=False, metadata={"help": "Whether distinct lines of text in the dataset are to be handled as distinct sequences."}, @@ -109,6 +115,7 @@ class DataTrainingArguments: mlm: bool = field( default=False, metadata={"help": "Train with masked-language modeling loss instead of language modeling."} ) + whole_word_mask: bool = field(default=False, metadata={"help": "Whether ot not to use whole word mask."}) mlm_probability: float = field( default=0.15, metadata={"help": "Ratio of tokens to mask for masked language modeling loss"} ) @@ -143,6 +150,16 @@ def get_dataset( ): def _dataset(file_path): if args.line_by_line: + if args.chinese_ref_file is not None: + if not args.whole_word_mask or not args.mlm: + raise ValueError("You need to set world whole masking and mlm to True for Chinese Whole Word Mask") + return LineByLineWithRefDataset( + tokenizer=tokenizer, + file_path=file_path, + block_size=args.block_size, + ref_path=args.chinese_ref_file, + ) + return LineByLineTextDataset(tokenizer=tokenizer, file_path=file_path, block_size=args.block_size) else: return TextDataset( @@ -174,7 +191,6 @@ def main(): "Cannot do evaluation without an evaluation data file. Either supply a file to --eval_data_file " "or remove the --do_eval argument." ) - if ( os.path.exists(training_args.output_dir) and os.listdir(training_args.output_dir) @@ -270,9 +286,14 @@ def main(): max_span_length=data_args.max_span_length, ) else: - data_collator = DataCollatorForLanguageModeling( - tokenizer=tokenizer, mlm=data_args.mlm, mlm_probability=data_args.mlm_probability - ) + if data_args.mlm and data_args.whole_word_mask: + data_collator = DataCollatorForWholeWordMask( + tokenizer=tokenizer, mlm_probability=data_args.mlm_probability + ) + else: + data_collator = DataCollatorForLanguageModeling( + tokenizer=tokenizer, mlm=data_args.mlm, mlm_probability=data_args.mlm_probability + ) # Initialize our Trainer trainer = Trainer( diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index 26a1bfd0b..3ed23ff30 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -284,6 +284,7 @@ if is_torch_available(): DataCollatorForNextSentencePrediction, DataCollatorForPermutationLanguageModeling, DataCollatorForSOP, + DataCollatorForWholeWordMask, DataCollatorWithPadding, default_data_collator, ) @@ -291,6 +292,7 @@ if is_torch_available(): GlueDataset, GlueDataTrainingArguments, LineByLineTextDataset, + LineByLineWithRefDataset, LineByLineWithSOPTextDataset, SquadDataset, SquadDataTrainingArguments, diff --git a/src/transformers/data/data_collator.py b/src/transformers/data/data_collator.py index f2fec6bb2..d05061a7c 100644 --- a/src/transformers/data/data_collator.py +++ b/src/transformers/data/data_collator.py @@ -1,3 +1,4 @@ +import random from dataclasses import dataclass from typing import Any, Callable, Dict, List, NewType, Optional, Tuple, Union @@ -195,6 +196,124 @@ class DataCollatorForLanguageModeling: return inputs, labels +@dataclass +class DataCollatorForWholeWordMask(DataCollatorForLanguageModeling): + """ + Data collator used for language modeling. + - collates batches of tensors, honoring their tokenizer's pad_token + - preprocesses batches for masked language modeling + """ + + def __call__( + self, examples: List[Union[List[int], torch.Tensor, Dict[str, torch.Tensor]]] + ) -> Dict[str, torch.Tensor]: + if isinstance(examples[0], (dict, BatchEncoding)): + input_ids = [e["input_ids"] for e in examples] + else: + input_ids = examples + examples = [{"input_ids": e} for e in examples] + + batch_input = self._tensorize_batch(input_ids) + + mask_labels = [] + for e in examples: + ref_tokens = [] + for id in e["input_ids"].tolist(): + token = self.tokenizer._convert_id_to_token(id) + ref_tokens.append(token) + + # For Chinese tokens, we need extra inf to mark sub-word, e.g [喜,欢]-> [喜,##欢] + if "chinese_ref" in e: + ref_pos = e["chinese_ref"].tolist() + len_seq = e["input_ids"].size(0) + for i in range(len_seq): + if i in ref_pos: + ref_tokens[i] = "##" + ref_tokens[i] + mask_labels.append(self._whole_word_mask(ref_tokens)) + batch_mask = self._tensorize_batch(mask_labels) + inputs, labels = self.mask_tokens(batch_input, batch_mask) + return {"input_ids": inputs, "labels": labels} + + def _whole_word_mask(self, input_tokens: List[str], max_predictions=512): + """ + Get 0/1 labels for masked tokens with whole word mask proxy + """ + + cand_indexes = [] + for (i, token) in enumerate(input_tokens): + if token == "[CLS]" or token == "[SEP]": + continue + + if len(cand_indexes) >= 1 and token.startswith("##"): + cand_indexes[-1].append(i) + else: + cand_indexes.append([i]) + + random.shuffle(cand_indexes) + num_to_predict = min(max_predictions, max(1, int(round(len(input_tokens) * self.mlm_probability)))) + masked_lms = [] + covered_indexes = set() + for index_set in cand_indexes: + if len(masked_lms) >= num_to_predict: + break + # If adding a whole-word mask would exceed the maximum number of + # predictions, then just skip this candidate. + if len(masked_lms) + len(index_set) > num_to_predict: + continue + is_any_index_covered = False + for index in index_set: + if index in covered_indexes: + is_any_index_covered = True + break + if is_any_index_covered: + continue + for index in index_set: + covered_indexes.add(index) + masked_lms.append(index) + + assert len(covered_indexes) == len(masked_lms) + mask_labels = [1 if i in covered_indexes else 0 for i in range(len(input_tokens))] + return mask_labels + + def mask_tokens(self, inputs: torch.Tensor, mask_labels: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Prepare masked tokens inputs/labels for masked language modeling: 80% MASK, 10% random, 10% original. + Set 'mask_labels' means we use whole word mask (wwm), we directly mask idxs according to it's ref. + """ + + if self.tokenizer.mask_token is None: + raise ValueError( + "This tokenizer does not have a mask token which is necessary for masked language modeling. Remove the --mlm flag if you want to use this tokenizer." + ) + labels = inputs.clone() + # We sample a few tokens in each sequence for masked-LM training (with probability args.mlm_probability defaults to 0.15 in Bert/RoBERTa) + + probability_matrix = mask_labels + + special_tokens_mask = [ + self.tokenizer.get_special_tokens_mask(val, already_has_special_tokens=True) for val in labels.tolist() + ] + probability_matrix.masked_fill_(torch.tensor(special_tokens_mask, dtype=torch.bool), value=0.0) + if self.tokenizer._pad_token is not None: + padding_mask = labels.eq(self.tokenizer.pad_token_id) + probability_matrix.masked_fill_(padding_mask, value=0.0) + + masked_indices = probability_matrix.bool() + labels[~masked_indices] = -100 # We only compute loss on masked tokens + + # 80% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK]) + indices_replaced = torch.bernoulli(torch.full(labels.shape, 0.8)).bool() & masked_indices + inputs[indices_replaced] = self.tokenizer.convert_tokens_to_ids(self.tokenizer.mask_token) + + # 10% of the time, we replace masked input tokens with random word + indices_random = torch.bernoulli(torch.full(labels.shape, 0.5)).bool() & masked_indices & ~indices_replaced + random_words = torch.randint(len(self.tokenizer), labels.shape, dtype=torch.long) + inputs[indices_random] = random_words[indices_random] + + # The rest of the time (10% of the time) we keep the masked input tokens unchanged + return inputs, labels + + @dataclass class DataCollatorForSOP(DataCollatorForLanguageModeling): """ diff --git a/src/transformers/data/datasets/__init__.py b/src/transformers/data/datasets/__init__.py index c482be987..0cb518a71 100644 --- a/src/transformers/data/datasets/__init__.py +++ b/src/transformers/data/datasets/__init__.py @@ -5,6 +5,7 @@ from .glue import GlueDataset, GlueDataTrainingArguments from .language_modeling import ( LineByLineTextDataset, + LineByLineWithRefDataset, LineByLineWithSOPTextDataset, TextDataset, TextDatasetForNextSentencePrediction, diff --git a/src/transformers/data/datasets/language_modeling.py b/src/transformers/data/datasets/language_modeling.py index 17f4ae0a5..9cd337f1e 100644 --- a/src/transformers/data/datasets/language_modeling.py +++ b/src/transformers/data/datasets/language_modeling.py @@ -1,3 +1,4 @@ +import json import os import pickle import random @@ -106,12 +107,48 @@ class LineByLineTextDataset(Dataset): batch_encoding = tokenizer(lines, add_special_tokens=True, truncation=True, max_length=block_size) self.examples = batch_encoding["input_ids"] + self.examples = [{"input_ids": torch.tensor(e, dtype=torch.long)} for e in self.examples] def __len__(self): return len(self.examples) - def __getitem__(self, i) -> torch.Tensor: - return torch.tensor(self.examples[i], dtype=torch.long) + def __getitem__(self, i) -> Dict[str, torch.tensor]: + return self.examples[i] + + +class LineByLineWithRefDataset(Dataset): + """ + This will be superseded by a framework-agnostic approach + soon. + """ + + def __init__(self, tokenizer: PreTrainedTokenizer, file_path: str, block_size: int, ref_path: str): + assert os.path.isfile(file_path), f"Input file path {file_path} not found" + assert os.path.isfile(ref_path), f"Ref file path {file_path} not found" + # Here, we do not cache the features, operating under the assumption + # that we will soon use fast multithreaded tokenizers from the + # `tokenizers` repo everywhere =) + logger.info("Creating features from dataset file at %s", file_path) + logger.info("Use ref segment results at %s", ref_path) + with open(file_path, encoding="utf-8") as f: + data = [line for line in f.read().splitlines() if (len(line) > 0 and not line.isspace())] + batch_encoding = tokenizer(data, add_special_tokens=True, truncation=True, max_length=block_size) + self.examples = batch_encoding["input_ids"] + self.examples = [{"input_ids": torch.tensor(e, dtype=torch.long)} for e in self.examples] + + # Get ref inf from file + with open(ref_path, encoding="utf-8") as f: + ref = [json.loads(line) for line in f.read().splitlines() if (len(line) > 0 and not line.isspace())] + assert len(data) == len(ref) + n = len(self.examples) + for i in range(n): + self.examples[i]["chinese_ref"] = torch.tensor(ref[i], dtype=torch.long) + + def __len__(self): + return len(self.examples) + + def __getitem__(self, i) -> Dict[str, torch.tensor]: + return self.examples[i] class LineByLineWithSOPTextDataset(Dataset): diff --git a/src/transformers/utils/dummy_pt_objects.py b/src/transformers/utils/dummy_pt_objects.py index 3e152be24..9e4a8ad6f 100644 --- a/src/transformers/utils/dummy_pt_objects.py +++ b/src/transformers/utils/dummy_pt_objects.py @@ -45,6 +45,11 @@ class DataCollatorForSOP: requires_pytorch(self) +class DataCollatorForWholeWordMask: + def __init__(self, *args, **kwargs): + requires_pytorch(self) + + class DataCollatorWithPadding: def __init__(self, *args, **kwargs): requires_pytorch(self) @@ -69,6 +74,11 @@ class LineByLineTextDataset: requires_pytorch(self) +class LineByLineWithRefDataset: + def __init__(self, *args, **kwargs): + requires_pytorch(self) + + class LineByLineWithSOPTextDataset: def __init__(self, *args, **kwargs): requires_pytorch(self)