mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
Replace individual detaches with overall torch.no_grad decorator (#120638)
Fixes https://github.com/pytorch/pytorch/issues/120611. At first, I thought there were too many detaches, but @awgu and I made the conclusion that both `clip_grad_norm_` and `clip_grad_value_` should be run under torch.no_grad similar to optimizer step. One option is to continue calling `detach`, but doing that on many tensors is slower than setting the context to be no_grad (I think?) and Andrew had noticed: "the 1st round of detaches takes 10 ms for FSDP2, whereas existing FSDP's clip_grad_norm_ only takes 3 ms total" since there are more tensors in FSDP2. This change also disables grad mode for the foreach path of `clip_grad_value_`, which the first attempt that didn't do this was an oversight. Not sure how to add a test case for this since grad mode will be turned back on after the call. New profile is not much different from the one in the bottom of this stack, but the number of detaches is 0 :D: ``` (pytorch-3.10) [janeyx@devgpu023.odn1 ~/local/pytorch (c71bcceb)]$ python playground2.py STAGE:2024-02-26 13:07:15 211224:211224 ActivityProfilerController.cpp:314] Completed Stage: Warm Up STAGE:2024-02-26 13:07:16 211224:211224 ActivityProfilerController.cpp:320] Completed Stage: Collection STAGE:2024-02-26 13:07:16 211224:211224 ActivityProfilerController.cpp:324] Completed Stage: Post Processing ------------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ Name Self CPU % Self CPU CPU total % CPU total CPU time avg Self CUDA Self CUDA % CUDA total CUDA time avg # of Calls ------------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ cudaLaunchKernel 70.63% 110.415ms 70.63% 110.415ms 5.811ms 0.000us 0.00% 0.000us 0.000us 19 aten::linalg_vector_norm 0.18% 284.000us 26.00% 40.636ms 40.636ms 3.000us 0.99% 3.000us 3.000us 1 aten::clamp 0.09% 148.000us 14.88% 23.261ms 23.261ms 1.000us 0.33% 1.000us 1.000us 1 aten::to 0.75% 1.170ms 14.05% 21.970ms 84.826us 0.000us 0.00% 258.000us 0.996us 259 aten::_to_copy 2.28% 3.562ms 13.31% 20.800ms 161.240us 0.000us 0.00% 258.000us 2.000us 129 aten::_foreach_norm 4.44% 6.935ms 12.72% 19.878ms 9.939ms 19.000us 6.29% 21.000us 10.500us 2 aten::add 0.11% 173.000us 10.97% 17.153ms 17.153ms 1.000us 0.33% 1.000us 1.000us 1 aten::stack 2.99% 4.673ms 9.15% 14.300ms 14.300ms 0.000us 0.00% 6.000us 6.000us 1 aten::copy_ 5.49% 8.586ms 8.96% 14.001ms 108.535us 258.000us 85.43% 258.000us 2.000us 129 aten::reciprocal 0.11% 179.000us 8.35% 13.051ms 13.051ms 1.000us 0.33% 1.000us 1.000us 1 aten::cat 0.64% 993.000us 4.42% 6.902ms 6.902ms 6.000us 1.99% 6.000us 6.000us 1 aten::zeros 0.04% 69.000us 4.28% 6.698ms 3.349ms 0.000us 0.00% 2.000us 1.000us 2 aten::zero_ 0.04% 66.000us 4.13% 6.462ms 3.231ms 0.000us 0.00% 2.000us 1.000us 2 aten::fill_ 0.06% 98.000us 4.09% 6.396ms 3.198ms 2.000us 0.66% 2.000us 1.000us 2 aten::_foreach_mul_ 1.50% 2.342ms 3.79% 5.924ms 2.962ms 10.000us 3.31% 10.000us 5.000us 2 aten::empty 3.27% 5.115ms 3.27% 5.115ms 19.826us 0.000us 0.00% 0.000us 0.000us 258 aten::empty_strided 2.07% 3.237ms 2.07% 3.237ms 25.093us 0.000us 0.00% 0.000us 0.000us 129 cudaDeviceEnablePeerAccess 1.93% 3.023ms 1.93% 3.023ms 1.512ms 0.000us 0.00% 0.000us 0.000us 2 aten::unsqueeze 1.21% 1.896ms 1.74% 2.725ms 10.645us 0.000us 0.00% 0.000us 0.000us 256 cudaMemcpyAsync 1.01% 1.572ms 1.01% 1.572ms 12.186us 0.000us 0.00% 0.000us 0.000us 129 aten::as_strided 0.54% 839.000us 0.54% 839.000us 3.265us 0.000us 0.00% 0.000us 0.000us 257 cudaStreamWaitEvent 0.34% 539.000us 0.34% 539.000us 2.089us 0.000us 0.00% 0.000us 0.000us 258 cudaEventRecord 0.18% 274.000us 0.18% 274.000us 1.062us 0.000us 0.00% 0.000us 0.000us 258 aten::mul 0.07% 107.000us 0.08% 132.000us 132.000us 1.000us 0.33% 1.000us 1.000us 1 cudaDeviceSynchronize 0.01% 17.000us 0.01% 17.000us 8.500us 0.000us 0.00% 0.000us 0.000us 2 cudaDeviceCanAccessPeer 0.00% 7.000us 0.00% 7.000us 3.500us 0.000us 0.00% 0.000us 0.000us 2 void at::native::vectorized_elementwise_kernel<4, at... 0.00% 0.000us 0.00% 0.000us 0.000us 2.000us 0.66% 2.000us 1.000us 2 void at::native::(anonymous namespace)::multi_tensor... 0.00% 0.000us 0.00% 0.000us 0.000us 13.000us 4.30% 13.000us 3.250us 4 void at::native::lpnorm_cleanup<float, (at::native::... 0.00% 0.000us 0.00% 0.000us 0.000us 6.000us 1.99% 6.000us 3.000us 2 Memcpy PtoP (Device -> Device) 0.00% 0.000us 0.00% 0.000us 0.000us 258.000us 85.43% 258.000us 2.000us 129 void at::native::(anonymous namespace)::CatArrayBatc... 0.00% 0.000us 0.00% 0.000us 0.000us 6.000us 1.99% 6.000us 3.000us 2 void at::native::reduce_kernel<512, 1, at::native::R... 0.00% 0.000us 0.00% 0.000us 0.000us 3.000us 0.99% 3.000us 3.000us 1 void at::native::vectorized_elementwise_kernel<4, at... 0.00% 0.000us 0.00% 0.000us 0.000us 1.000us 0.33% 1.000us 1.000us 1 void at::native::vectorized_elementwise_kernel<4, at... 0.00% 0.000us 0.00% 0.000us 0.000us 1.000us 0.33% 1.000us 1.000us 1 void at::native::vectorized_elementwise_kernel<4, at... 0.00% 0.000us 0.00% 0.000us 0.000us 1.000us 0.33% 1.000us 1.000us 1 void at::native::vectorized_elementwise_kernel<4, at... 0.00% 0.000us 0.00% 0.000us 0.000us 1.000us 0.33% 1.000us 1.000us 1 void at::native::(anonymous namespace)::multi_tensor... 0.00% 0.000us 0.00% 0.000us 0.000us 10.000us 3.31% 10.000us 2.500us 4 ------------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ Self CPU time total: 156.319ms Self CUDA time total: 302.000us ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/120638 Approved by: https://github.com/Skylion007, https://github.com/albanD ghstack dependencies: #120623
This commit is contained in:
parent
df72819f91
commit
ef9b6d6816
1 changed files with 18 additions and 5 deletions
|
|
@ -1,4 +1,5 @@
|
|||
import warnings
|
||||
import functools
|
||||
from typing import Union, Iterable, List, Dict, Tuple, Optional, cast
|
||||
|
||||
import torch
|
||||
|
|
@ -9,6 +10,18 @@ _tensor_or_tensors = Union[torch.Tensor, Iterable[torch.Tensor]]
|
|||
|
||||
__all__ = ['clip_grad_norm_', 'clip_grad_norm', 'clip_grad_value_']
|
||||
|
||||
def _no_grad(func):
|
||||
"""
|
||||
This wrapper is needed to avoid a circular import when using @torch.no_grad on the exposed functions
|
||||
clip_grad_norm_ and clip_grad_value_ themselves.
|
||||
"""
|
||||
def _no_grad_wrapper(*args, **kwargs):
|
||||
with torch.no_grad():
|
||||
return func(*args, **kwargs)
|
||||
functools.update_wrapper(_no_grad_wrapper, func)
|
||||
return _no_grad_wrapper
|
||||
|
||||
@_no_grad
|
||||
def clip_grad_norm_(
|
||||
parameters: _tensor_or_tensors, max_norm: float, norm_type: float = 2.0,
|
||||
error_if_nonfinite: bool = False, foreach: Optional[bool] = None) -> torch.Tensor:
|
||||
|
|
@ -43,7 +56,7 @@ def clip_grad_norm_(
|
|||
return torch.tensor(0.)
|
||||
first_device = grads[0].device
|
||||
grouped_grads: Dict[Tuple[torch.device, torch.dtype], Tuple[List[List[Tensor]], List[int]]] \
|
||||
= _group_tensors_by_device_and_dtype([[g.detach() for g in grads]]) # type: ignore[assignment]
|
||||
= _group_tensors_by_device_and_dtype([grads]) # type: ignore[assignment]
|
||||
|
||||
norms: List[Tensor] = []
|
||||
for ((device, _), ([device_grads], _)) in grouped_grads.items(): # type: ignore[assignment]
|
||||
|
|
@ -75,7 +88,7 @@ def clip_grad_norm_(
|
|||
else:
|
||||
clip_coef_clamped_device = clip_coef_clamped.to(device)
|
||||
for g in device_grads:
|
||||
g.detach().mul_(clip_coef_clamped_device)
|
||||
g.mul_(clip_coef_clamped_device)
|
||||
|
||||
return total_norm
|
||||
|
||||
|
|
@ -94,6 +107,7 @@ def clip_grad_norm(
|
|||
return clip_grad_norm_(parameters, max_norm, norm_type, error_if_nonfinite, foreach)
|
||||
|
||||
|
||||
@_no_grad
|
||||
def clip_grad_value_(parameters: _tensor_or_tensors, clip_value: float, foreach: Optional[bool] = None) -> None:
|
||||
r"""Clip the gradients of an iterable of parameters at specified value.
|
||||
|
||||
|
|
@ -124,6 +138,5 @@ def clip_grad_value_(parameters: _tensor_or_tensors, clip_value: float, foreach:
|
|||
elif foreach:
|
||||
raise RuntimeError(f'foreach=True was passed, but can\'t use the foreach API on {device.type} tensors')
|
||||
else:
|
||||
with torch.no_grad():
|
||||
for grad in grads:
|
||||
cast(Tensor, grad).clamp_(min=-clip_value, max=clip_value)
|
||||
for grad in grads:
|
||||
cast(Tensor, grad).clamp_(min=-clip_value, max=clip_value)
|
||||
|
|
|
|||
Loading…
Reference in a new issue