diff --git a/test/test_optim.py b/test/test_optim.py index ec8ba03a065..f44a7e87777 100644 --- a/test/test_optim.py +++ b/test/test_optim.py @@ -665,6 +665,17 @@ class TestOptimRenewed(TestCase): self.assertTrue(a1_grad_imags.all_popped()) self.assertTrue(losses.all_popped()) + def test_adamw_serialization(self, device): + model = torch.nn.Linear(5, 5).to(device) + optim = torch.optim.AdamW(model.parameters()) + + loaded_dict = optim.state_dict() + + new_optim = torch.optim.Adam(model.parameters()) + new_optim.load_state_dict(loaded_dict) + + self.assertTrue(new_optim.param_groups[0]["decoupled_weight_decay"]) + def _compare_between( self, inputs, models, optimizers, assert_eq_kwargs=None, assert_step_dtype=None ): @@ -2150,6 +2161,8 @@ class TestOptimRenewed(TestCase): def test_defaults_changed_to_foreach(self, device, dtype, optim_info): # Test that the default implementations for optimizers are changed to foreach # except Adafactor, which defaults to the single tensor impl for memory efficiency. + from torch.optim import Adam, AdamW + optim_cls = optim_info.optim_cls model = torch.nn.Linear(5, 5) model.to(dtype=dtype, device=device) @@ -2157,7 +2170,13 @@ class TestOptimRenewed(TestCase): import inspect - module = inspect.getmodule(optim_cls) + # AdamW dispatches to superclass' adam + if optim_cls is AdamW: + module = inspect.getmodule(Adam) + module_name = "_multi_tensor_adam" + else: + module = inspect.getmodule(optim_cls) + module_name = f"_multi_tensor_{optim_cls.__name__.lower()}" for optim_input in optim_info.optim_inputs_func(device=device): optim = optim_cls(model.parameters(), **optim_input.kwargs) @@ -2165,9 +2184,7 @@ class TestOptimRenewed(TestCase): output = model(inpt) loss = output.sum() loss.backward() - with patch.object( - module, f"_multi_tensor_{optim_cls.__name__.lower()}" - ) as mocked_foreach_impl: + with patch.object(module, module_name) as mocked_foreach_impl: optim.step() self.assertTrue(mocked_foreach_impl.called) diff --git a/torch/optim/adam.py b/torch/optim/adam.py index d621a19c5b9..e3a628cb576 100644 --- a/torch/optim/adam.py +++ b/torch/optim/adam.py @@ -46,6 +46,7 @@ class Adam(Optimizer): capturable: bool = False, differentiable: bool = False, fused: Optional[bool] = None, + decoupled_weight_decay: bool = False, ): if isinstance(lr, Tensor): if foreach and not capturable: @@ -95,6 +96,7 @@ class Adam(Optimizer): capturable=capturable, differentiable=differentiable, fused=fused, + decoupled_weight_decay=decoupled_weight_decay, ) super().__init__(params, defaults) @@ -117,6 +119,7 @@ class Adam(Optimizer): group.setdefault("foreach", None) group.setdefault("capturable", False) group.setdefault("differentiable", False) + group.setdefault("decoupled_weight_decay", False) fused = group.setdefault("fused", None) for p in group["params"]: p_state = self.state.get(p, []) @@ -262,6 +265,7 @@ class Adam(Optimizer): fused=group["fused"], grad_scale=getattr(self, "grad_scale", None), found_inf=getattr(self, "found_inf", None), + decoupled_weight_decay=group["decoupled_weight_decay"], ) return loss @@ -355,6 +359,7 @@ def _single_tensor_adam( maximize: bool, capturable: bool, differentiable: bool, + decoupled_weight_decay: bool, ): assert grad_scale is None and found_inf is None @@ -393,7 +398,11 @@ def _single_tensor_adam( step_t += 1 if weight_decay != 0: - grad = grad.add(param, alpha=weight_decay) + if decoupled_weight_decay: + # Perform stepweight decay + param.mul_(1 - lr * weight_decay) + else: + grad = grad.add(param, alpha=weight_decay) if torch.is_complex(param): grad = torch.view_as_real(grad) @@ -500,6 +509,7 @@ def _multi_tensor_adam( maximize: bool, capturable: bool, differentiable: bool, + decoupled_weight_decay: bool, ): if len(params) == 0: return @@ -603,13 +613,17 @@ def _multi_tensor_adam( torch._foreach_add_(device_state_steps, 1) if weight_decay != 0: - # Re-use the intermediate memory (device_grads) already allocated for maximize - if maximize: - torch._foreach_add_(device_grads, device_params, alpha=weight_decay) + if decoupled_weight_decay: + # Perform stepweight decay + torch._foreach_mul_(device_params, 1 - lr * weight_decay) else: - device_grads = torch._foreach_add( # type: ignore[assignment] - device_grads, device_params, alpha=weight_decay - ) + # Re-use the intermediate memory (device_grads) already allocated for maximize + if maximize: + torch._foreach_add_(device_grads, device_params, alpha=weight_decay) + else: + device_grads = torch._foreach_add( # type: ignore[assignment] + device_grads, device_params, alpha=weight_decay + ) # Decay the first and second moment running average coefficient # Use device beta1 if beta1 is a tensor to ensure all @@ -727,6 +741,7 @@ def _fused_adam( maximize: bool, capturable: bool, # Needed for consistency. differentiable: bool, + decoupled_weight_decay: bool, ) -> None: if not params: return @@ -781,7 +796,8 @@ def _fused_adam( lr_dict[device] = lr.to(device=device, non_blocking=True) # type: ignore[union-attr] lr = lr_dict[device] torch._foreach_add_(device_state_steps, 1) - torch._fused_adam_( + func = torch._fused_adam_ if not decoupled_weight_decay else torch._fused_adamw_ + func( device_params, device_grads, device_exp_avgs, @@ -821,6 +837,7 @@ def adam( grad_scale: Optional[Tensor] = None, found_inf: Optional[Tensor] = None, has_complex: bool = False, + decoupled_weight_decay: bool = False, *, amsgrad: bool, beta1: float, @@ -890,4 +907,5 @@ def adam( differentiable=differentiable, grad_scale=grad_scale, found_inf=found_inf, + decoupled_weight_decay=decoupled_weight_decay, ) diff --git a/torch/optim/adamw.py b/torch/optim/adamw.py index 88118b30957..01453ca0a3c 100644 --- a/torch/optim/adamw.py +++ b/torch/optim/adamw.py @@ -1,29 +1,17 @@ # mypy: allow-untyped-decorators # mypy: allow-untyped-defs -from typing import cast, List, Optional, Tuple, Union +from typing import List, Optional, Tuple, Union -import torch from torch import Tensor +from .adam import Adam, adam from .optimizer import ( _capturable_doc, - _default_to_fused_or_foreach, - _device_dtype_check_for_fused, _differentiable_doc, - _disable_dynamo_if_unsupported, _foreach_doc, _fused_doc, - _get_capturable_supported_devices, - _get_scalar_dtype, - _get_value, _maximize_doc, _params_doc, - _stack_if_compiling, - _use_grad_for_differentiable, - _view_as_real, - DeviceDict, - DeviceDtypeDict, - Optimizer, ParamsT, ) @@ -31,7 +19,7 @@ from .optimizer import ( __all__ = ["AdamW", "adamw"] -class AdamW(Optimizer): +class AdamW(Adam): def __init__( self, params: ParamsT, @@ -47,223 +35,20 @@ class AdamW(Optimizer): differentiable: bool = False, fused: Optional[bool] = None, ): - if isinstance(lr, Tensor): - if foreach and not capturable: - raise ValueError( - "lr as a Tensor is not supported for capturable=False and foreach=True" - ) - if lr.numel() != 1: - raise ValueError("Tensor lr must be 1-element") - - if not 0.0 <= lr: - raise ValueError(f"Invalid learning rate: {lr}") - if not 0.0 <= eps: - raise ValueError(f"Invalid epsilon value: {eps}") - if not 0.0 <= betas[0] < 1.0: - raise ValueError(f"Invalid beta parameter at index 0: {betas[0]}") - if not 0.0 <= betas[1] < 1.0: - raise ValueError(f"Invalid beta parameter at index 1: {betas[1]}") - if not 0.0 <= weight_decay: - raise ValueError(f"Invalid weight_decay value: {weight_decay}") - if not ( - (isinstance(betas[0], float) and isinstance(betas[1], float)) - or (isinstance(betas[0], Tensor) and isinstance(betas[1], Tensor)) - ): - raise ValueError("betas must be either both floats or both Tensors") - if isinstance(betas[0], Tensor): - if not capturable and foreach: - raise ValueError( - "betas[0] as a Tensor is not supported for capturable=False and foreach=True" - ) - if betas[0].numel() != 1: - raise ValueError("Tensor betas[0] must be 1-element") - if isinstance(betas[1], Tensor): - if not capturable and foreach: - raise ValueError( - "betas[1] as a Tensor is not supported for capturable=False and foreach=True" - ) - if betas[1].numel() != 1: - raise ValueError("Tensor betas[1] must be 1-element") - - defaults = dict( - lr=lr, - betas=betas, - eps=eps, - weight_decay=weight_decay, - amsgrad=amsgrad, + super().__init__( + params, + lr, + betas, + eps, + weight_decay, + amsgrad, foreach=foreach, maximize=maximize, capturable=capturable, differentiable=differentiable, fused=fused, + decoupled_weight_decay=True, ) - super().__init__(params, defaults) - - if fused: - if differentiable: - raise RuntimeError("`fused` does not support `differentiable`") - self._step_supports_amp_scaling = True - if foreach: - raise RuntimeError("`fused` and `foreach` cannot be `True` together.") - - def __setstate__(self, state): - super().__setstate__(state) - for group in self.param_groups: - group.setdefault("amsgrad", False) - group.setdefault("maximize", False) - group.setdefault("foreach", None) - group.setdefault("capturable", False) - group.setdefault("differentiable", False) - fused = group.setdefault("fused", None) - for p in group["params"]: - p_state = self.state.get(p, []) - if len(p_state) != 0 and not torch.is_tensor(p_state["step"]): - step_val = float(p_state["step"]) - p_state["step"] = ( - torch.tensor( - step_val, - dtype=_get_scalar_dtype(is_fused=fused), - device=p.device, - ) - if group["capturable"] or group["fused"] - else torch.tensor(step_val, dtype=_get_scalar_dtype()) - ) - - def _init_group( - self, - group, - params_with_grad, - grads, - amsgrad, - exp_avgs, - exp_avg_sqs, - max_exp_avg_sqs, - state_steps, - ): - has_complex = False - for p in group["params"]: - if p.grad is None: - continue - has_complex |= torch.is_complex(p) - params_with_grad.append(p) - if p.grad.is_sparse: - raise RuntimeError("AdamW does not support sparse gradients") - grads.append(p.grad) - - state = self.state[p] - - # State initialization - if len(state) == 0: - if group["fused"]: - _device_dtype_check_for_fused(p) - # note(crcrpar): Deliberately host `step` on CPU if both capturable and fused are off. - # This is because kernel launches are costly on CUDA and XLA. - state["step"] = ( - torch.zeros( - (), - dtype=_get_scalar_dtype(is_fused=group["fused"]), - device=p.device, - ) - if group["capturable"] or group["fused"] - else torch.tensor(0.0, dtype=_get_scalar_dtype()) - ) - # Exponential moving average of gradient values - state["exp_avg"] = torch.zeros_like( - p, memory_format=torch.preserve_format - ) - # Exponential moving average of squared gradient values - state["exp_avg_sq"] = torch.zeros_like( - p, memory_format=torch.preserve_format - ) - if amsgrad: - # Maintains max of all exp. moving avg. of sq. grad. values - state["max_exp_avg_sq"] = torch.zeros_like( - p, memory_format=torch.preserve_format - ) - - exp_avgs.append(state["exp_avg"]) - exp_avg_sqs.append(state["exp_avg_sq"]) - - if group["amsgrad"]: - max_exp_avg_sqs.append(state["max_exp_avg_sq"]) - if group["differentiable"] and state["step"].requires_grad: - raise RuntimeError( - "`requires_grad` is not supported for `step` in differentiable mode" - ) - - # Foreach without capturable does not support a tensor lr - if ( - group["foreach"] - and isinstance(group["lr"], Tensor) - and not group["capturable"] - ): - raise RuntimeError( - "lr as a Tensor is not supported for capturable=False and foreach=True" - ) - - state_steps.append(state["step"]) - return has_complex - - @_use_grad_for_differentiable - def step(self, closure=None): - """Perform a single optimization step. - - Args: - closure (Callable, optional): A closure that reevaluates the model - and returns the loss. - """ - self._cuda_graph_capture_health_check() - - loss = None - if closure is not None: - with torch.enable_grad(): - loss = closure() - - for group in self.param_groups: - params_with_grad: List[Tensor] = [] - grads: List[Tensor] = [] - exp_avgs: List[Tensor] = [] - exp_avg_sqs: List[Tensor] = [] - max_exp_avg_sqs: List[Tensor] = [] - state_steps: List[Tensor] = [] - amsgrad: bool = group["amsgrad"] - beta1, beta2 = cast(Tuple[float, float], group["betas"]) - - has_complex = self._init_group( - group, - params_with_grad, - grads, - amsgrad, - exp_avgs, - exp_avg_sqs, - max_exp_avg_sqs, - state_steps, - ) - - adamw( - params_with_grad, - grads, - exp_avgs, - exp_avg_sqs, - max_exp_avg_sqs, - state_steps, - amsgrad=amsgrad, - beta1=beta1, - beta2=beta2, - lr=group["lr"], - weight_decay=group["weight_decay"], - eps=group["eps"], - maximize=group["maximize"], - foreach=group["foreach"], - capturable=group["capturable"], - differentiable=group["differentiable"], - fused=group["fused"], - grad_scale=getattr(self, "grad_scale", None), - found_inf=getattr(self, "found_inf", None), - has_complex=has_complex, - ) - - return loss AdamW.__doc__ = ( @@ -334,478 +119,7 @@ AdamW.__doc__ = ( ) -def _single_tensor_adamw( - params: List[Tensor], - grads: List[Tensor], - exp_avgs: List[Tensor], - exp_avg_sqs: List[Tensor], - max_exp_avg_sqs: List[Tensor], - state_steps: List[Tensor], - grad_scale: Optional[Tensor], - found_inf: Optional[Tensor], - *, - amsgrad: bool, - beta1: Union[Tensor, float], - beta2: Union[Tensor, float], - lr: Union[Tensor, float], - weight_decay: float, - eps: float, - maximize: bool, - capturable: bool, - differentiable: bool, - has_complex: bool, -): - assert grad_scale is None and found_inf is None - - if torch.jit.is_scripting(): - # this assert is due to JIT being dumb and not realizing that the ops below - # have overloads to handle both float and Tensor lrs, so we just assert it's - # a float since most people using JIT are using floats - assert isinstance(lr, float) - assert isinstance(beta1, float) - assert isinstance(beta2, float) - - # We only shuffle around the beta when it is a Tensor, otherwise, we prefer - # treating it as a scalar. - # Note: ensure type declaration is under conditional check for isinstance - # or else torchscript will get cranky about the DeviceDict type. - if isinstance(beta1, Tensor): - beta1_dict: Optional[DeviceDtypeDict] = {(beta1.device, beta1.dtype): beta1} - else: - beta1_dict = None - - for i, param in enumerate(params): - grad = grads[i] if not maximize else -grads[i] - exp_avg = exp_avgs[i] - exp_avg_sq = exp_avg_sqs[i] - step_t = state_steps[i] - - # If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable] - if not torch.compiler.is_compiling() and capturable: - capturable_supported_devices = _get_capturable_supported_devices() - assert ( - param.device.type == step_t.device.type - and param.device.type in capturable_supported_devices - ), f"If capturable=True, params and state_steps must be on supported devices: {capturable_supported_devices}." - - if torch.is_complex(param): - grad = torch.view_as_real(grad) - exp_avg = torch.view_as_real(exp_avg) - exp_avg_sq = torch.view_as_real(exp_avg_sq) - if amsgrad: - max_exp_avg_sqs[i] = torch.view_as_real(max_exp_avg_sqs[i]) - param = torch.view_as_real(param) - - # update step - step_t += 1 - - # Perform stepweight decay - param.mul_(1 - lr * weight_decay) - - device = param.device - - device = param.device - dtype = param.dtype - - if beta1_dict is not None: - dtype = param.dtype # type: ignore[union-attr] - - # cast to workaround https://github.com/pytorch/pytorch/issues/140601 - key = (device, dtype) - if key not in beta1_dict: - beta1_dict[key] = beta1.to(device=device, dtype=dtype, non_blocking=True) # type: ignore[union-attr] - - device_beta1: Union[float, Tensor] = beta1_dict[key] - else: - device_beta1 = beta1 - - # Decay the first and second moment running average coefficient - exp_avg.lerp_(grad, 1 - device_beta1) - exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) - - if capturable or differentiable: - step = step_t - - bias_correction1 = 1 - beta1**step - bias_correction2 = 1 - beta2**step - - step_size = lr / bias_correction1 - step_size_neg = step_size.neg() - - bias_correction2_sqrt = bias_correction2.sqrt() - - if amsgrad: - # Maintains the maximum of all 2nd moment running avg. till now - if differentiable: - max_exp_avg_sq = max_exp_avg_sqs[i].clone() - else: - max_exp_avg_sq = max_exp_avg_sqs[i] - - max_exp_avg_sqs[i].copy_(torch.maximum(max_exp_avg_sq, exp_avg_sq)) - - # Uses the max. for normalizing running avg. of gradient - # Folds in (admittedly ugly) 1-elem step_size math here to avoid extra param-set-sized read+write - # (can't fold it into addcdiv_ below because addcdiv_ requires value is a Number, not a Tensor) - denom = ( - max_exp_avg_sqs[i].sqrt() / (bias_correction2_sqrt * step_size_neg) - ).add_(eps / step_size_neg) - else: - denom = ( - exp_avg_sq.sqrt() / (bias_correction2_sqrt * step_size_neg) - ).add_(eps / step_size_neg) - - param.addcdiv_(exp_avg, denom) - else: - step = _get_value(step_t) - - bias_correction1 = 1 - beta1**step - bias_correction2 = 1 - beta2**step - - step_size = lr / bias_correction1 - - bias_correction2_sqrt = bias_correction2**0.5 - - if amsgrad: - # Maintains the maximum of all 2nd moment running avg. till now - torch.maximum(max_exp_avg_sqs[i], exp_avg_sq, out=max_exp_avg_sqs[i]) - - # Use the max. for normalizing running avg. of gradient - denom = (max_exp_avg_sqs[i].sqrt() / bias_correction2_sqrt).add_(eps) - else: - denom = (exp_avg_sq.sqrt() / bias_correction2_sqrt).add_(eps) - - param.addcdiv_(exp_avg, denom, value=-step_size) - - # Lastly, switch back to complex view - if amsgrad and torch.is_complex(params[i]): - max_exp_avg_sqs[i] = torch.view_as_complex(max_exp_avg_sqs[i]) - - -def _multi_tensor_adamw( - params: List[Tensor], - grads: List[Tensor], - exp_avgs: List[Tensor], - exp_avg_sqs: List[Tensor], - max_exp_avg_sqs: List[Tensor], - state_steps: List[Tensor], - grad_scale: Optional[Tensor], - found_inf: Optional[Tensor], - *, - amsgrad: bool, - beta1: Union[Tensor, float], - beta2: Union[Tensor, float], - lr: Union[Tensor, float], - weight_decay: float, - eps: float, - maximize: bool, - capturable: bool, - differentiable: bool, - has_complex: bool, -): - if len(params) == 0: - return - - if isinstance(lr, Tensor) and not capturable: - raise RuntimeError( - "lr as a Tensor is not supported for capturable=False and foreach=True" - ) - - if isinstance(beta1, Tensor): - if not capturable: - raise ValueError( - "beta1 as a Tensor is not supported for capturable=False and foreach=True" - ) - if beta1.numel() != 1: - raise ValueError("Tensor beta1 must be 1-element") - - if isinstance(beta2, Tensor): - if not capturable: - raise ValueError( - "beta2 as a Tensor is not supported for capturable=False and foreach=True" - ) - if beta2.numel() != 1: - raise ValueError("Tensor beta2 must be 1-element") - - # If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable] - if not torch.compiler.is_compiling() and capturable: - capturable_supported_devices = _get_capturable_supported_devices( - supports_xla=False - ) - assert all( - p.device.type == step.device.type - and p.device.type in capturable_supported_devices - for p, step in zip(params, state_steps) - ), f"If capturable=True, params and state_steps must be on supported devices: {capturable_supported_devices}." - - assert not differentiable, "_foreach ops don't support autograd" - - assert grad_scale is None and found_inf is None - - grouped_tensors = Optimizer._group_tensors_by_device_and_dtype( - [params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps] # type: ignore[list-item] - ) - - # We only shuffle around the beta when it is a Tensor and on CUDA, otherwise, we prefer - # treating it as a scalar. - beta1_dict: Optional[DeviceDict] = ( # type: ignore[attr-defined] - {beta1.device: beta1} - if isinstance(beta1, Tensor) and str(beta1.device) != "cpu" - else None - ) - - for ( - device_params_, - device_grads_, - device_exp_avgs_, - device_exp_avg_sqs_, - device_max_exp_avg_sqs_, - device_state_steps_, - ), _ in grouped_tensors.values(): - device_params = cast(List[Tensor], device_params_) - device_grads = cast(List[Tensor], device_grads_) - device_exp_avgs = cast(List[Tensor], device_exp_avgs_) - device_exp_avg_sqs = cast(List[Tensor], device_exp_avg_sqs_) - device_state_steps = cast(List[Tensor], device_state_steps_) - - device = device_params[0].device - if beta1_dict is not None and device not in beta1_dict: - beta1_dict[device] = beta1.to(device=device, non_blocking=True) # type: ignore[union-attr] - - device_beta1 = beta1_dict[device] if beta1_dict else beta1 - - if has_complex: - if amsgrad: - device_max_exp_avg_sqs = cast(List[Tensor], device_max_exp_avg_sqs_) - _view_as_real( - device_params, - device_grads, - device_exp_avgs, - device_exp_avg_sqs, - device_max_exp_avg_sqs, - ) - else: - _view_as_real( - device_params, device_grads, device_exp_avgs, device_exp_avg_sqs - ) - - if maximize: - device_grads = torch._foreach_neg(device_grads) # type: ignore[assignment] - - # Update steps - # If steps are on CPU, foreach will fall back to the slow path, which is a for-loop calling t.add(1) over - # and over. 1 will then be wrapped into a Tensor over and over again, which is slower than if we just - # wrapped it once now. The alpha is required to assure we go to the right overload. - if not torch.compiler.is_compiling() and device_state_steps[0].is_cpu: - torch._foreach_add_( - device_state_steps, torch.tensor(1.0, device="cpu"), alpha=1.0 - ) - else: - torch._foreach_add_(device_state_steps, 1) - - # Perform stepweight decay - if weight_decay != 0: - torch._foreach_mul_(device_params, 1 - lr * weight_decay) - - # Decay the first and second moment running average coefficient - torch._foreach_lerp_(device_exp_avgs, device_grads, 1 - device_beta1) - - torch._foreach_mul_(device_exp_avg_sqs, beta2) - # Due to the strictness of the _foreach_addcmul API, we can't have a single - # tensor scalar as the scalar arg (only python number is supported there) - # as a result, separate out the value mul - # Filed https://github.com/pytorch/pytorch/issues/139795 - if isinstance(beta2, torch.Tensor): - scaled_device_grads = torch._foreach_mul(device_grads, 1 - beta2) # type: ignore[assignment] - value = 1.0 - else: - scaled_device_grads = device_grads # type: ignore[assignment] - value = 1 - beta2 - - torch._foreach_addcmul_( - device_exp_avg_sqs, scaled_device_grads, device_grads, value - ) - - # Delete the local intermediate(s) since they won't be used anymore to save on peak memory - del device_grads - del scaled_device_grads - - bias_correction1: Union[Tuple[Tensor, ...], List[Tensor]] - bias_correction2: Union[Tuple[Tensor, ...], List[Tensor]] - bias_correction2_sqrt: Union[Tuple[Tensor, ...], List[Tensor]] - - if capturable: - bias_correction1 = torch._foreach_pow(beta1, device_state_steps) # type: ignore[arg-type] - bias_correction2 = torch._foreach_pow(beta2, device_state_steps) # type: ignore[arg-type] - # foreach_sub doesn't allow a scalar as the first arg - torch._foreach_sub_(bias_correction1, 1) - torch._foreach_sub_(bias_correction2, 1) - # we do not negate bias_correction1 as it'll need to be negated later anyway - torch._foreach_neg_(bias_correction2) - - # foreach_div doesn't allow a scalar as the first arg - torch._foreach_div_(bias_correction1, lr) - torch._foreach_reciprocal_(bias_correction1) - - torch._foreach_sqrt_(bias_correction2) - - # Re-assign for clarity as we maintain minimal intermediates: we'll have - # step_size = - lr / (1 - beta1 ^ t) where t = num_steps - # bias_correction2_sqrt = sqrt(1 - beta2 ^ t) - step_size = bias_correction1 - bias_correction2_sqrt = bias_correction2 - - if amsgrad: - device_max_exp_avg_sqs = cast(List[Tensor], device_max_exp_avg_sqs_) - - # Maintains the maximum of all 2nd moment running avg. till now - torch._foreach_maximum_(device_max_exp_avg_sqs, device_exp_avg_sqs) - - # Use the max. for normalizing running avg. of gradient - exp_avg_sq_sqrt = torch._foreach_sqrt(device_max_exp_avg_sqs) - else: - exp_avg_sq_sqrt = torch._foreach_sqrt(device_exp_avg_sqs) - - torch._foreach_div_(exp_avg_sq_sqrt, bias_correction2_sqrt) - torch._foreach_add_(exp_avg_sq_sqrt, eps) - torch._foreach_div_(exp_avg_sq_sqrt, step_size) - - # at this point, exp_avg_sq_sqrt = - (1 - beta^t) * [sqrt(exp_avg_sq / (1 - beta2^t)) + eps] / lr - torch._foreach_addcdiv_(device_params, device_exp_avgs, exp_avg_sq_sqrt) - else: - bias_correction1 = [ - 1 - beta1 ** _get_value(step) for step in device_state_steps - ] - bias_correction2 = [ - 1 - beta2 ** _get_value(step) for step in device_state_steps - ] - - step_size = _stack_if_compiling([(lr / bc) * -1 for bc in bias_correction1]) - - bias_correction2_sqrt = [ - bc**0.5 for bc in bias_correction2 # type: ignore[arg-type] - ] - - if amsgrad: - device_max_exp_avg_sqs = cast(List[Tensor], device_max_exp_avg_sqs_) - - # Maintains the maximum of all 2nd moment running avg. till now - torch._foreach_maximum_(device_max_exp_avg_sqs, device_exp_avg_sqs) - - # Use the max. for normalizing running avg. of gradient - exp_avg_sq_sqrt = torch._foreach_sqrt(device_max_exp_avg_sqs) - else: - exp_avg_sq_sqrt = torch._foreach_sqrt(device_exp_avg_sqs) - - torch._foreach_div_(exp_avg_sq_sqrt, bias_correction2_sqrt) - torch._foreach_add_(exp_avg_sq_sqrt, eps) - torch._foreach_addcdiv_( - device_params, - device_exp_avgs, - exp_avg_sq_sqrt, - step_size, # type: ignore[arg-type] - ) - - -def _fused_adamw( - params: List[Tensor], - grads: List[Tensor], - exp_avgs: List[Tensor], - exp_avg_sqs: List[Tensor], - max_exp_avg_sqs: List[Tensor], - state_steps: List[Tensor], - grad_scale: Optional[Tensor], - found_inf: Optional[Tensor], - *, - amsgrad: bool, - beta1: float, - beta2: float, - lr: Union[Tensor, float], - weight_decay: float, - eps: float, - maximize: bool, - capturable: bool, # Needed for consistency. - differentiable: bool, - has_complex: bool, # Needed for consistency. -) -> None: - if not params: - return - if differentiable: - raise RuntimeError("Adam with fused=True does not support differentiable=True") - - grad_scale_dict: DeviceDict = ( - {grad_scale.device: grad_scale} if grad_scale is not None else {} - ) - found_inf_dict: DeviceDict = ( - {found_inf.device: found_inf} if found_inf is not None else {} - ) - - # We only shuffle around the lr when it is a Tensor and on CUDA, otherwise, we prefer - # treating it as a scalar. - lr_dict: Optional[DeviceDict] = ( - {lr.device: lr} if isinstance(lr, Tensor) and str(lr.device) != "cpu" else None - ) - - grouped_tensors = Optimizer._group_tensors_by_device_and_dtype( - [params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps] # type: ignore[list-item] - ) - for (device, _), ( - ( - device_params_, - device_grads_, - device_exp_avgs_, - device_exp_avg_sqs_, - device_max_exp_avg_sqs, - device_state_steps_, - ), - _, - ) in grouped_tensors.items(): - device_params = cast(List[Tensor], device_params_) - device_grads = cast(List[Tensor], device_grads_) - device_exp_avgs = cast(List[Tensor], device_exp_avgs_) - device_exp_avg_sqs = cast(List[Tensor], device_exp_avg_sqs_) - device_state_steps = cast(List[Tensor], device_state_steps_) - - if device.type == "mps": # type: ignore[union-attr] - assert found_inf is None and grad_scale is None - - device_grad_scale, device_found_inf = None, None - if grad_scale is not None: - device_grad_scale = grad_scale_dict.setdefault( - device, grad_scale.to(device, non_blocking=True) - ) - if found_inf is not None: - device_found_inf = found_inf_dict.setdefault( - device, found_inf.to(device, non_blocking=True) - ) - if lr_dict is not None and device not in lr_dict: - lr = lr_dict.setdefault( - device, lr.to(device=device, non_blocking=True) # type: ignore[union-attr] - ) - torch._foreach_add_(device_state_steps, 1) - torch._fused_adamw_( - device_params, - device_grads, - device_exp_avgs, - device_exp_avg_sqs, - device_max_exp_avg_sqs, # type: ignore[arg-type] - device_state_steps, - amsgrad=amsgrad, - lr=lr, # type: ignore[arg-type] - beta1=beta1, - beta2=beta2, - weight_decay=weight_decay, - eps=eps, - maximize=maximize, - grad_scale=device_grad_scale, - found_inf=device_found_inf, - ) - if device_found_inf is not None: - torch._foreach_sub_( - device_state_steps, [device_found_inf] * len(device_state_steps) - ) - - -@_disable_dynamo_if_unsupported(single_tensor_fn=_single_tensor_adamw) +# @_disable_dynamo_if_unsupported logic occurs in the decorator that's applied to F.adam def adamw( params: List[Tensor], grads: List[Tensor], @@ -835,48 +149,20 @@ def adamw( See :class:`~torch.optim.AdamW` for details. """ - if not torch.compiler.is_compiling() and not all( - isinstance(t, torch.Tensor) for t in state_steps - ): - raise RuntimeError( - "API has changed, `state_steps` argument must contain a list of singleton tensors" - ) - - # Respect when the user inputs False/True for foreach or fused. We only want to change - # the default when neither have been user-specified. Note that we default to foreach - # and pass False to use_fused. This is not a mistake--we want to give the fused impl - # bake-in time before making it the default, even if it is typically faster. - if fused is None and foreach is None: - _, foreach = _default_to_fused_or_foreach( - params, differentiable, use_fused=False - ) - # Do not flip on foreach for the unsupported case where lr is a Tensor and capturable=False. - if foreach and isinstance(lr, Tensor) and not capturable: - foreach = False - if fused is None: - fused = False - if foreach is None: - foreach = False - - if foreach and torch.jit.is_scripting(): - raise RuntimeError("torch.jit.script not supported with foreach optimizers") - if fused and torch.jit.is_scripting(): - raise RuntimeError("torch.jit.script not supported with fused optimizers") - - if fused and not torch.jit.is_scripting(): - func = _fused_adamw - elif foreach and not torch.jit.is_scripting(): - func = _multi_tensor_adamw - else: - func = _single_tensor_adamw - - func( + adam( params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps, + foreach=foreach, + capturable=capturable, + differentiable=differentiable, + fused=fused, + grad_scale=grad_scale, + found_inf=found_inf, + has_complex=has_complex, amsgrad=amsgrad, beta1=beta1, beta2=beta2, @@ -884,9 +170,5 @@ def adamw( weight_decay=weight_decay, eps=eps, maximize=maximize, - capturable=capturable, - differentiable=differentiable, - grad_scale=grad_scale, - found_inf=found_inf, - has_complex=has_complex, + decoupled_weight_decay=True, )