From c11160114a155de38c072bfa56eab10e938ca5b7 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Wed, 11 Mar 2020 14:30:07 +0100 Subject: [PATCH 1/6] small clean-up --- src/transformers/modeling_utils.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 6e26a9318..253844ad4 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -845,7 +845,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin): encoder_inputs = input_ids input_ids = torch.full( (effective_batch_size * num_beams, 1), - bos_token_id, + bos_token_id, # TODO: wait for results of Bart CNN summarization dtype=torch.long, device=next(self.parameters()).device, ) @@ -1082,7 +1082,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin): scores = F.log_softmax(next_token_logits, dim=-1) # (batch_size * num_beams, vocab_size) if self.config.is_encoder_decoder and do_sample is False: - # TODO(PVP) to be refactored later - do we need this boolean flag here? Also Only add for beam_search or also for no_beam_search? The prepare scores fn is ugly here + # TODO: maybe give better naming scores = self.prepare_scores_for_generation(scores, cur_len, max_length) # set eos token prob to zero if min_length is not reached @@ -1276,7 +1276,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin): decoded = torch.stack(best).type(torch.long).to(next(self.parameters()).device) if self.config.is_encoder_decoder: - # do not return first token return decoded[:, 1:] return decoded From 6047f46b199ba49f353b31d7bedad2b3e076f52e Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Thu, 12 Mar 2020 20:17:50 +0100 Subject: [PATCH 2/6] re-add eos token to get good bart results --- src/transformers/modeling_utils.py | 10 +++++++++- tests/test_modeling_bart.py | 7 ++++++- 2 files changed, 15 insertions(+), 2 deletions(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 253844ad4..57b4204a5 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -628,6 +628,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin): no_repeat_ngram_size=None, num_return_sequences=None, attention_mask=None, + decoder_start_token_id=None, ): r""" Generates sequences for models with a LM head. The method currently supports greedy or penalized greedy decoding, sampling with top-k or nucleus sampling and beam-search. @@ -739,6 +740,10 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin): num_return_sequences = ( num_return_sequences if num_return_sequences is not None else self.config.num_return_sequences ) + # TODO: think about how to make this cleaner + decoder_start_token_id = ( + decoder_start_token_id if decoder_start_token_id is not None else self.config.bos_token_id + ) if input_ids is not None: batch_size = input_ids.shape[0] # overriden by the input batch_size @@ -765,6 +770,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin): assert (eos_token_ids is None) or ( isinstance(eos_token_ids, (list, tuple)) and ((isinstance(e, int) and e >= 0) for e in eos_token_ids) ), "`eos_token_ids` should be a positive integer or a list/tuple of positive integers." + assert ( + decoder_start_token_id is not None or self.config.is_encoder_decoder is False + ), "`decoder_start_token_id` has to be defined if model is encoder-decoder model" assert length_penalty > 0, "`length_penalty` should be strictly positive." assert ( isinstance(no_repeat_ngram_size, int) and no_repeat_ngram_size >= 0 @@ -845,7 +853,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin): encoder_inputs = input_ids input_ids = torch.full( (effective_batch_size * num_beams, 1), - bos_token_id, # TODO: wait for results of Bart CNN summarization + decoder_start_token_id, # TODO: see whether this is the best result dtype=torch.long, device=next(self.parameters()).device, ) diff --git a/tests/test_modeling_bart.py b/tests/test_modeling_bart.py index b885ccf1b..8ad02feb2 100644 --- a/tests/test_modeling_bart.py +++ b/tests/test_modeling_bart.py @@ -432,7 +432,11 @@ class BartModelIntegrationTest(unittest.TestCase): tokens = tok.encode(text, return_tensors="pt").to(torch_device) extra_len = 20 gen_tokens = hf.generate( - tokens, num_beams=4, max_length=extra_len + 2, do_sample=False + tokens, + num_beams=4, + max_length=extra_len + 2, + do_sample=False, + decoder_start_token_id=hf.config.eos_token_id, ) # repetition_penalty=10., expected_result = "The Palestinian Authority officially became the 123rd member of the International Criminal Court on Wednesday." generated = [tok.decode(g,) for g in gen_tokens] @@ -477,6 +481,7 @@ class BartModelIntegrationTest(unittest.TestCase): no_repeat_ngram_size=3, do_sample=False, early_stopping=True, + decoder_start_token_id=hf.config.eos_token_id, ) decoded = [ From f1c71da1154cb8bf58e07eccb4c1a3fcae83efb8 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Thu, 12 Mar 2020 21:00:54 +0100 Subject: [PATCH 3/6] fix eos_token_ids in test --- tests/test_modeling_bart.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/test_modeling_bart.py b/tests/test_modeling_bart.py index 8ad02feb2..a0ad29830 100644 --- a/tests/test_modeling_bart.py +++ b/tests/test_modeling_bart.py @@ -61,7 +61,7 @@ class ModelTester: self.hidden_dropout_prob = 0.1 self.attention_probs_dropout_prob = 0.1 self.max_position_embeddings = 20 - self.eos_token_id = 2 + self.eos_token_ids = [2] self.pad_token_id = 1 self.bos_token_id = 0 torch.manual_seed(0) @@ -436,7 +436,7 @@ class BartModelIntegrationTest(unittest.TestCase): num_beams=4, max_length=extra_len + 2, do_sample=False, - decoder_start_token_id=hf.config.eos_token_id, + decoder_start_token_id=hf.config.eos_token_ids[0], ) # repetition_penalty=10., expected_result = "The Palestinian Authority officially became the 123rd member of the International Criminal Court on Wednesday." generated = [tok.decode(g,) for g in gen_tokens] @@ -481,7 +481,7 @@ class BartModelIntegrationTest(unittest.TestCase): no_repeat_ngram_size=3, do_sample=False, early_stopping=True, - decoder_start_token_id=hf.config.eos_token_id, + decoder_start_token_id=hf.config.eos_token_ids[0], ) decoded = [ From 6a82f774f257f60aa7dc8b813e90e3ffb60e32d1 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Thu, 12 Mar 2020 21:10:51 +0100 Subject: [PATCH 4/6] fix typo --- tests/test_modeling_bart.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_modeling_bart.py b/tests/test_modeling_bart.py index a0ad29830..31a3a8a25 100644 --- a/tests/test_modeling_bart.py +++ b/tests/test_modeling_bart.py @@ -82,7 +82,7 @@ class ModelTester: dropout=self.hidden_dropout_prob, attention_dropout=self.attention_probs_dropout_prob, max_position_embeddings=self.max_position_embeddings, - eos_token_ids=[self.eos_token_id], + eos_token_ids=[2], bos_token_id=self.bos_token_id, pad_token_id=self.pad_token_id, ) From c2ee3840ae5bd4555ada6c9d12ccdc03e0ca8454 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Fri, 13 Mar 2020 16:34:44 +0100 Subject: [PATCH 5/6] update file to new starting token logic --- examples/summarization/bart/evaluate_cnn.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/examples/summarization/bart/evaluate_cnn.py b/examples/summarization/bart/evaluate_cnn.py index 474fcd46c..93590614e 100644 --- a/examples/summarization/bart/evaluate_cnn.py +++ b/examples/summarization/bart/evaluate_cnn.py @@ -20,6 +20,10 @@ def generate_summaries(lns, out_file, batch_size=8, device=DEFAULT_DEVICE): fout = Path(out_file).open("w") model = BartForConditionalGeneration.from_pretrained("bart-large-cnn", output_past=True,).to(device) tokenizer = BartTokenizer.from_pretrained("bart-large") + + max_length = 140 + min_length = 55 + for batch in tqdm(list(chunks(lns, batch_size))): dct = tokenizer.batch_encode_plus(batch, max_length=1024, return_tensors="pt", pad_to_max_length=True) summaries = model.generate( @@ -27,11 +31,12 @@ def generate_summaries(lns, out_file, batch_size=8, device=DEFAULT_DEVICE): attention_mask=dct["attention_mask"].to(device), num_beams=4, length_penalty=2.0, - max_length=142, # +2 from original because we start at step=1 and stop before max_length - min_length=56, # +1 from original because we start at step=1 + max_length=max_length + 2, # +2 from original because we start at step=1 and stop before max_length + min_length=min_length + 1, # +1 from original because we start at step=1 no_repeat_ngram_size=3, early_stopping=True, do_sample=False, + decoder_start_token_id=model.config.eos_token_ids[0] ) dec = [tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=False) for g in summaries] for hypothesis in dec: From 4f75d380a46f343b0c2244eff60484ed2a85ec16 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Fri, 13 Mar 2020 16:35:52 +0100 Subject: [PATCH 6/6] make style --- examples/summarization/bart/evaluate_cnn.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/summarization/bart/evaluate_cnn.py b/examples/summarization/bart/evaluate_cnn.py index 93590614e..fded7e51f 100644 --- a/examples/summarization/bart/evaluate_cnn.py +++ b/examples/summarization/bart/evaluate_cnn.py @@ -36,7 +36,7 @@ def generate_summaries(lns, out_file, batch_size=8, device=DEFAULT_DEVICE): no_repeat_ngram_size=3, early_stopping=True, do_sample=False, - decoder_start_token_id=model.config.eos_token_ids[0] + decoder_start_token_id=model.config.eos_token_ids[0], ) dec = [tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=False) for g in summaries] for hypothesis in dec: