mirror of
https://github.com/saymrwulf/transformers.git
synced 2026-05-14 20:58:08 +00:00
[VisionEncoderDecoder] Add gradient checkpointing (#18697)
* add first generation tutorial * VisionEnocderDecoder gradient checkpointing * remove generation * add tests
This commit is contained in:
parent
06a6a4bd51
commit
8869bf41fe
3 changed files with 52 additions and 0 deletions
|
|
@ -155,6 +155,7 @@ class VisionEncoderDecoderModel(PreTrainedModel):
|
|||
config_class = VisionEncoderDecoderConfig
|
||||
base_model_prefix = "vision_encoder_decoder"
|
||||
main_input_name = "pixel_values"
|
||||
supports_gradient_checkpointing = True
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
|
@ -221,6 +222,11 @@ class VisionEncoderDecoderModel(PreTrainedModel):
|
|||
f"The encoder {self.encoder} should not have a LM Head. Please use a model without LM Head"
|
||||
)
|
||||
|
||||
def _set_gradient_checkpointing(self, module, value=False):
|
||||
# call both encoder and decoder function on gradient checkpointing
|
||||
self.encoder._set_gradient_checkpointing(module, value=value)
|
||||
self.decoder._set_gradient_checkpointing(module, value=value)
|
||||
|
||||
def get_encoder(self):
|
||||
return self.encoder
|
||||
|
||||
|
|
|
|||
|
|
@ -396,6 +396,28 @@ class EncoderDecoderMixin:
|
|||
input_ids_dict = self.prepare_config_and_inputs()
|
||||
self.check_encoder_decoder_model_generate(**input_ids_dict)
|
||||
|
||||
def test_training_gradient_checkpointing(self):
|
||||
inputs_dict = self.prepare_config_and_inputs()
|
||||
encoder_model, decoder_model = self.get_encoder_decoder_model(
|
||||
inputs_dict["config"], inputs_dict["decoder_config"]
|
||||
)
|
||||
|
||||
model = SpeechEncoderDecoderModel(encoder=encoder_model, decoder=decoder_model)
|
||||
model.train()
|
||||
model.gradient_checkpointing_enable()
|
||||
model.config.decoder_start_token_id = 0
|
||||
model.config.pad_token_id = 0
|
||||
|
||||
model_inputs = {
|
||||
"attention_mask": inputs_dict["attention_mask"],
|
||||
"labels": inputs_dict["labels"],
|
||||
"decoder_input_ids": inputs_dict["decoder_input_ids"],
|
||||
}
|
||||
inputs = inputs_dict["input_features"] if "input_features" in inputs_dict else inputs_dict["input_values"]
|
||||
|
||||
loss = model(inputs, **model_inputs).loss
|
||||
loss.backward()
|
||||
|
||||
@slow
|
||||
def test_real_model_save_load_from_pretrained(self):
|
||||
model_2, inputs = self.get_pretrained_model_and_inputs()
|
||||
|
|
@ -590,6 +612,7 @@ class Wav2Vec2Speech2Text2(EncoderDecoderMixin, unittest.TestCase):
|
|||
"decoder_config": decoder_config,
|
||||
"decoder_input_ids": decoder_input_ids,
|
||||
"decoder_attention_mask": decoder_attention_mask,
|
||||
"labels": decoder_input_ids,
|
||||
}
|
||||
|
||||
# there are no published pretrained Speech2Text2ForCausalLM for now
|
||||
|
|
|
|||
|
|
@ -324,6 +324,27 @@ class EncoderDecoderMixin:
|
|||
input_ids_dict = self.prepare_config_and_inputs()
|
||||
self.check_encoder_decoder_model_generate(**input_ids_dict)
|
||||
|
||||
def test_training_gradient_checkpointing(self):
|
||||
inputs_dict = self.prepare_config_and_inputs()
|
||||
encoder_model, decoder_model = self.get_encoder_decoder_model(
|
||||
inputs_dict["config"], inputs_dict["decoder_config"]
|
||||
)
|
||||
|
||||
model = VisionEncoderDecoderModel(encoder=encoder_model, decoder=decoder_model)
|
||||
model.train()
|
||||
model.gradient_checkpointing_enable()
|
||||
model.config.decoder_start_token_id = 0
|
||||
model.config.pad_token_id = 0
|
||||
|
||||
model_inputs = {
|
||||
"pixel_values": inputs_dict["pixel_values"],
|
||||
"labels": inputs_dict["labels"],
|
||||
"decoder_input_ids": inputs_dict["decoder_input_ids"],
|
||||
}
|
||||
|
||||
loss = model(**model_inputs).loss
|
||||
loss.backward()
|
||||
|
||||
@slow
|
||||
def test_real_model_save_load_from_pretrained(self):
|
||||
model_2, inputs = self.get_pretrained_model_and_inputs()
|
||||
|
|
@ -547,6 +568,7 @@ class Swin2BartModelTest(EncoderDecoderMixin, unittest.TestCase):
|
|||
decoder_config_and_inputs = model_tester_decoder.prepare_config_and_inputs()
|
||||
config, pixel_values, _ = encoder_config_and_inputs
|
||||
decoder_config, decoder_inputs_dict = decoder_config_and_inputs
|
||||
decoder_inputs_dict["labels"] = decoder_inputs_dict["decoder_input_ids"]
|
||||
|
||||
# make sure that cross attention layers are added
|
||||
decoder_config.add_cross_attention = True
|
||||
|
|
@ -644,6 +666,7 @@ class ViT2TrOCR(EncoderDecoderMixin, unittest.TestCase):
|
|||
"decoder_config": decoder_config,
|
||||
"decoder_input_ids": decoder_input_ids,
|
||||
"decoder_attention_mask": decoder_attention_mask,
|
||||
"labels": decoder_input_ids,
|
||||
}
|
||||
|
||||
# there are no published pretrained TrOCR checkpoints for now
|
||||
|
|
|
|||
Loading…
Reference in a new issue