This commit is contained in:
Arthur Zucker 2024-12-12 10:24:11 +01:00
parent 2016bc47d0
commit 4f36712da1

View file

@ -1356,7 +1356,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
_output_embedding = None
_input_embedding = None
gradient_checkpointing = False
@property
def dummy_inputs(self) -> Dict[str, torch.Tensor]:
@ -2576,6 +2575,13 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
if isinstance(layer, GradientCheckpointLayer):
layer.gradient_checkpointing_enable(gradient_checkpointing_kwargs)
@property
def gradient_checkpointing(self, gradient_checkpointing_kwargs=None):
for layer in list(self.modules()):
if isinstance(layer, GradientCheckpointLayer):
return layer.gradient_checkpointing
return False
@property
def is_gradient_checkpointing(self) -> bool:
"""
@ -5537,6 +5543,11 @@ ALL_ATTENTION_FUNCTIONS.update(
class GradientCheckpointLayer(torch.nn.Module):
def __init__(self, *args, **kwargs):
self.gradient_checkpointing = False
super().__init__( *args, **kwargs)
def __call__(self, *args, **kwargs):
"""
Adjust the behavior of the inherited class by overriding `__call__`.