From 38dd294dd72183713faa7515923a4c3ebb73ad7e Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Wed, 11 Dec 2024 16:10:40 +0100 Subject: [PATCH] fix --- src/transformers/modeling_utils.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 52770dfa2..44d0f18f7 100755 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -5521,7 +5521,7 @@ class GradientCheckpointLayer(torch.nn.Module): return self._apply_gradient_checkpointing(*args, **kwargs) else: # Default behavior: call the original `forward` method - return super().__call__(*args, **kwargs) + return self.forward(*args, **kwargs) def _apply_gradient_checkpointing(self, *args, **kwargs): """ @@ -5530,7 +5530,9 @@ class GradientCheckpointLayer(torch.nn.Module): By default, uses `torch.utils.checkpoint.checkpoint`. """ # Assume `self.forward` is compatible with checkpointing - return checkpoint(self.__call__, *args, **kwargs) + def wrapped_forward(): + return self.forward(*args, **kwargs) + return self._gradient_checkpointing_func(wrapped_forward) def gradient_checkpointing_enable(self, gradient_checkpointing_kwargs=None):