diff --git a/torch/nn/utils/clip_grad.py b/torch/nn/utils/clip_grad.py index 522829cb18a..8d6909391b7 100644 --- a/torch/nn/utils/clip_grad.py +++ b/torch/nn/utils/clip_grad.py @@ -1,4 +1,5 @@ import warnings +import functools from typing import Union, Iterable, List, Dict, Tuple, Optional, cast import torch @@ -9,6 +10,18 @@ _tensor_or_tensors = Union[torch.Tensor, Iterable[torch.Tensor]] __all__ = ['clip_grad_norm_', 'clip_grad_norm', 'clip_grad_value_'] +def _no_grad(func): + """ + This wrapper is needed to avoid a circular import when using @torch.no_grad on the exposed functions + clip_grad_norm_ and clip_grad_value_ themselves. + """ + def _no_grad_wrapper(*args, **kwargs): + with torch.no_grad(): + return func(*args, **kwargs) + functools.update_wrapper(_no_grad_wrapper, func) + return _no_grad_wrapper + +@_no_grad def clip_grad_norm_( parameters: _tensor_or_tensors, max_norm: float, norm_type: float = 2.0, error_if_nonfinite: bool = False, foreach: Optional[bool] = None) -> torch.Tensor: @@ -43,7 +56,7 @@ def clip_grad_norm_( return torch.tensor(0.) first_device = grads[0].device grouped_grads: Dict[Tuple[torch.device, torch.dtype], Tuple[List[List[Tensor]], List[int]]] \ - = _group_tensors_by_device_and_dtype([[g.detach() for g in grads]]) # type: ignore[assignment] + = _group_tensors_by_device_and_dtype([grads]) # type: ignore[assignment] norms: List[Tensor] = [] for ((device, _), ([device_grads], _)) in grouped_grads.items(): # type: ignore[assignment] @@ -75,7 +88,7 @@ def clip_grad_norm_( else: clip_coef_clamped_device = clip_coef_clamped.to(device) for g in device_grads: - g.detach().mul_(clip_coef_clamped_device) + g.mul_(clip_coef_clamped_device) return total_norm @@ -94,6 +107,7 @@ def clip_grad_norm( return clip_grad_norm_(parameters, max_norm, norm_type, error_if_nonfinite, foreach) +@_no_grad def clip_grad_value_(parameters: _tensor_or_tensors, clip_value: float, foreach: Optional[bool] = None) -> None: r"""Clip the gradients of an iterable of parameters at specified value. @@ -124,6 +138,5 @@ def clip_grad_value_(parameters: _tensor_or_tensors, clip_value: float, foreach: elif foreach: raise RuntimeError(f'foreach=True was passed, but can\'t use the foreach API on {device.type} tensors') else: - with torch.no_grad(): - for grad in grads: - cast(Tensor, grad).clamp_(min=-clip_value, max=clip_value) + for grad in grads: + cast(Tensor, grad).clamp_(min=-clip_value, max=clip_value)