mirror of
https://github.com/saymrwulf/transformers.git
synced 2026-05-14 20:58:08 +00:00
fix
This commit is contained in:
parent
1baabd3207
commit
38dd294dd7
1 changed files with 4 additions and 2 deletions
|
|
@ -5521,7 +5521,7 @@ class GradientCheckpointLayer(torch.nn.Module):
|
|||
return self._apply_gradient_checkpointing(*args, **kwargs)
|
||||
else:
|
||||
# Default behavior: call the original `forward` method
|
||||
return super().__call__(*args, **kwargs)
|
||||
return self.forward(*args, **kwargs)
|
||||
|
||||
def _apply_gradient_checkpointing(self, *args, **kwargs):
|
||||
"""
|
||||
|
|
@ -5530,7 +5530,9 @@ class GradientCheckpointLayer(torch.nn.Module):
|
|||
By default, uses `torch.utils.checkpoint.checkpoint`.
|
||||
"""
|
||||
# Assume `self.forward` is compatible with checkpointing
|
||||
return checkpoint(self.__call__, *args, **kwargs)
|
||||
def wrapped_forward():
|
||||
return self.forward(*args, **kwargs)
|
||||
return self._gradient_checkpointing_func(wrapped_forward)
|
||||
|
||||
|
||||
def gradient_checkpointing_enable(self, gradient_checkpointing_kwargs=None):
|
||||
|
|
|
|||
Loading…
Reference in a new issue