diff --git a/examples/language-modeling/run_language_modeling.py b/examples/language-modeling/run_language_modeling.py index d9465c376..740cb6364 100644 --- a/examples/language-modeling/run_language_modeling.py +++ b/examples/language-modeling/run_language_modeling.py @@ -14,9 +14,9 @@ # See the License for the specific language governing permissions and # limitations under the License. """ -Fine-tuning the library models for language modeling on a text file (GPT, GPT-2, BERT, RoBERTa). -GPT and GPT-2 are fine-tuned using a causal language modeling (CLM) loss while BERT and RoBERTa are fine-tuned -using a masked language modeling (MLM) loss. +Fine-tuning the library models for language modeling on a text file (GPT, GPT-2, CTRL, BERT, RoBERTa, XLNet). +GPT, GPT-2 and CTRL are fine-tuned using a causal language modeling (CLM) loss. BERT and RoBERTa are fine-tuned +using a masked language modeling (MLM) loss. XLNet is fine-tuned using a permutation language modeling (PLM) loss. """ @@ -33,6 +33,7 @@ from transformers import ( AutoModelWithLMHead, AutoTokenizer, DataCollatorForLanguageModeling, + DataCollatorForPermutationLanguageModeling, HfArgumentParser, LineByLineTextDataset, PreTrainedTokenizer, @@ -101,6 +102,15 @@ class DataTrainingArguments: mlm_probability: float = field( default=0.15, metadata={"help": "Ratio of tokens to mask for masked language modeling loss"} ) + plm_probability: float = field( + default=1 / 6, + metadata={ + "help": "Ratio of length of a span of masked tokens to surrounding context length for permutation language modeling." + }, + ) + max_span_length: int = field( + default=5, metadata={"help": "Maximum length of a span of masked tokens for permutation language modeling."} + ) block_size: int = field( default=-1, @@ -207,8 +217,8 @@ def main(): if config.model_type in ["bert", "roberta", "distilbert", "camembert"] and not data_args.mlm: raise ValueError( - "BERT and RoBERTa-like models do not have LM heads but masked LM heads. They must be run using the --mlm " - "flag (masked language modeling)." + "BERT and RoBERTa-like models do not have LM heads but masked LM heads. They must be run using the" + "--mlm flag (masked language modeling)." ) if data_args.block_size <= 0: @@ -221,9 +231,14 @@ def main(): train_dataset = get_dataset(data_args, tokenizer=tokenizer) if training_args.do_train else None eval_dataset = get_dataset(data_args, tokenizer=tokenizer, evaluate=True) if training_args.do_eval else None - data_collator = DataCollatorForLanguageModeling( - tokenizer=tokenizer, mlm=data_args.mlm, mlm_probability=data_args.mlm_probability - ) + if config.model_type == "xlnet": + data_collator = DataCollatorForPermutationLanguageModeling( + tokenizer=tokenizer, plm_probability=data_args.plm_probability, max_span_length=data_args.max_span_length, + ) + else: + data_collator = DataCollatorForLanguageModeling( + tokenizer=tokenizer, mlm=data_args.mlm, mlm_probability=data_args.mlm_probability + ) # Initialize our Trainer trainer = Trainer( diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index d8dd42d67..7e0e6dbd6 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -400,7 +400,12 @@ if is_torch_available(): # Trainer from .trainer import Trainer, torch_distributed_zero_first - from .data.data_collator import default_data_collator, DataCollator, DataCollatorForLanguageModeling + from .data.data_collator import ( + default_data_collator, + DataCollator, + DataCollatorForLanguageModeling, + DataCollatorForPermutationLanguageModeling, + ) from .data.datasets import GlueDataset, TextDataset, LineByLineTextDataset, GlueDataTrainingArguments # Benchmarks diff --git a/src/transformers/data/data_collator.py b/src/transformers/data/data_collator.py index 29331cc83..b4d9f205b 100644 --- a/src/transformers/data/data_collator.py +++ b/src/transformers/data/data_collator.py @@ -21,8 +21,8 @@ def default_data_collator(features: List[InputDataClass]) -> Dict[str, torch.Ten Very simple data collator that: - simply collates batches of dict-like objects - Performs special handling for potential keys named: - - `label`: handles a single value (int or float) per object - - `label_ids`: handles a list of values per object + - ``label``: handles a single value (int or float) per object + - ``label_ids``: handles a list of values per object - does not do any additional preprocessing i.e., Property names of the input object will be used as corresponding inputs to the model. @@ -134,3 +134,126 @@ class DataCollatorForLanguageModeling: # The rest of the time (10% of the time) we keep the masked input tokens unchanged return inputs, labels + + +@dataclass +class DataCollatorForPermutationLanguageModeling: + """ + Data collator used for permutation language modeling. + - collates batches of tensors, honoring their tokenizer's pad_token + - preprocesses batches for permutation language modeling with procedures specific to XLNet + """ + + tokenizer: PreTrainedTokenizer + plm_probability: float = 1 / 6 + max_span_length: int = 5 # maximum length of a span of masked tokens + + def __call__(self, examples: List[torch.Tensor]) -> Dict[str, torch.Tensor]: + batch = self._tensorize_batch(examples) + inputs, perm_mask, target_mapping, labels = self.mask_tokens(batch) + return {"input_ids": inputs, "perm_mask": perm_mask, "target_mapping": target_mapping, "labels": labels} + + def _tensorize_batch(self, examples: List[torch.Tensor]) -> torch.Tensor: + length_of_first = examples[0].size(0) + are_tensors_same_length = all(x.size(0) == length_of_first for x in examples) + if are_tensors_same_length: + return torch.stack(examples, dim=0) + else: + if self.tokenizer._pad_token is None: + raise ValueError( + "You are attempting to pad samples but the tokenizer you are using" + f" ({self.tokenizer.__class__.__name__}) does not have one." + ) + return pad_sequence(examples, batch_first=True, padding_value=self.tokenizer.pad_token_id) + + def mask_tokens(self, inputs: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """ + The masked tokens to be predicted for a particular sequence are determined by the following algorithm: + 0. Start from the beginning of the sequence by setting ``cur_len = 0`` (number of tokens processed so far). + 1. Sample a ``span_length`` from the interval ``[1, max_span_length]`` (length of span of tokens to be masked) + 2. Reserve a context of length ``context_length = span_length / plm_probability`` to surround span to be masked + 3. Sample a starting point ``start_index`` from the interval ``[cur_len, cur_len + context_length - span_length]`` and mask tokens ``start_index:start_index + span_length`` + 4. Set ``cur_len = cur_len + context_length``. If ``cur_len < max_len`` (i.e. there are tokens remaining in the sequence to be processed), repeat from Step 1. + """ + + if self.tokenizer.mask_token is None: + raise ValueError( + "This tokenizer does not have a mask token which is necessary for permutation language modeling. Please add a mask token if you want to use this tokenizer." + ) + + if inputs.size(1) % 2 != 0: + raise ValueError( + "This collator requires that sequence lengths be even to create a leakage-free perm_mask. Please see relevant comments in source code for details." + ) + + labels = inputs.clone() + # Creating the mask and target_mapping tensors + masked_indices = torch.full(labels.shape, 0, dtype=torch.bool) + target_mapping = torch.zeros((labels.size(0), labels.size(1), labels.size(1)), dtype=torch.float32) + + for i in range(labels.size(0)): + # Start from the beginning of the sequence by setting `cur_len = 0` (number of tokens processed so far). + cur_len = 0 + max_len = labels.size(1) + + while cur_len < max_len: + # Sample a `span_length` from the interval `[1, max_span_length]` (length of span of tokens to be masked) + span_length = torch.randint(1, self.max_span_length + 1, (1,)).item() + # Reserve a context of length `context_length = span_length / plm_probability` to surround the span to be masked + context_length = int(span_length / self.plm_probability) + # Sample a starting point `start_index` from the interval `[cur_len, cur_len + context_length - span_length]` and mask tokens `start_index:start_index + span_length` + start_index = cur_len + torch.randint(context_length - span_length + 1, (1,)).item() + masked_indices[i, start_index : start_index + span_length] = 1 + # Set `cur_len = cur_len + context_length` + cur_len += context_length + + # Since we're replacing non-masked tokens with -100 in the labels tensor instead of skipping them altogether, + # the i-th predict corresponds to the i-th token. + target_mapping[i] = torch.eye(labels.size(1)) + + special_tokens_mask = torch.tensor( + [self.tokenizer.get_special_tokens_mask(val, already_has_special_tokens=True) for val in labels.tolist()], + dtype=torch.bool, + ) + masked_indices.masked_fill_(special_tokens_mask, value=0.0) + if self.tokenizer._pad_token is not None: + padding_mask = labels.eq(self.tokenizer.pad_token_id) + masked_indices.masked_fill_(padding_mask, value=0.0) + + # Mask indicating non-functional tokens, where functional tokens are [SEP], [CLS], padding, etc. + non_func_mask = ~(padding_mask & special_tokens_mask) + + inputs[masked_indices] = self.tokenizer.mask_token_id + labels[~masked_indices] = -100 # We only compute loss on masked tokens + + perm_mask = torch.zeros((labels.size(0), labels.size(1), labels.size(1)), dtype=torch.float32) + + for i in range(labels.size(0)): + # Generate permutation indices i.e. sample a random factorisation order for the sequence. This will + # determine which tokens a given token can attend to (encoded in `perm_mask`). + # Note: Length of token sequence being permuted has to be less than or equal to reused sequence length + # (see documentation for `mems`), otherwise information may leak through due to reuse. In this implementation, + # we assume that reused length is half of sequence length and permutation length is equal to reused length. + # This requires that the sequence length be even. + + # Create a linear factorisation order + perm_index = torch.arange(labels.size(1)) + # Split this into two halves, assuming that half the sequence is reused each time + perm_index = perm_index.reshape((-1, labels.size(1) // 2)).transpose(0, 1) + # Permute the two halves such that they do not cross over + perm_index = perm_index[torch.randperm(labels.size(1) // 2)] + # Flatten this out into the desired permuted factorisation order + perm_index = torch.flatten(perm_index.transpose(0, 1)) + # Set the permutation indices of non-masked (non-functional) tokens to the + # smallest index (-1) so that: + # (1) They can be seen by all other positions + # (2) They cannot see masked positions, so there won't be information leak + perm_index.masked_fill_(~masked_indices[i] & non_func_mask[i], -1) + # The logic for whether the i-th token can attend on the j-th token based on the factorisation order: + # 0 (can attend): If perm_index[i] > perm_index[j] or j is neither masked nor a functional token + # 1 (cannot attend): If perm_index[i] <= perm_index[j] and j is either masked or a functional token + perm_mask[i] = ( + perm_index.reshape((labels.size(1), 1)) <= perm_index.reshape((1, labels.size(1))) + ) & masked_indices[i] + + return inputs, perm_mask, target_mapping, labels diff --git a/tests/test_trainer.py b/tests/test_trainer.py index d68eee524..dd5a487be 100644 --- a/tests/test_trainer.py +++ b/tests/test_trainer.py @@ -12,6 +12,7 @@ if is_torch_available(): AutoModelForSequenceClassification, default_data_collator, DataCollatorForLanguageModeling, + DataCollatorForPermutationLanguageModeling, GlueDataset, GlueDataTrainingArguments, TextDataset, @@ -123,6 +124,34 @@ class DataCollatorIntegrationTest(unittest.TestCase): self.assertEqual(batch["input_ids"].shape, torch.Size((2, 512))) self.assertEqual(batch["labels"].shape, torch.Size((2, 512))) + def test_plm(self): + tokenizer = AutoTokenizer.from_pretrained("xlnet-base-cased") + data_collator = DataCollatorForPermutationLanguageModeling(tokenizer) + # ^ permutation lm + + dataset = LineByLineTextDataset(tokenizer, file_path=PATH_SAMPLE_TEXT, block_size=512) + examples = [dataset[i] for i in range(len(dataset))] + batch = data_collator(examples) + self.assertIsInstance(batch, dict) + self.assertEqual(batch["input_ids"].shape, torch.Size((31, 112))) + self.assertEqual(batch["perm_mask"].shape, torch.Size((31, 112, 112))) + self.assertEqual(batch["target_mapping"].shape, torch.Size((31, 112, 112))) + self.assertEqual(batch["labels"].shape, torch.Size((31, 112))) + + dataset = TextDataset(tokenizer, file_path=PATH_SAMPLE_TEXT, block_size=512, overwrite_cache=True) + examples = [dataset[i] for i in range(len(dataset))] + batch = data_collator(examples) + self.assertIsInstance(batch, dict) + self.assertEqual(batch["input_ids"].shape, torch.Size((2, 512))) + self.assertEqual(batch["perm_mask"].shape, torch.Size((2, 512, 512))) + self.assertEqual(batch["target_mapping"].shape, torch.Size((2, 512, 512))) + self.assertEqual(batch["labels"].shape, torch.Size((2, 512))) + + example = [torch.randint(5, [5])] + with self.assertRaises(ValueError): + # Expect error due to odd sequence length + data_collator(example) + @require_torch class TrainerIntegrationTest(unittest.TestCase):