From 4f36712da1e805e02f03da17e47ac93ca2eecd60 Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Thu, 12 Dec 2024 10:24:11 +0100 Subject: [PATCH] nit? --- src/transformers/modeling_utils.py | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index a4f1bb25f..3dcd54303 100755 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -1356,7 +1356,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix _output_embedding = None _input_embedding = None - gradient_checkpointing = False @property def dummy_inputs(self) -> Dict[str, torch.Tensor]: @@ -2576,6 +2575,13 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix if isinstance(layer, GradientCheckpointLayer): layer.gradient_checkpointing_enable(gradient_checkpointing_kwargs) + @property + def gradient_checkpointing(self, gradient_checkpointing_kwargs=None): + for layer in list(self.modules()): + if isinstance(layer, GradientCheckpointLayer): + return layer.gradient_checkpointing + return False + @property def is_gradient_checkpointing(self) -> bool: """ @@ -5537,6 +5543,11 @@ ALL_ATTENTION_FUNCTIONS.update( class GradientCheckpointLayer(torch.nn.Module): + + def __init__(self, *args, **kwargs): + self.gradient_checkpointing = False + super().__init__( *args, **kwargs) + def __call__(self, *args, **kwargs): """ Adjust the behavior of the inherited class by overriding `__call__`.