clip_grad_norm can use fast foreach path for inf norm (#120623)

Now that foreach_norm supports inf, we should not special case it.

For a mere 256 parameters, we get a win of 30ms in CPU time and ~800us -> 300us decrease in CUDA time. This win is only bigger for more parameters.

New profile:
```
(pytorch-3.10) [janeyx@devgpu023.odn1 ~/local/pytorch (bf1c0490|REBASE-i|detached HEAD)]$ python playground2.py
STAGE:2024-02-26 13:14:10 395517:395517 ActivityProfilerController.cpp:314] Completed Stage: Warm Up
STAGE:2024-02-26 13:14:11 395517:395517 ActivityProfilerController.cpp:320] Completed Stage: Collection
STAGE:2024-02-26 13:14:11 395517:395517 ActivityProfilerController.cpp:324] Completed Stage: Post Processing
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------
                                       cudaLaunchKernel        67.01%     102.262ms        67.01%     102.262ms       5.382ms       2.000us         0.66%       2.000us       0.105us            19
                               aten::linalg_vector_norm         0.20%     311.000us        23.44%      35.776ms      35.776ms       3.000us         0.99%       3.000us       3.000us             1
                                               aten::to         0.79%       1.208ms        14.62%      22.311ms      86.143us       0.000us         0.00%     263.000us       1.015us           259
                                            aten::clamp         0.12%     182.000us        13.96%      21.303ms      21.303ms       1.000us         0.33%       1.000us       1.000us             1
                                         aten::_to_copy         2.38%       3.628ms        13.83%      21.103ms     163.589us       0.000us         0.00%     263.000us       2.039us           129
                                    aten::_foreach_norm         4.71%       7.185ms        13.54%      20.659ms      10.329ms      19.000us         6.29%      23.000us      11.500us             2
                                              aten::add         0.14%     211.000us        10.86%      16.580ms      16.580ms       1.000us         0.33%       1.000us       1.000us             1
                                            aten::stack         3.11%       4.744ms         9.59%      14.642ms      14.642ms       0.000us         0.00%       6.000us       6.000us             1
                                            aten::copy_         5.71%       8.721ms         9.27%      14.152ms     109.705us     258.000us        85.43%     263.000us       2.039us           129
                                       aten::reciprocal         0.13%     193.000us         7.93%      12.100ms      12.100ms       1.000us         0.33%       1.000us       1.000us             1
                                              aten::cat         0.67%       1.017ms         4.67%       7.129ms       7.129ms       6.000us         1.99%       6.000us       6.000us             1
                                            aten::zeros         0.05%      79.000us         4.46%       6.800ms       3.400ms       0.000us         0.00%       2.000us       1.000us             2
                                            aten::zero_         0.05%      79.000us         4.28%       6.537ms       3.268ms       0.000us         0.00%       2.000us       1.000us             2
                                            aten::fill_         0.09%     131.000us         4.23%       6.458ms       3.229ms       2.000us         0.66%       2.000us       1.000us             2
                                    aten::_foreach_mul_         1.56%       2.377ms         3.86%       5.896ms       2.948ms      10.000us         3.31%      10.000us       5.000us             2
                                            aten::empty         3.55%       5.414ms         3.55%       5.414ms      20.984us       0.000us         0.00%       0.000us       0.000us           258
                                    aten::empty_strided         2.18%       3.323ms         2.18%       3.323ms      25.760us       0.000us         0.00%       0.000us       0.000us           129
                                           aten::detach         0.85%       1.302ms         2.10%       3.199ms      12.496us       0.000us         0.00%       0.000us       0.000us           256
                             cudaDeviceEnablePeerAccess         2.01%       3.069ms         2.01%       3.069ms       1.534ms       0.000us         0.00%       0.000us       0.000us             2
                                        aten::unsqueeze         1.24%       1.899ms         1.81%       2.769ms      10.816us       0.000us         0.00%       0.000us       0.000us           256
                                                 detach         1.24%       1.897ms         1.24%       1.897ms       7.410us       0.000us         0.00%       0.000us       0.000us           256
                                        cudaMemcpyAsync         1.01%       1.539ms         1.01%       1.539ms      11.930us       0.000us         0.00%       0.000us       0.000us           129
                                       aten::as_strided         0.58%     881.000us         0.58%     881.000us       3.428us       0.000us         0.00%       0.000us       0.000us           257
                                    cudaStreamWaitEvent         0.35%     540.000us         0.35%     540.000us       2.093us       0.000us         0.00%       0.000us       0.000us           258
                                        cudaEventRecord         0.18%     278.000us         0.18%     278.000us       1.078us       5.000us         1.66%       5.000us       0.019us           258
                                              aten::mul         0.08%     125.000us         0.09%     138.000us     138.000us       1.000us         0.33%       1.000us       1.000us             1
                                  cudaDeviceSynchronize         0.01%      13.000us         0.01%      13.000us       6.500us       0.000us         0.00%       0.000us       0.000us             2
                                cudaDeviceCanAccessPeer         0.00%       5.000us         0.00%       5.000us       2.500us       0.000us         0.00%       0.000us       0.000us             2
void at::native::vectorized_elementwise_kernel<4, at...         0.00%       0.000us         0.00%       0.000us       0.000us       2.000us         0.66%       2.000us       1.000us             2
void at::native::(anonymous namespace)::multi_tensor...         0.00%       0.000us         0.00%       0.000us       0.000us      13.000us         4.30%      13.000us       3.250us             4
void at::native::lpnorm_cleanup<float, (at::native::...         0.00%       0.000us         0.00%       0.000us       0.000us       6.000us         1.99%       6.000us       3.000us             2
                         Memcpy PtoP (Device -> Device)         0.00%       0.000us         0.00%       0.000us       0.000us     258.000us        85.43%     258.000us       2.000us           129
void at::native::(anonymous namespace)::CatArrayBatc...         0.00%       0.000us         0.00%       0.000us       0.000us       6.000us         1.99%       6.000us       3.000us             2
void at::native::reduce_kernel<512, 1, at::native::R...         0.00%       0.000us         0.00%       0.000us       0.000us       3.000us         0.99%       3.000us       3.000us             1
void at::native::vectorized_elementwise_kernel<4, at...         0.00%       0.000us         0.00%       0.000us       0.000us       1.000us         0.33%       1.000us       1.000us             1
void at::native::vectorized_elementwise_kernel<4, at...         0.00%       0.000us         0.00%       0.000us       0.000us       1.000us         0.33%       1.000us       1.000us             1
void at::native::vectorized_elementwise_kernel<4, at...         0.00%       0.000us         0.00%       0.000us       0.000us       1.000us         0.33%       1.000us       1.000us             1
void at::native::vectorized_elementwise_kernel<4, at...         0.00%       0.000us         0.00%       0.000us       0.000us       1.000us         0.33%       1.000us       1.000us             1
void at::native::(anonymous namespace)::multi_tensor...         0.00%       0.000us         0.00%       0.000us       0.000us      10.000us         3.31%      10.000us       2.500us             4
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------
Self CPU time total: 152.613ms
Self CUDA time total: 302.000us
```

Compared to on main:
```
(pytorch-3.10) [janeyx@devgpu023.odn1 ~/local/pytorch (5a0a9644)]$ python playground2.py
STAGE:2024-02-26 13:09:56 285045:285045 ActivityProfilerController.cpp:314] Completed Stage: Warm Up
STAGE:2024-02-26 13:09:57 285045:285045 ActivityProfilerController.cpp:320] Completed Stage: Collection
STAGE:2024-02-26 13:09:57 285045:285045 ActivityProfilerController.cpp:324] Completed Stage: Post Processing
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------
                                       cudaLaunchKernel        61.42%     113.375ms        61.42%     113.375ms     424.625us      45.000us         5.66%      45.000us       0.169us           267
                               aten::linalg_vector_norm        14.04%      25.909ms        37.67%      69.534ms     271.617us     514.000us        64.65%     559.000us       2.184us           256
                                               aten::to         0.78%       1.433ms        12.87%      23.751ms      91.703us       0.000us         0.00%     278.000us       1.073us           259
                                         aten::_to_copy         2.02%       3.730ms        12.09%      22.318ms     173.008us       0.000us         0.00%     278.000us       2.155us           129
                                            aten::clamp         0.09%     174.000us        11.43%      21.103ms      21.103ms       1.000us         0.13%       1.000us       1.000us             1
                                              aten::add         0.11%     205.000us         9.08%      16.768ms      16.768ms       1.000us         0.13%       1.000us       1.000us             1
                                            aten::copy_         4.94%       9.112ms         8.15%      15.043ms     116.612us     258.000us        32.45%     278.000us       2.155us           129
                                            aten::stack         2.76%       5.091ms         7.97%      14.719ms      14.719ms       0.000us         0.00%       6.000us       6.000us             1
                                       aten::reciprocal         0.11%     194.000us         7.01%      12.933ms      12.933ms       1.000us         0.13%       1.000us       1.000us             1
                                              aten::max         0.09%     165.000us         6.43%      11.868ms      11.868ms       3.000us         0.38%       3.000us       3.000us             1
                                           aten::detach         1.58%       2.911ms         4.12%       7.596ms      14.836us       0.000us         0.00%       0.000us       0.000us           512
                                              aten::cat         0.56%       1.042ms         3.73%       6.882ms       6.882ms       6.000us         0.75%       6.000us       6.000us             1
                                    aten::_foreach_mul_         1.36%       2.503ms         3.33%       6.145ms       3.072ms      10.000us         1.26%      10.000us       5.000us             2
                                                 detach         2.54%       4.685ms         2.54%       4.685ms       9.150us       0.000us         0.00%       0.000us       0.000us           512
                                    aten::empty_strided         1.92%       3.545ms         1.92%       3.545ms      27.481us       0.000us         0.00%       0.000us       0.000us           129
                             cudaDeviceEnablePeerAccess         1.64%       3.022ms         1.64%       3.022ms       1.511ms       0.000us         0.00%       0.000us       0.000us             2
                                        aten::unsqueeze         1.03%       1.892ms         1.49%       2.746ms      10.727us       0.000us         0.00%       0.000us       0.000us           256
                                       aten::as_strided         1.35%       2.494ms         1.35%       2.494ms       4.862us       0.000us         0.00%       0.000us       0.000us           513
                                        cudaMemcpyAsync         1.01%       1.868ms         1.01%       1.868ms      14.481us       4.000us         0.50%       4.000us       0.031us           129
                                    cudaStreamWaitEvent         0.41%     760.000us         0.41%     760.000us       2.946us       8.000us         1.01%       8.000us       0.031us           258
                                        cudaEventRecord         0.15%     276.000us         0.15%     276.000us       1.070us       8.000us         1.01%       8.000us       0.031us           258
                                              aten::mul         0.08%     139.000us         0.08%     153.000us     153.000us       1.000us         0.13%       1.000us       1.000us             1
                                            aten::empty         0.02%      35.000us         0.02%      35.000us      35.000us       0.000us         0.00%       0.000us       0.000us             1
                                  cudaDeviceSynchronize         0.01%      14.000us         0.01%      14.000us       7.000us       0.000us         0.00%       0.000us       0.000us             2
                                cudaDeviceCanAccessPeer         0.00%       5.000us         0.00%       5.000us       2.500us       0.000us         0.00%       0.000us       0.000us             2
void at::native::reduce_kernel<512, 1, at::native::R...         0.00%       0.000us         0.00%       0.000us       0.000us     514.000us        64.65%     514.000us       2.008us           256
                         Memcpy PtoP (Device -> Device)         0.00%       0.000us         0.00%       0.000us       0.000us     258.000us        32.45%     258.000us       2.000us           129
void at::native::(anonymous namespace)::CatArrayBatc...         0.00%       0.000us         0.00%       0.000us       0.000us       6.000us         0.75%       6.000us       3.000us             2
void at::native::reduce_kernel<512, 1, at::native::R...         0.00%       0.000us         0.00%       0.000us       0.000us       3.000us         0.38%       3.000us       3.000us             1
void at::native::vectorized_elementwise_kernel<4, at...         0.00%       0.000us         0.00%       0.000us       0.000us       1.000us         0.13%       1.000us       1.000us             1
void at::native::vectorized_elementwise_kernel<4, at...         0.00%       0.000us         0.00%       0.000us       0.000us       1.000us         0.13%       1.000us       1.000us             1
void at::native::vectorized_elementwise_kernel<4, at...         0.00%       0.000us         0.00%       0.000us       0.000us       1.000us         0.13%       1.000us       1.000us             1
void at::native::vectorized_elementwise_kernel<4, at...         0.00%       0.000us         0.00%       0.000us       0.000us       1.000us         0.13%       1.000us       1.000us             1
void at::native::(anonymous namespace)::multi_tensor...         0.00%       0.000us         0.00%       0.000us       0.000us      10.000us         1.26%      10.000us       2.500us             4
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------
Self CPU time total: 184.579ms
Self CUDA time total: 795.000us
```

For script:
```
import torch
from math import inf
from torch.nn.utils import clip_grad_norm_

params = [torch.rand(32, 16, device="cuda:3")*5 for _ in range(128)] + [torch.rand(32, 16, device="cuda:4")*-7 for _ in range(128)]
for p in params:
    p.grad = torch.rand_like(p)

with torch.profiler.profile(
    activities=[
        torch.profiler.ProfilerActivity.CPU,
        torch.profiler.ProfilerActivity.CUDA,
    ]
) as p:
    total_norm = clip_grad_norm_(params, 10.0, norm_type=inf)
    torch.cuda.synchronize()

print(p.key_averages().table(sort_by="cpu_time_total"))
print(total_norm)
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/120623
Approved by: https://github.com/Skylion007, https://github.com/mikaylagawarecki
This commit is contained in:
Jane Xu 2024-02-26 11:59:59 -08:00 committed by PyTorch MergeBot
parent b01bd1f7a1
commit df72819f91

View file

@ -2,7 +2,7 @@ import warnings
from typing import Union, Iterable, List, Dict, Tuple, Optional, cast
import torch
from torch import Tensor, inf
from torch import Tensor
from torch.utils._foreach_utils import _group_tensors_by_device_and_dtype, _has_foreach_support
_tensor_or_tensors = Union[torch.Tensor, Iterable[torch.Tensor]]
@ -42,23 +42,19 @@ def clip_grad_norm_(
if len(grads) == 0:
return torch.tensor(0.)
first_device = grads[0].device
grouped_grads: Dict[Tuple[torch.device, torch.dtype], List[List[Tensor]]] \
grouped_grads: Dict[Tuple[torch.device, torch.dtype], Tuple[List[List[Tensor]], List[int]]] \
= _group_tensors_by_device_and_dtype([[g.detach() for g in grads]]) # type: ignore[assignment]
if norm_type == inf:
norms = [torch.linalg.vector_norm(g.detach(), inf).to(first_device) for g in grads]
total_norm = norms[0] if len(norms) == 1 else torch.max(torch.stack(norms))
else:
norms = []
for ((device, _), ([grads], _)) in grouped_grads.items(): # type: ignore[assignment]
if (foreach is None or foreach) and _has_foreach_support(grads, device=device):
norms.extend(torch._foreach_norm(grads, norm_type))
elif foreach:
raise RuntimeError(f'foreach=True was passed, but can\'t use the foreach API on {device.type} tensors')
else:
norms.extend([torch.linalg.vector_norm(g, norm_type) for g in grads])
norms: List[Tensor] = []
for ((device, _), ([device_grads], _)) in grouped_grads.items(): # type: ignore[assignment]
if (foreach is None or foreach) and _has_foreach_support(device_grads, device=device):
norms.extend(torch._foreach_norm(device_grads, norm_type))
elif foreach:
raise RuntimeError(f'foreach=True was passed, but can\'t use the foreach API on {device.type} tensors')
else:
norms.extend([torch.linalg.vector_norm(g, norm_type) for g in device_grads])
total_norm = torch.linalg.vector_norm(torch.stack([norm.to(first_device) for norm in norms]), norm_type)
total_norm = torch.linalg.vector_norm(torch.stack([norm.to(first_device) for norm in norms]), norm_type)
if error_if_nonfinite and torch.logical_or(total_norm.isnan(), total_norm.isinf()):
raise RuntimeError(
@ -71,14 +67,14 @@ def clip_grad_norm_(
# avoids a `if clip_coef < 1:` conditional which can require a CPU <=> device synchronization
# when the gradients do not reside in CPU memory.
clip_coef_clamped = torch.clamp(clip_coef, max=1.0)
for ((device, _), ([grads], _)) in grouped_grads.items(): # type: ignore[assignment]
if (foreach is None or foreach) and _has_foreach_support(grads, device=device): # type: ignore[arg-type]
torch._foreach_mul_(grads, clip_coef_clamped.to(device)) # type: ignore[call-overload]
for ((device, _), ([device_grads], _)) in grouped_grads.items(): # type: ignore[assignment]
if (foreach is None or foreach) and _has_foreach_support(device_grads, device=device): # type: ignore[arg-type]
torch._foreach_mul_(device_grads, clip_coef_clamped.to(device)) # type: ignore[call-overload]
elif foreach:
raise RuntimeError(f'foreach=True was passed, but can\'t use the foreach API on {device.type} tensors')
else:
clip_coef_clamped_device = clip_coef_clamped.to(device)
for g in grads:
for g in device_grads:
g.detach().mul_(clip_coef_clamped_device)
return total_norm