This commit is contained in:
Arthur Zucker 2024-12-11 16:10:40 +01:00
parent 1baabd3207
commit 38dd294dd7

View file

@ -5521,7 +5521,7 @@ class GradientCheckpointLayer(torch.nn.Module):
return self._apply_gradient_checkpointing(*args, **kwargs)
else:
# Default behavior: call the original `forward` method
return super().__call__(*args, **kwargs)
return self.forward(*args, **kwargs)
def _apply_gradient_checkpointing(self, *args, **kwargs):
"""
@ -5530,7 +5530,9 @@ class GradientCheckpointLayer(torch.nn.Module):
By default, uses `torch.utils.checkpoint.checkpoint`.
"""
# Assume `self.forward` is compatible with checkpointing
return checkpoint(self.__call__, *args, **kwargs)
def wrapped_forward():
return self.forward(*args, **kwargs)
return self._gradient_checkpointing_func(wrapped_forward)
def gradient_checkpointing_enable(self, gradient_checkpointing_kwargs=None):