mirror of
https://github.com/saymrwulf/transformers.git
synced 2026-05-14 20:58:08 +00:00
nit?
This commit is contained in:
parent
2016bc47d0
commit
4f36712da1
1 changed files with 12 additions and 1 deletions
|
|
@ -1356,7 +1356,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||
|
||||
_output_embedding = None
|
||||
_input_embedding = None
|
||||
gradient_checkpointing = False
|
||||
|
||||
@property
|
||||
def dummy_inputs(self) -> Dict[str, torch.Tensor]:
|
||||
|
|
@ -2576,6 +2575,13 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||
if isinstance(layer, GradientCheckpointLayer):
|
||||
layer.gradient_checkpointing_enable(gradient_checkpointing_kwargs)
|
||||
|
||||
@property
|
||||
def gradient_checkpointing(self, gradient_checkpointing_kwargs=None):
|
||||
for layer in list(self.modules()):
|
||||
if isinstance(layer, GradientCheckpointLayer):
|
||||
return layer.gradient_checkpointing
|
||||
return False
|
||||
|
||||
@property
|
||||
def is_gradient_checkpointing(self) -> bool:
|
||||
"""
|
||||
|
|
@ -5537,6 +5543,11 @@ ALL_ATTENTION_FUNCTIONS.update(
|
|||
|
||||
|
||||
class GradientCheckpointLayer(torch.nn.Module):
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
self.gradient_checkpointing = False
|
||||
super().__init__( *args, **kwargs)
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
"""
|
||||
Adjust the behavior of the inherited class by overriding `__call__`.
|
||||
|
|
|
|||
Loading…
Reference in a new issue