diff --git a/src/transformers/models/vision_encoder_decoder/modeling_vision_encoder_decoder.py b/src/transformers/models/vision_encoder_decoder/modeling_vision_encoder_decoder.py index 4906c13a9..a963561c2 100644 --- a/src/transformers/models/vision_encoder_decoder/modeling_vision_encoder_decoder.py +++ b/src/transformers/models/vision_encoder_decoder/modeling_vision_encoder_decoder.py @@ -155,6 +155,7 @@ class VisionEncoderDecoderModel(PreTrainedModel): config_class = VisionEncoderDecoderConfig base_model_prefix = "vision_encoder_decoder" main_input_name = "pixel_values" + supports_gradient_checkpointing = True def __init__( self, @@ -221,6 +222,11 @@ class VisionEncoderDecoderModel(PreTrainedModel): f"The encoder {self.encoder} should not have a LM Head. Please use a model without LM Head" ) + def _set_gradient_checkpointing(self, module, value=False): + # call both encoder and decoder function on gradient checkpointing + self.encoder._set_gradient_checkpointing(module, value=value) + self.decoder._set_gradient_checkpointing(module, value=value) + def get_encoder(self): return self.encoder diff --git a/tests/models/speech_encoder_decoder/test_modeling_speech_encoder_decoder.py b/tests/models/speech_encoder_decoder/test_modeling_speech_encoder_decoder.py index 2d934744f..f24150130 100644 --- a/tests/models/speech_encoder_decoder/test_modeling_speech_encoder_decoder.py +++ b/tests/models/speech_encoder_decoder/test_modeling_speech_encoder_decoder.py @@ -396,6 +396,28 @@ class EncoderDecoderMixin: input_ids_dict = self.prepare_config_and_inputs() self.check_encoder_decoder_model_generate(**input_ids_dict) + def test_training_gradient_checkpointing(self): + inputs_dict = self.prepare_config_and_inputs() + encoder_model, decoder_model = self.get_encoder_decoder_model( + inputs_dict["config"], inputs_dict["decoder_config"] + ) + + model = SpeechEncoderDecoderModel(encoder=encoder_model, decoder=decoder_model) + model.train() + model.gradient_checkpointing_enable() + model.config.decoder_start_token_id = 0 + model.config.pad_token_id = 0 + + model_inputs = { + "attention_mask": inputs_dict["attention_mask"], + "labels": inputs_dict["labels"], + "decoder_input_ids": inputs_dict["decoder_input_ids"], + } + inputs = inputs_dict["input_features"] if "input_features" in inputs_dict else inputs_dict["input_values"] + + loss = model(inputs, **model_inputs).loss + loss.backward() + @slow def test_real_model_save_load_from_pretrained(self): model_2, inputs = self.get_pretrained_model_and_inputs() @@ -590,6 +612,7 @@ class Wav2Vec2Speech2Text2(EncoderDecoderMixin, unittest.TestCase): "decoder_config": decoder_config, "decoder_input_ids": decoder_input_ids, "decoder_attention_mask": decoder_attention_mask, + "labels": decoder_input_ids, } # there are no published pretrained Speech2Text2ForCausalLM for now diff --git a/tests/models/vision_encoder_decoder/test_modeling_vision_encoder_decoder.py b/tests/models/vision_encoder_decoder/test_modeling_vision_encoder_decoder.py index 757088809..fbac8b898 100644 --- a/tests/models/vision_encoder_decoder/test_modeling_vision_encoder_decoder.py +++ b/tests/models/vision_encoder_decoder/test_modeling_vision_encoder_decoder.py @@ -324,6 +324,27 @@ class EncoderDecoderMixin: input_ids_dict = self.prepare_config_and_inputs() self.check_encoder_decoder_model_generate(**input_ids_dict) + def test_training_gradient_checkpointing(self): + inputs_dict = self.prepare_config_and_inputs() + encoder_model, decoder_model = self.get_encoder_decoder_model( + inputs_dict["config"], inputs_dict["decoder_config"] + ) + + model = VisionEncoderDecoderModel(encoder=encoder_model, decoder=decoder_model) + model.train() + model.gradient_checkpointing_enable() + model.config.decoder_start_token_id = 0 + model.config.pad_token_id = 0 + + model_inputs = { + "pixel_values": inputs_dict["pixel_values"], + "labels": inputs_dict["labels"], + "decoder_input_ids": inputs_dict["decoder_input_ids"], + } + + loss = model(**model_inputs).loss + loss.backward() + @slow def test_real_model_save_load_from_pretrained(self): model_2, inputs = self.get_pretrained_model_and_inputs() @@ -547,6 +568,7 @@ class Swin2BartModelTest(EncoderDecoderMixin, unittest.TestCase): decoder_config_and_inputs = model_tester_decoder.prepare_config_and_inputs() config, pixel_values, _ = encoder_config_and_inputs decoder_config, decoder_inputs_dict = decoder_config_and_inputs + decoder_inputs_dict["labels"] = decoder_inputs_dict["decoder_input_ids"] # make sure that cross attention layers are added decoder_config.add_cross_attention = True @@ -644,6 +666,7 @@ class ViT2TrOCR(EncoderDecoderMixin, unittest.TestCase): "decoder_config": decoder_config, "decoder_input_ids": decoder_input_ids, "decoder_attention_mask": decoder_attention_mask, + "labels": decoder_input_ids, } # there are no published pretrained TrOCR checkpoints for now