From 3c7fbf35a6c9237e8bbceb5b4f315980ed10d8a0 Mon Sep 17 00:00:00 2001 From: Sam Shleifer Date: Tue, 28 Jul 2020 08:18:11 -0400 Subject: [PATCH] MBART: support summarization tasks where max_src_len > max_tgt_len (#6003) * MBART: support summarization tasks * fix test * Style * add tokenizer test --- examples/seq2seq/README.md | 2 ++ examples/seq2seq/finetune.py | 13 +++++++------ examples/seq2seq/finetune_t5.sh | 1 + examples/seq2seq/test_seq2seq_examples.py | 15 ++++++++------- examples/seq2seq/utils.py | 4 +++- src/transformers/tokenization_bart.py | 6 +++++- tests/test_tokenization_mbart.py | 12 ++++++++++++ 7 files changed, 38 insertions(+), 15 deletions(-) diff --git a/examples/seq2seq/README.md b/examples/seq2seq/README.md index a579d728b..1e12242cb 100644 --- a/examples/seq2seq/README.md +++ b/examples/seq2seq/README.md @@ -180,6 +180,8 @@ python run_eval.py sshleifer/distilbart-cnn-12-6 $DATA_DIR/val.source dbart_val_ --task summarization \ --n_obs 100 \ --device cuda \ + --max_source_length 1024 \ + --max_target_length 56 \ --fp16 \ --bs 32 ``` diff --git a/examples/seq2seq/finetune.py b/examples/seq2seq/finetune.py index 1866042cb..e2e9ecffa 100644 --- a/examples/seq2seq/finetune.py +++ b/examples/seq2seq/finetune.py @@ -105,7 +105,13 @@ class SummarizationModule(BaseTransformer): self.hparams.git_sha = get_git_info()["repo_sha"] self.num_workers = hparams.num_workers self.decoder_start_token_id = None - self.dataset_class = Seq2SeqDataset + if self.model.config.decoder_start_token_id is None and isinstance(self.tokenizer, MBartTokenizer): + self.decoder_start_token_id = self.tokenizer.lang_code_to_id[hparams.tgt_lang] + self.model.config.decoder_start_token_id = self.decoder_start_token_id + if isinstance(self.tokenizer, MBartTokenizer): + self.dataset_class = MBartDataset + else: + self.dataset_class = Seq2SeqDataset def freeze_embeds(self): """Freeze token embeddings and positional embeddings for bart, just token embeddings for t5.""" @@ -331,11 +337,6 @@ class TranslationModule(SummarizationModule): super().__init__(hparams, **kwargs) self.dataset_kwargs["src_lang"] = hparams.src_lang self.dataset_kwargs["tgt_lang"] = hparams.tgt_lang - if self.model.config.decoder_start_token_id is None and isinstance(self.tokenizer, MBartTokenizer): - self.decoder_start_token_id = self.tokenizer.lang_code_to_id[hparams.tgt_lang] - self.model.config.decoder_start_token_id = self.decoder_start_token_id - if isinstance(self.tokenizer, MBartTokenizer): - self.dataset_class = MBartDataset def calc_generative_metrics(self, preds, target) -> dict: return calculate_bleu_score(preds, target) diff --git a/examples/seq2seq/finetune_t5.sh b/examples/seq2seq/finetune_t5.sh index ed8d26634..0021107bb 100755 --- a/examples/seq2seq/finetune_t5.sh +++ b/examples/seq2seq/finetune_t5.sh @@ -8,6 +8,7 @@ python finetune.py \ --eval_batch_size=$BS \ --output_dir=$OUTPUT_DIR \ --max_source_length=512 \ +--max_target_length=56 \ --val_check_interval=0.1 --n_val=200 \ --do_train --do_predict \ $@ diff --git a/examples/seq2seq/test_seq2seq_examples.py b/examples/seq2seq/test_seq2seq_examples.py index 191bbfac7..44e3d6c70 100644 --- a/examples/seq2seq/test_seq2seq_examples.py +++ b/examples/seq2seq/test_seq2seq_examples.py @@ -300,14 +300,17 @@ def test_mbart_dataset_truncation(): tmp_dir = make_test_data_dir() max_len_source = max(len(tokenizer.encode(a)) for a in ARTICLES) max_len_target = max(len(tokenizer.encode(a)) for a in SUMMARIES) - trunc = 4 + max_src_len = 4 + max_tgt_len = 8 + assert max_len_target > max_src_len # Truncated + assert max_len_source > max_src_len src_lang, tgt_lang = "ro_RO", "de_DE" # NOT WHAT IT WAS TRAINED ON train_dataset = MBartDataset( tokenizer, data_dir=tmp_dir, type_path="train", - max_source_length=trunc, - max_target_length=1000, # ignored + max_source_length=max_src_len, + max_target_length=max_tgt_len, # ignored src_lang=src_lang, tgt_lang=tgt_lang, ) @@ -316,17 +319,15 @@ def test_mbart_dataset_truncation(): assert isinstance(batch, dict) assert batch["attention_mask"].shape == batch["input_ids"].shape # show that articles were trimmed. - assert batch["input_ids"].shape[1] == trunc + assert batch["input_ids"].shape[1] == max_src_len # show that targets are the same len - assert batch["decoder_input_ids"].shape[1] == trunc + assert batch["decoder_input_ids"].shape[1] == max_tgt_len # check language codes in correct place assert batch["decoder_input_ids"][0, 0].item() == tokenizer.lang_code_to_id[tgt_lang] assert batch["decoder_input_ids"][0, -1].item() == tokenizer.eos_token_id assert batch["input_ids"][0, -2].item() == tokenizer.eos_token_id assert batch["input_ids"][0, -1].item() == tokenizer.lang_code_to_id[src_lang] - assert max_len_target > trunc # Truncated - assert max_len_source > trunc break # No need to test every batch diff --git a/examples/seq2seq/utils.py b/examples/seq2seq/utils.py index c2c484735..49910ab62 100644 --- a/examples/seq2seq/utils.py +++ b/examples/seq2seq/utils.py @@ -157,7 +157,8 @@ class MBartDataset(Seq2SeqDataset): super().__init__(*args, **kwargs) if self.max_source_length != self.max_target_length: warnings.warn( - f"Mbart will ignore max_target_length = {self.max_target_length} and use {self.max_source_length} for both sides." + f"Mbart is using sequence lengths {self.max_source_length}, {self.max_target_length}. " + f"Imbalanced sequence lengths may be undesired for translation tasks" ) def __getitem__(self, index) -> Dict[str, str]: @@ -178,6 +179,7 @@ class MBartDataset(Seq2SeqDataset): tgt_texts=[x["tgt_texts"] for x in batch], tgt_lang=self.tgt_lang, max_length=self.max_source_length, + max_target_length=self.max_target_length, ) return batch_encoding.data diff --git a/src/transformers/tokenization_bart.py b/src/transformers/tokenization_bart.py index 90353ddd3..c83ad0d33 100644 --- a/src/transformers/tokenization_bart.py +++ b/src/transformers/tokenization_bart.py @@ -193,6 +193,7 @@ class MBartTokenizer(XLMRobertaTokenizer): tgt_texts: Optional[List[str]] = None, tgt_lang: str = "ro_RO", max_length: Optional[int] = None, + max_target_length: Optional[int] = None, padding: str = "longest", return_tensors: str = "pt", **kwargs, @@ -224,13 +225,16 @@ class MBartTokenizer(XLMRobertaTokenizer): ) if tgt_texts is None: return model_inputs + # Process tgt_texts + if max_target_length is None: + max_target_length = max_length self.set_tgt_lang_special_tokens(tgt_lang) decoder_inputs: BatchEncoding = self( tgt_texts, add_special_tokens=True, return_tensors=return_tensors, padding=padding, - max_length=max_length, + max_length=max_target_length, truncation=True, **kwargs, ) diff --git a/tests/test_tokenization_mbart.py b/tests/test_tokenization_mbart.py index d45a5ee60..14566ac97 100644 --- a/tests/test_tokenization_mbart.py +++ b/tests/test_tokenization_mbart.py @@ -137,6 +137,18 @@ class MBartEnroIntegrationTest(unittest.TestCase): self.assertEqual(self.tokenizer.prefix_tokens, []) self.assertEqual(self.tokenizer.suffix_tokens, [self.tokenizer.eos_token_id, EN_CODE]) + def test_max_target_length(self): + + batch = self.tokenizer.prepare_translation_batch( + self.src_text, tgt_texts=self.tgt_text, max_length=3, max_target_length=10 + ) + self.assertEqual(batch.input_ids.shape[1], 3) + self.assertEqual(batch.decoder_input_ids.shape[1], 10) + # max_target_length will default to max_length if not specified + batch = self.tokenizer.prepare_translation_batch(self.src_text, tgt_texts=self.tgt_text, max_length=3) + self.assertEqual(batch.input_ids.shape[1], 3) + self.assertEqual(batch.decoder_input_ids.shape[1], 3) + def test_enro_tokenizer_batch_encode_plus(self): ids = self.tokenizer.batch_encode_plus(self.src_text).input_ids[0] self.assertListEqual(self.expected_src_tokens, ids)