From 6a5472a8e1d75e95e8fb4d0bdf8ddf9b237bac03 Mon Sep 17 00:00:00 2001 From: Yih-Dar <2521628+ydshieh@users.noreply.github.com> Date: Tue, 8 Feb 2022 16:20:53 +0100 Subject: [PATCH] Force use_cache to be False in PyTorch (#15385) * use_cache = False for PT models if labels is passed * Fix for BigBirdPegasusForConditionalGeneration * add warning if users specify use_cache=True * Use logger.warning instead of warnings.warn Co-authored-by: ydshieh --- src/transformers/models/bart/modeling_bart.py | 3 +++ .../models/bigbird_pegasus/modeling_bigbird_pegasus.py | 3 +++ src/transformers/models/blenderbot/modeling_blenderbot.py | 3 +++ .../models/blenderbot_small/modeling_blenderbot_small.py | 3 +++ src/transformers/models/led/modeling_led.py | 3 +++ src/transformers/models/marian/modeling_marian.py | 3 +++ src/transformers/models/mbart/modeling_mbart.py | 3 +++ src/transformers/models/pegasus/modeling_pegasus.py | 3 +++ .../modeling_{{cookiecutter.lowercase_modelname}}.py | 3 +++ 9 files changed, 27 insertions(+) diff --git a/src/transformers/models/bart/modeling_bart.py b/src/transformers/models/bart/modeling_bart.py index 381204cf2..70edf96cd 100755 --- a/src/transformers/models/bart/modeling_bart.py +++ b/src/transformers/models/bart/modeling_bart.py @@ -1318,6 +1318,9 @@ class BartForConditionalGeneration(BartPretrainedModel): return_dict = return_dict if return_dict is not None else self.config.use_return_dict if labels is not None: + if use_cache: + logger.warning("The `use_cache` argument is changed to `False` since `labels` is provided.") + use_cache = False if decoder_input_ids is None and decoder_inputs_embeds is None: decoder_input_ids = shift_tokens_right( labels, self.config.pad_token_id, self.config.decoder_start_token_id diff --git a/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py b/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py index c4db1b8be..aa961e0f5 100755 --- a/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py +++ b/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py @@ -2513,6 +2513,9 @@ class BigBirdPegasusForConditionalGeneration(BigBirdPegasusPreTrainedModel): return_dict = return_dict if return_dict is not None else self.config.use_return_dict if labels is not None: + if use_cache: + logger.warning("The `use_cache` argument is changed to `False` since `labels` is provided.") + use_cache = False if decoder_input_ids is None and decoder_inputs_embeds is None: decoder_input_ids = shift_tokens_right( labels, self.config.pad_token_id, self.config.decoder_start_token_id diff --git a/src/transformers/models/blenderbot/modeling_blenderbot.py b/src/transformers/models/blenderbot/modeling_blenderbot.py index df098dd6e..7751a74f9 100755 --- a/src/transformers/models/blenderbot/modeling_blenderbot.py +++ b/src/transformers/models/blenderbot/modeling_blenderbot.py @@ -1287,6 +1287,9 @@ class BlenderbotForConditionalGeneration(BlenderbotPreTrainedModel): return_dict = return_dict if return_dict is not None else self.config.use_return_dict if labels is not None: + if use_cache: + logger.warning("The `use_cache` argument is changed to `False` since `labels` is provided.") + use_cache = False if decoder_input_ids is None: decoder_input_ids = shift_tokens_right( labels, self.config.pad_token_id, self.config.decoder_start_token_id diff --git a/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py b/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py index 5875a827f..a22c4d0ce 100755 --- a/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py +++ b/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py @@ -1258,6 +1258,9 @@ class BlenderbotSmallForConditionalGeneration(BlenderbotSmallPreTrainedModel): return_dict = return_dict if return_dict is not None else self.config.use_return_dict if labels is not None: + if use_cache: + logger.warning("The `use_cache` argument is changed to `False` since `labels` is provided.") + use_cache = False if decoder_input_ids is None: decoder_input_ids = shift_tokens_right( labels, self.config.pad_token_id, self.config.decoder_start_token_id diff --git a/src/transformers/models/led/modeling_led.py b/src/transformers/models/led/modeling_led.py index 8054b9ee6..e775fd35c 100755 --- a/src/transformers/models/led/modeling_led.py +++ b/src/transformers/models/led/modeling_led.py @@ -2366,6 +2366,9 @@ class LEDForConditionalGeneration(LEDPreTrainedModel): return_dict = return_dict if return_dict is not None else self.config.use_return_dict if labels is not None: + if use_cache: + logger.warning("The `use_cache` argument is changed to `False` since `labels` is provided.") + use_cache = False if decoder_input_ids is None: decoder_input_ids = shift_tokens_right( labels, self.config.pad_token_id, self.config.decoder_start_token_id diff --git a/src/transformers/models/marian/modeling_marian.py b/src/transformers/models/marian/modeling_marian.py index f3bd96eeb..20cbd21f7 100755 --- a/src/transformers/models/marian/modeling_marian.py +++ b/src/transformers/models/marian/modeling_marian.py @@ -1291,6 +1291,9 @@ class MarianMTModel(MarianPreTrainedModel): return_dict = return_dict if return_dict is not None else self.config.use_return_dict if labels is not None: + if use_cache: + logger.warning("The `use_cache` argument is changed to `False` since `labels` is provided.") + use_cache = False if decoder_input_ids is None: decoder_input_ids = shift_tokens_right( labels, self.config.pad_token_id, self.config.decoder_start_token_id diff --git a/src/transformers/models/mbart/modeling_mbart.py b/src/transformers/models/mbart/modeling_mbart.py index fc09f0a7e..3e747b4b1 100755 --- a/src/transformers/models/mbart/modeling_mbart.py +++ b/src/transformers/models/mbart/modeling_mbart.py @@ -1314,6 +1314,9 @@ class MBartForConditionalGeneration(MBartPreTrainedModel): return_dict = return_dict if return_dict is not None else self.config.use_return_dict if labels is not None: + if use_cache: + logger.warning("The `use_cache` argument is changed to `False` since `labels` is provided.") + use_cache = False if decoder_input_ids is None: decoder_input_ids = shift_tokens_right(labels, self.config.pad_token_id) diff --git a/src/transformers/models/pegasus/modeling_pegasus.py b/src/transformers/models/pegasus/modeling_pegasus.py index 32923ce44..5eed41254 100755 --- a/src/transformers/models/pegasus/modeling_pegasus.py +++ b/src/transformers/models/pegasus/modeling_pegasus.py @@ -1381,6 +1381,9 @@ class PegasusForConditionalGeneration(PegasusPreTrainedModel): return_dict = return_dict if return_dict is not None else self.config.use_return_dict if labels is not None: + if use_cache: + logger.warning("The `use_cache` argument is changed to `False` since `labels` is provided.") + use_cache = False if decoder_input_ids is None: decoder_input_ids = shift_tokens_right( labels, self.config.pad_token_id, self.config.decoder_start_token_id diff --git a/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/modeling_{{cookiecutter.lowercase_modelname}}.py b/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/modeling_{{cookiecutter.lowercase_modelname}}.py index 63233c4bf..69a0d8176 100755 --- a/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/modeling_{{cookiecutter.lowercase_modelname}}.py +++ b/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/modeling_{{cookiecutter.lowercase_modelname}}.py @@ -2832,6 +2832,9 @@ class {{cookiecutter.camelcase_modelname}}ForConditionalGeneration({{cookiecutte return_dict = return_dict if return_dict is not None else self.config.use_return_dict if labels is not None: + if use_cache: + logger.warning("The `use_cache` argument is changed to `False` since `labels` is provided.") + use_cache = False if decoder_input_ids is None: decoder_input_ids = shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id)