mirror of
https://github.com/saymrwulf/transformers.git
synced 2026-05-14 20:58:08 +00:00
Run mlm pad to multiple for fp16 (#11128)
* Add mlm collator pad to multiple option (#10627) * Use padding to 8x in run mlm (#10627)
This commit is contained in:
parent
dfed4ec263
commit
6c40e49712
3 changed files with 67 additions and 9 deletions
|
|
@ -422,7 +422,12 @@ def main():
|
|||
|
||||
# Data collator
|
||||
# This one will take care of randomly masking the tokens.
|
||||
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm_probability=data_args.mlm_probability)
|
||||
pad_to_multiple_of_8 = data_args.line_by_line and training_args.fp16 and not data_args.pad_to_max_length
|
||||
data_collator = DataCollatorForLanguageModeling(
|
||||
tokenizer=tokenizer,
|
||||
mlm_probability=data_args.mlm_probability,
|
||||
pad_to_multiple_of=8 if pad_to_multiple_of_8 else None,
|
||||
)
|
||||
|
||||
# Initialize our Trainer
|
||||
trainer = Trainer(
|
||||
|
|
|
|||
|
|
@ -192,7 +192,7 @@ class DataCollatorForTokenClassification:
|
|||
return batch
|
||||
|
||||
|
||||
def _collate_batch(examples, tokenizer):
|
||||
def _collate_batch(examples, tokenizer, pad_to_multiple_of: Optional[int] = None):
|
||||
"""Collate `examples` into a batch, using the information in `tokenizer` for padding if necessary."""
|
||||
# Tensorize if necessary.
|
||||
if isinstance(examples[0], (list, tuple)):
|
||||
|
|
@ -201,7 +201,7 @@ def _collate_batch(examples, tokenizer):
|
|||
# Check if padding is necessary.
|
||||
length_of_first = examples[0].size(0)
|
||||
are_tensors_same_length = all(x.size(0) == length_of_first for x in examples)
|
||||
if are_tensors_same_length:
|
||||
if are_tensors_same_length and (pad_to_multiple_of is None or length_of_first % pad_to_multiple_of == 0):
|
||||
return torch.stack(examples, dim=0)
|
||||
|
||||
# If yes, check if we have a `pad_token`.
|
||||
|
|
@ -213,6 +213,8 @@ def _collate_batch(examples, tokenizer):
|
|||
|
||||
# Creating the full tensor and filling it with our data.
|
||||
max_length = max(x.size(0) for x in examples)
|
||||
if pad_to_multiple_of is not None and (max_length % pad_to_multiple_of != 0):
|
||||
max_length = ((max_length // pad_to_multiple_of) + 1) * pad_to_multiple_of
|
||||
result = examples[0].new_full([len(examples), max_length], tokenizer.pad_token_id)
|
||||
for i, example in enumerate(examples):
|
||||
if tokenizer.padding_side == "right":
|
||||
|
|
@ -311,6 +313,8 @@ class DataCollatorForLanguageModeling:
|
|||
non-masked tokens and the value to predict for the masked token.
|
||||
mlm_probability (:obj:`float`, `optional`, defaults to 0.15):
|
||||
The probability with which to (randomly) mask tokens in the input, when :obj:`mlm` is set to :obj:`True`.
|
||||
pad_to_multiple_of (:obj:`int`, `optional`):
|
||||
If set will pad the sequence to a multiple of the provided value.
|
||||
|
||||
.. note::
|
||||
|
||||
|
|
@ -323,6 +327,7 @@ class DataCollatorForLanguageModeling:
|
|||
tokenizer: PreTrainedTokenizerBase
|
||||
mlm: bool = True
|
||||
mlm_probability: float = 0.15
|
||||
pad_to_multiple_of: Optional[int] = None
|
||||
|
||||
def __post_init__(self):
|
||||
if self.mlm and self.tokenizer.mask_token is None:
|
||||
|
|
@ -336,9 +341,9 @@ class DataCollatorForLanguageModeling:
|
|||
) -> Dict[str, torch.Tensor]:
|
||||
# Handle dict or lists with proper padding and conversion to tensor.
|
||||
if isinstance(examples[0], (dict, BatchEncoding)):
|
||||
batch = self.tokenizer.pad(examples, return_tensors="pt")
|
||||
batch = self.tokenizer.pad(examples, return_tensors="pt", pad_to_multiple_of=self.pad_to_multiple_of)
|
||||
else:
|
||||
batch = {"input_ids": _collate_batch(examples, self.tokenizer)}
|
||||
batch = {"input_ids": _collate_batch(examples, self.tokenizer, pad_to_multiple_of=self.pad_to_multiple_of)}
|
||||
|
||||
# If special token mask has been preprocessed, pop it from the dict.
|
||||
special_tokens_mask = batch.pop("special_tokens_mask", None)
|
||||
|
|
|
|||
|
|
@ -146,11 +146,8 @@ class DataCollatorIntegrationTest(unittest.TestCase):
|
|||
self.assertEqual(batch["labels"].shape, torch.Size([2, 6]))
|
||||
self.assertEqual(batch["labels"][0].tolist(), [0, 1, 2] + [-1] * 3)
|
||||
|
||||
def test_data_collator_for_language_modeling(self):
|
||||
def _test_no_pad_and_pad(self, no_pad_features, pad_features):
|
||||
tokenizer = BertTokenizer(self.vocab_file)
|
||||
no_pad_features = [{"input_ids": list(range(10))}, {"input_ids": list(range(10))}]
|
||||
pad_features = [{"input_ids": list(range(5))}, {"input_ids": list(range(10))}]
|
||||
|
||||
data_collator = DataCollatorForLanguageModeling(tokenizer, mlm=False)
|
||||
batch = data_collator(no_pad_features)
|
||||
self.assertEqual(batch["input_ids"].shape, torch.Size((2, 10)))
|
||||
|
|
@ -160,6 +157,15 @@ class DataCollatorIntegrationTest(unittest.TestCase):
|
|||
self.assertEqual(batch["input_ids"].shape, torch.Size((2, 10)))
|
||||
self.assertEqual(batch["labels"].shape, torch.Size((2, 10)))
|
||||
|
||||
data_collator = DataCollatorForLanguageModeling(tokenizer, mlm=False, pad_to_multiple_of=8)
|
||||
batch = data_collator(no_pad_features)
|
||||
self.assertEqual(batch["input_ids"].shape, torch.Size((2, 16)))
|
||||
self.assertEqual(batch["labels"].shape, torch.Size((2, 16)))
|
||||
|
||||
batch = data_collator(pad_features)
|
||||
self.assertEqual(batch["input_ids"].shape, torch.Size((2, 16)))
|
||||
self.assertEqual(batch["labels"].shape, torch.Size((2, 16)))
|
||||
|
||||
tokenizer._pad_token = None
|
||||
data_collator = DataCollatorForLanguageModeling(tokenizer, mlm=False)
|
||||
with self.assertRaises(ValueError):
|
||||
|
|
@ -185,6 +191,32 @@ class DataCollatorIntegrationTest(unittest.TestCase):
|
|||
self.assertTrue(torch.any(masked_tokens))
|
||||
self.assertTrue(all(x == -100 for x in batch["labels"][~masked_tokens].tolist()))
|
||||
|
||||
data_collator = DataCollatorForLanguageModeling(tokenizer, pad_to_multiple_of=8)
|
||||
batch = data_collator(no_pad_features)
|
||||
self.assertEqual(batch["input_ids"].shape, torch.Size((2, 16)))
|
||||
self.assertEqual(batch["labels"].shape, torch.Size((2, 16)))
|
||||
|
||||
masked_tokens = batch["input_ids"] == tokenizer.mask_token_id
|
||||
self.assertTrue(torch.any(masked_tokens))
|
||||
self.assertTrue(all(x == -100 for x in batch["labels"][~masked_tokens].tolist()))
|
||||
|
||||
batch = data_collator(pad_features)
|
||||
self.assertEqual(batch["input_ids"].shape, torch.Size((2, 16)))
|
||||
self.assertEqual(batch["labels"].shape, torch.Size((2, 16)))
|
||||
|
||||
masked_tokens = batch["input_ids"] == tokenizer.mask_token_id
|
||||
self.assertTrue(torch.any(masked_tokens))
|
||||
self.assertTrue(all(x == -100 for x in batch["labels"][~masked_tokens].tolist()))
|
||||
|
||||
def test_data_collator_for_language_modeling(self):
|
||||
no_pad_features = [{"input_ids": list(range(10))}, {"input_ids": list(range(10))}]
|
||||
pad_features = [{"input_ids": list(range(5))}, {"input_ids": list(range(10))}]
|
||||
self._test_no_pad_and_pad(no_pad_features, pad_features)
|
||||
|
||||
no_pad_features = [list(range(10)), list(range(10))]
|
||||
pad_features = [list(range(5)), list(range(10))]
|
||||
self._test_no_pad_and_pad(no_pad_features, pad_features)
|
||||
|
||||
def test_plm(self):
|
||||
tokenizer = BertTokenizer(self.vocab_file)
|
||||
no_pad_features = [{"input_ids": list(range(10))}, {"input_ids": list(range(10))}]
|
||||
|
|
@ -225,6 +257,14 @@ class DataCollatorIntegrationTest(unittest.TestCase):
|
|||
self.assertEqual(batch["labels"].shape, torch.Size((2, 5)))
|
||||
self.assertEqual(batch["next_sentence_label"].shape, torch.Size((2,)))
|
||||
|
||||
data_collator = DataCollatorForLanguageModeling(tokenizer, pad_to_multiple_of=8)
|
||||
batch = data_collator(features)
|
||||
|
||||
self.assertEqual(batch["input_ids"].shape, torch.Size((2, 8)))
|
||||
self.assertEqual(batch["token_type_ids"].shape, torch.Size((2, 8)))
|
||||
self.assertEqual(batch["labels"].shape, torch.Size((2, 8)))
|
||||
self.assertEqual(batch["next_sentence_label"].shape, torch.Size((2,)))
|
||||
|
||||
def test_sop(self):
|
||||
tokenizer = BertTokenizer(self.vocab_file)
|
||||
features = [
|
||||
|
|
@ -242,3 +282,11 @@ class DataCollatorIntegrationTest(unittest.TestCase):
|
|||
self.assertEqual(batch["token_type_ids"].shape, torch.Size((2, 5)))
|
||||
self.assertEqual(batch["labels"].shape, torch.Size((2, 5)))
|
||||
self.assertEqual(batch["sentence_order_label"].shape, torch.Size((2,)))
|
||||
|
||||
data_collator = DataCollatorForLanguageModeling(tokenizer, pad_to_multiple_of=8)
|
||||
batch = data_collator(features)
|
||||
|
||||
self.assertEqual(batch["input_ids"].shape, torch.Size((2, 8)))
|
||||
self.assertEqual(batch["token_type_ids"].shape, torch.Size((2, 8)))
|
||||
self.assertEqual(batch["labels"].shape, torch.Size((2, 8)))
|
||||
self.assertEqual(batch["sentence_order_label"].shape, torch.Size((2,)))
|
||||
|
|
|
|||
Loading…
Reference in a new issue