diff --git a/torch/nn/utils/clip_grad.py b/torch/nn/utils/clip_grad.py index 95a51fd2c68..522829cb18a 100644 --- a/torch/nn/utils/clip_grad.py +++ b/torch/nn/utils/clip_grad.py @@ -2,7 +2,7 @@ import warnings from typing import Union, Iterable, List, Dict, Tuple, Optional, cast import torch -from torch import Tensor, inf +from torch import Tensor from torch.utils._foreach_utils import _group_tensors_by_device_and_dtype, _has_foreach_support _tensor_or_tensors = Union[torch.Tensor, Iterable[torch.Tensor]] @@ -42,23 +42,19 @@ def clip_grad_norm_( if len(grads) == 0: return torch.tensor(0.) first_device = grads[0].device - grouped_grads: Dict[Tuple[torch.device, torch.dtype], List[List[Tensor]]] \ + grouped_grads: Dict[Tuple[torch.device, torch.dtype], Tuple[List[List[Tensor]], List[int]]] \ = _group_tensors_by_device_and_dtype([[g.detach() for g in grads]]) # type: ignore[assignment] - if norm_type == inf: - norms = [torch.linalg.vector_norm(g.detach(), inf).to(first_device) for g in grads] - total_norm = norms[0] if len(norms) == 1 else torch.max(torch.stack(norms)) - else: - norms = [] - for ((device, _), ([grads], _)) in grouped_grads.items(): # type: ignore[assignment] - if (foreach is None or foreach) and _has_foreach_support(grads, device=device): - norms.extend(torch._foreach_norm(grads, norm_type)) - elif foreach: - raise RuntimeError(f'foreach=True was passed, but can\'t use the foreach API on {device.type} tensors') - else: - norms.extend([torch.linalg.vector_norm(g, norm_type) for g in grads]) + norms: List[Tensor] = [] + for ((device, _), ([device_grads], _)) in grouped_grads.items(): # type: ignore[assignment] + if (foreach is None or foreach) and _has_foreach_support(device_grads, device=device): + norms.extend(torch._foreach_norm(device_grads, norm_type)) + elif foreach: + raise RuntimeError(f'foreach=True was passed, but can\'t use the foreach API on {device.type} tensors') + else: + norms.extend([torch.linalg.vector_norm(g, norm_type) for g in device_grads]) - total_norm = torch.linalg.vector_norm(torch.stack([norm.to(first_device) for norm in norms]), norm_type) + total_norm = torch.linalg.vector_norm(torch.stack([norm.to(first_device) for norm in norms]), norm_type) if error_if_nonfinite and torch.logical_or(total_norm.isnan(), total_norm.isinf()): raise RuntimeError( @@ -71,14 +67,14 @@ def clip_grad_norm_( # avoids a `if clip_coef < 1:` conditional which can require a CPU <=> device synchronization # when the gradients do not reside in CPU memory. clip_coef_clamped = torch.clamp(clip_coef, max=1.0) - for ((device, _), ([grads], _)) in grouped_grads.items(): # type: ignore[assignment] - if (foreach is None or foreach) and _has_foreach_support(grads, device=device): # type: ignore[arg-type] - torch._foreach_mul_(grads, clip_coef_clamped.to(device)) # type: ignore[call-overload] + for ((device, _), ([device_grads], _)) in grouped_grads.items(): # type: ignore[assignment] + if (foreach is None or foreach) and _has_foreach_support(device_grads, device=device): # type: ignore[arg-type] + torch._foreach_mul_(device_grads, clip_coef_clamped.to(device)) # type: ignore[call-overload] elif foreach: raise RuntimeError(f'foreach=True was passed, but can\'t use the foreach API on {device.type} tensors') else: clip_coef_clamped_device = clip_coef_clamped.to(device) - for g in grads: + for g in device_grads: g.detach().mul_(clip_coef_clamped_device) return total_norm