Replace individual detaches with overall torch.no_grad decorator (#120638)

Fixes https://github.com/pytorch/pytorch/issues/120611.

At first, I thought there were too many detaches, but @awgu and I made the conclusion that both `clip_grad_norm_` and `clip_grad_value_` should be run under torch.no_grad similar to optimizer step. One option is to continue calling `detach`, but doing that on many tensors is slower than setting the context to be no_grad (I think?) and Andrew had noticed: "the 1st round of detaches takes 10 ms for FSDP2, whereas existing FSDP's clip_grad_norm_ only takes 3 ms total" since there are more tensors in FSDP2.

This change also disables grad mode for the foreach path of `clip_grad_value_`, which the first attempt that didn't do this was an oversight. Not sure how to add a test case for this since grad mode will be turned back on after the call.

New profile is not much different from the one in the bottom of this stack, but the number of detaches is 0 :D:
```
(pytorch-3.10) [janeyx@devgpu023.odn1 ~/local/pytorch (c71bcceb)]$ python playground2.py
STAGE:2024-02-26 13:07:15 211224:211224 ActivityProfilerController.cpp:314] Completed Stage: Warm Up
STAGE:2024-02-26 13:07:16 211224:211224 ActivityProfilerController.cpp:320] Completed Stage: Collection
STAGE:2024-02-26 13:07:16 211224:211224 ActivityProfilerController.cpp:324] Completed Stage: Post Processing
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------
                                       cudaLaunchKernel        70.63%     110.415ms        70.63%     110.415ms       5.811ms       0.000us         0.00%       0.000us       0.000us            19
                               aten::linalg_vector_norm         0.18%     284.000us        26.00%      40.636ms      40.636ms       3.000us         0.99%       3.000us       3.000us             1
                                            aten::clamp         0.09%     148.000us        14.88%      23.261ms      23.261ms       1.000us         0.33%       1.000us       1.000us             1
                                               aten::to         0.75%       1.170ms        14.05%      21.970ms      84.826us       0.000us         0.00%     258.000us       0.996us           259
                                         aten::_to_copy         2.28%       3.562ms        13.31%      20.800ms     161.240us       0.000us         0.00%     258.000us       2.000us           129
                                    aten::_foreach_norm         4.44%       6.935ms        12.72%      19.878ms       9.939ms      19.000us         6.29%      21.000us      10.500us             2
                                              aten::add         0.11%     173.000us        10.97%      17.153ms      17.153ms       1.000us         0.33%       1.000us       1.000us             1
                                            aten::stack         2.99%       4.673ms         9.15%      14.300ms      14.300ms       0.000us         0.00%       6.000us       6.000us             1
                                            aten::copy_         5.49%       8.586ms         8.96%      14.001ms     108.535us     258.000us        85.43%     258.000us       2.000us           129
                                       aten::reciprocal         0.11%     179.000us         8.35%      13.051ms      13.051ms       1.000us         0.33%       1.000us       1.000us             1
                                              aten::cat         0.64%     993.000us         4.42%       6.902ms       6.902ms       6.000us         1.99%       6.000us       6.000us             1
                                            aten::zeros         0.04%      69.000us         4.28%       6.698ms       3.349ms       0.000us         0.00%       2.000us       1.000us             2
                                            aten::zero_         0.04%      66.000us         4.13%       6.462ms       3.231ms       0.000us         0.00%       2.000us       1.000us             2
                                            aten::fill_         0.06%      98.000us         4.09%       6.396ms       3.198ms       2.000us         0.66%       2.000us       1.000us             2
                                    aten::_foreach_mul_         1.50%       2.342ms         3.79%       5.924ms       2.962ms      10.000us         3.31%      10.000us       5.000us             2
                                            aten::empty         3.27%       5.115ms         3.27%       5.115ms      19.826us       0.000us         0.00%       0.000us       0.000us           258
                                    aten::empty_strided         2.07%       3.237ms         2.07%       3.237ms      25.093us       0.000us         0.00%       0.000us       0.000us           129
                             cudaDeviceEnablePeerAccess         1.93%       3.023ms         1.93%       3.023ms       1.512ms       0.000us         0.00%       0.000us       0.000us             2
                                        aten::unsqueeze         1.21%       1.896ms         1.74%       2.725ms      10.645us       0.000us         0.00%       0.000us       0.000us           256
                                        cudaMemcpyAsync         1.01%       1.572ms         1.01%       1.572ms      12.186us       0.000us         0.00%       0.000us       0.000us           129
                                       aten::as_strided         0.54%     839.000us         0.54%     839.000us       3.265us       0.000us         0.00%       0.000us       0.000us           257
                                    cudaStreamWaitEvent         0.34%     539.000us         0.34%     539.000us       2.089us       0.000us         0.00%       0.000us       0.000us           258
                                        cudaEventRecord         0.18%     274.000us         0.18%     274.000us       1.062us       0.000us         0.00%       0.000us       0.000us           258
                                              aten::mul         0.07%     107.000us         0.08%     132.000us     132.000us       1.000us         0.33%       1.000us       1.000us             1
                                  cudaDeviceSynchronize         0.01%      17.000us         0.01%      17.000us       8.500us       0.000us         0.00%       0.000us       0.000us             2
                                cudaDeviceCanAccessPeer         0.00%       7.000us         0.00%       7.000us       3.500us       0.000us         0.00%       0.000us       0.000us             2
void at::native::vectorized_elementwise_kernel<4, at...         0.00%       0.000us         0.00%       0.000us       0.000us       2.000us         0.66%       2.000us       1.000us             2
void at::native::(anonymous namespace)::multi_tensor...         0.00%       0.000us         0.00%       0.000us       0.000us      13.000us         4.30%      13.000us       3.250us             4
void at::native::lpnorm_cleanup<float, (at::native::...         0.00%       0.000us         0.00%       0.000us       0.000us       6.000us         1.99%       6.000us       3.000us             2
                         Memcpy PtoP (Device -> Device)         0.00%       0.000us         0.00%       0.000us       0.000us     258.000us        85.43%     258.000us       2.000us           129
void at::native::(anonymous namespace)::CatArrayBatc...         0.00%       0.000us         0.00%       0.000us       0.000us       6.000us         1.99%       6.000us       3.000us             2
void at::native::reduce_kernel<512, 1, at::native::R...         0.00%       0.000us         0.00%       0.000us       0.000us       3.000us         0.99%       3.000us       3.000us             1
void at::native::vectorized_elementwise_kernel<4, at...         0.00%       0.000us         0.00%       0.000us       0.000us       1.000us         0.33%       1.000us       1.000us             1
void at::native::vectorized_elementwise_kernel<4, at...         0.00%       0.000us         0.00%       0.000us       0.000us       1.000us         0.33%       1.000us       1.000us             1
void at::native::vectorized_elementwise_kernel<4, at...         0.00%       0.000us         0.00%       0.000us       0.000us       1.000us         0.33%       1.000us       1.000us             1
void at::native::vectorized_elementwise_kernel<4, at...         0.00%       0.000us         0.00%       0.000us       0.000us       1.000us         0.33%       1.000us       1.000us             1
void at::native::(anonymous namespace)::multi_tensor...         0.00%       0.000us         0.00%       0.000us       0.000us      10.000us         3.31%      10.000us       2.500us             4
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------
Self CPU time total: 156.319ms
Self CUDA time total: 302.000us
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/120638
Approved by: https://github.com/Skylion007, https://github.com/albanD
ghstack dependencies: #120623
This commit is contained in:
Jane Xu 2024-02-26 12:28:20 -08:00 committed by PyTorch MergeBot
parent df72819f91
commit ef9b6d6816

View file

@ -1,4 +1,5 @@
import warnings
import functools
from typing import Union, Iterable, List, Dict, Tuple, Optional, cast
import torch
@ -9,6 +10,18 @@ _tensor_or_tensors = Union[torch.Tensor, Iterable[torch.Tensor]]
__all__ = ['clip_grad_norm_', 'clip_grad_norm', 'clip_grad_value_']
def _no_grad(func):
"""
This wrapper is needed to avoid a circular import when using @torch.no_grad on the exposed functions
clip_grad_norm_ and clip_grad_value_ themselves.
"""
def _no_grad_wrapper(*args, **kwargs):
with torch.no_grad():
return func(*args, **kwargs)
functools.update_wrapper(_no_grad_wrapper, func)
return _no_grad_wrapper
@_no_grad
def clip_grad_norm_(
parameters: _tensor_or_tensors, max_norm: float, norm_type: float = 2.0,
error_if_nonfinite: bool = False, foreach: Optional[bool] = None) -> torch.Tensor:
@ -43,7 +56,7 @@ def clip_grad_norm_(
return torch.tensor(0.)
first_device = grads[0].device
grouped_grads: Dict[Tuple[torch.device, torch.dtype], Tuple[List[List[Tensor]], List[int]]] \
= _group_tensors_by_device_and_dtype([[g.detach() for g in grads]]) # type: ignore[assignment]
= _group_tensors_by_device_and_dtype([grads]) # type: ignore[assignment]
norms: List[Tensor] = []
for ((device, _), ([device_grads], _)) in grouped_grads.items(): # type: ignore[assignment]
@ -75,7 +88,7 @@ def clip_grad_norm_(
else:
clip_coef_clamped_device = clip_coef_clamped.to(device)
for g in device_grads:
g.detach().mul_(clip_coef_clamped_device)
g.mul_(clip_coef_clamped_device)
return total_norm
@ -94,6 +107,7 @@ def clip_grad_norm(
return clip_grad_norm_(parameters, max_norm, norm_type, error_if_nonfinite, foreach)
@_no_grad
def clip_grad_value_(parameters: _tensor_or_tensors, clip_value: float, foreach: Optional[bool] = None) -> None:
r"""Clip the gradients of an iterable of parameters at specified value.
@ -124,6 +138,5 @@ def clip_grad_value_(parameters: _tensor_or_tensors, clip_value: float, foreach:
elif foreach:
raise RuntimeError(f'foreach=True was passed, but can\'t use the foreach API on {device.type} tensors')
else:
with torch.no_grad():
for grad in grads:
cast(Tensor, grad).clamp_(min=-clip_value, max=clip_value)
for grad in grads:
cast(Tensor, grad).clamp_(min=-clip_value, max=clip_value)