mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-15 21:00:47 +00:00
Refactor AdamW into Adam (heavily inspired by tfsingh) (#143710)
Fixes #104899 Refactors AdamW into Adam by making AdamW a subclass of Adam. Additionally adds a test to assert that the added parameter `decoupled_weight_decay` is True in AdamW and also updates test_defaults_changed_to_foreach to account for the differences in module location for AdamW. Heavily heavily inspired by #118857 by @tfsingh Pull Request resolved: https://github.com/pytorch/pytorch/pull/143710 Approved by: https://github.com/janeyx99
This commit is contained in:
parent
4271a95590
commit
6ccb8ed186
3 changed files with 68 additions and 751 deletions
|
|
@ -665,6 +665,17 @@ class TestOptimRenewed(TestCase):
|
|||
self.assertTrue(a1_grad_imags.all_popped())
|
||||
self.assertTrue(losses.all_popped())
|
||||
|
||||
def test_adamw_serialization(self, device):
|
||||
model = torch.nn.Linear(5, 5).to(device)
|
||||
optim = torch.optim.AdamW(model.parameters())
|
||||
|
||||
loaded_dict = optim.state_dict()
|
||||
|
||||
new_optim = torch.optim.Adam(model.parameters())
|
||||
new_optim.load_state_dict(loaded_dict)
|
||||
|
||||
self.assertTrue(new_optim.param_groups[0]["decoupled_weight_decay"])
|
||||
|
||||
def _compare_between(
|
||||
self, inputs, models, optimizers, assert_eq_kwargs=None, assert_step_dtype=None
|
||||
):
|
||||
|
|
@ -2150,6 +2161,8 @@ class TestOptimRenewed(TestCase):
|
|||
def test_defaults_changed_to_foreach(self, device, dtype, optim_info):
|
||||
# Test that the default implementations for optimizers are changed to foreach
|
||||
# except Adafactor, which defaults to the single tensor impl for memory efficiency.
|
||||
from torch.optim import Adam, AdamW
|
||||
|
||||
optim_cls = optim_info.optim_cls
|
||||
model = torch.nn.Linear(5, 5)
|
||||
model.to(dtype=dtype, device=device)
|
||||
|
|
@ -2157,7 +2170,13 @@ class TestOptimRenewed(TestCase):
|
|||
|
||||
import inspect
|
||||
|
||||
module = inspect.getmodule(optim_cls)
|
||||
# AdamW dispatches to superclass' adam
|
||||
if optim_cls is AdamW:
|
||||
module = inspect.getmodule(Adam)
|
||||
module_name = "_multi_tensor_adam"
|
||||
else:
|
||||
module = inspect.getmodule(optim_cls)
|
||||
module_name = f"_multi_tensor_{optim_cls.__name__.lower()}"
|
||||
|
||||
for optim_input in optim_info.optim_inputs_func(device=device):
|
||||
optim = optim_cls(model.parameters(), **optim_input.kwargs)
|
||||
|
|
@ -2165,9 +2184,7 @@ class TestOptimRenewed(TestCase):
|
|||
output = model(inpt)
|
||||
loss = output.sum()
|
||||
loss.backward()
|
||||
with patch.object(
|
||||
module, f"_multi_tensor_{optim_cls.__name__.lower()}"
|
||||
) as mocked_foreach_impl:
|
||||
with patch.object(module, module_name) as mocked_foreach_impl:
|
||||
optim.step()
|
||||
self.assertTrue(mocked_foreach_impl.called)
|
||||
|
||||
|
|
|
|||
|
|
@ -46,6 +46,7 @@ class Adam(Optimizer):
|
|||
capturable: bool = False,
|
||||
differentiable: bool = False,
|
||||
fused: Optional[bool] = None,
|
||||
decoupled_weight_decay: bool = False,
|
||||
):
|
||||
if isinstance(lr, Tensor):
|
||||
if foreach and not capturable:
|
||||
|
|
@ -95,6 +96,7 @@ class Adam(Optimizer):
|
|||
capturable=capturable,
|
||||
differentiable=differentiable,
|
||||
fused=fused,
|
||||
decoupled_weight_decay=decoupled_weight_decay,
|
||||
)
|
||||
super().__init__(params, defaults)
|
||||
|
||||
|
|
@ -117,6 +119,7 @@ class Adam(Optimizer):
|
|||
group.setdefault("foreach", None)
|
||||
group.setdefault("capturable", False)
|
||||
group.setdefault("differentiable", False)
|
||||
group.setdefault("decoupled_weight_decay", False)
|
||||
fused = group.setdefault("fused", None)
|
||||
for p in group["params"]:
|
||||
p_state = self.state.get(p, [])
|
||||
|
|
@ -262,6 +265,7 @@ class Adam(Optimizer):
|
|||
fused=group["fused"],
|
||||
grad_scale=getattr(self, "grad_scale", None),
|
||||
found_inf=getattr(self, "found_inf", None),
|
||||
decoupled_weight_decay=group["decoupled_weight_decay"],
|
||||
)
|
||||
|
||||
return loss
|
||||
|
|
@ -355,6 +359,7 @@ def _single_tensor_adam(
|
|||
maximize: bool,
|
||||
capturable: bool,
|
||||
differentiable: bool,
|
||||
decoupled_weight_decay: bool,
|
||||
):
|
||||
assert grad_scale is None and found_inf is None
|
||||
|
||||
|
|
@ -393,7 +398,11 @@ def _single_tensor_adam(
|
|||
step_t += 1
|
||||
|
||||
if weight_decay != 0:
|
||||
grad = grad.add(param, alpha=weight_decay)
|
||||
if decoupled_weight_decay:
|
||||
# Perform stepweight decay
|
||||
param.mul_(1 - lr * weight_decay)
|
||||
else:
|
||||
grad = grad.add(param, alpha=weight_decay)
|
||||
|
||||
if torch.is_complex(param):
|
||||
grad = torch.view_as_real(grad)
|
||||
|
|
@ -500,6 +509,7 @@ def _multi_tensor_adam(
|
|||
maximize: bool,
|
||||
capturable: bool,
|
||||
differentiable: bool,
|
||||
decoupled_weight_decay: bool,
|
||||
):
|
||||
if len(params) == 0:
|
||||
return
|
||||
|
|
@ -603,13 +613,17 @@ def _multi_tensor_adam(
|
|||
torch._foreach_add_(device_state_steps, 1)
|
||||
|
||||
if weight_decay != 0:
|
||||
# Re-use the intermediate memory (device_grads) already allocated for maximize
|
||||
if maximize:
|
||||
torch._foreach_add_(device_grads, device_params, alpha=weight_decay)
|
||||
if decoupled_weight_decay:
|
||||
# Perform stepweight decay
|
||||
torch._foreach_mul_(device_params, 1 - lr * weight_decay)
|
||||
else:
|
||||
device_grads = torch._foreach_add( # type: ignore[assignment]
|
||||
device_grads, device_params, alpha=weight_decay
|
||||
)
|
||||
# Re-use the intermediate memory (device_grads) already allocated for maximize
|
||||
if maximize:
|
||||
torch._foreach_add_(device_grads, device_params, alpha=weight_decay)
|
||||
else:
|
||||
device_grads = torch._foreach_add( # type: ignore[assignment]
|
||||
device_grads, device_params, alpha=weight_decay
|
||||
)
|
||||
|
||||
# Decay the first and second moment running average coefficient
|
||||
# Use device beta1 if beta1 is a tensor to ensure all
|
||||
|
|
@ -727,6 +741,7 @@ def _fused_adam(
|
|||
maximize: bool,
|
||||
capturable: bool, # Needed for consistency.
|
||||
differentiable: bool,
|
||||
decoupled_weight_decay: bool,
|
||||
) -> None:
|
||||
if not params:
|
||||
return
|
||||
|
|
@ -781,7 +796,8 @@ def _fused_adam(
|
|||
lr_dict[device] = lr.to(device=device, non_blocking=True) # type: ignore[union-attr]
|
||||
lr = lr_dict[device]
|
||||
torch._foreach_add_(device_state_steps, 1)
|
||||
torch._fused_adam_(
|
||||
func = torch._fused_adam_ if not decoupled_weight_decay else torch._fused_adamw_
|
||||
func(
|
||||
device_params,
|
||||
device_grads,
|
||||
device_exp_avgs,
|
||||
|
|
@ -821,6 +837,7 @@ def adam(
|
|||
grad_scale: Optional[Tensor] = None,
|
||||
found_inf: Optional[Tensor] = None,
|
||||
has_complex: bool = False,
|
||||
decoupled_weight_decay: bool = False,
|
||||
*,
|
||||
amsgrad: bool,
|
||||
beta1: float,
|
||||
|
|
@ -890,4 +907,5 @@ def adam(
|
|||
differentiable=differentiable,
|
||||
grad_scale=grad_scale,
|
||||
found_inf=found_inf,
|
||||
decoupled_weight_decay=decoupled_weight_decay,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -1,29 +1,17 @@
|
|||
# mypy: allow-untyped-decorators
|
||||
# mypy: allow-untyped-defs
|
||||
from typing import cast, List, Optional, Tuple, Union
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
|
||||
from .adam import Adam, adam
|
||||
from .optimizer import (
|
||||
_capturable_doc,
|
||||
_default_to_fused_or_foreach,
|
||||
_device_dtype_check_for_fused,
|
||||
_differentiable_doc,
|
||||
_disable_dynamo_if_unsupported,
|
||||
_foreach_doc,
|
||||
_fused_doc,
|
||||
_get_capturable_supported_devices,
|
||||
_get_scalar_dtype,
|
||||
_get_value,
|
||||
_maximize_doc,
|
||||
_params_doc,
|
||||
_stack_if_compiling,
|
||||
_use_grad_for_differentiable,
|
||||
_view_as_real,
|
||||
DeviceDict,
|
||||
DeviceDtypeDict,
|
||||
Optimizer,
|
||||
ParamsT,
|
||||
)
|
||||
|
||||
|
|
@ -31,7 +19,7 @@ from .optimizer import (
|
|||
__all__ = ["AdamW", "adamw"]
|
||||
|
||||
|
||||
class AdamW(Optimizer):
|
||||
class AdamW(Adam):
|
||||
def __init__(
|
||||
self,
|
||||
params: ParamsT,
|
||||
|
|
@ -47,223 +35,20 @@ class AdamW(Optimizer):
|
|||
differentiable: bool = False,
|
||||
fused: Optional[bool] = None,
|
||||
):
|
||||
if isinstance(lr, Tensor):
|
||||
if foreach and not capturable:
|
||||
raise ValueError(
|
||||
"lr as a Tensor is not supported for capturable=False and foreach=True"
|
||||
)
|
||||
if lr.numel() != 1:
|
||||
raise ValueError("Tensor lr must be 1-element")
|
||||
|
||||
if not 0.0 <= lr:
|
||||
raise ValueError(f"Invalid learning rate: {lr}")
|
||||
if not 0.0 <= eps:
|
||||
raise ValueError(f"Invalid epsilon value: {eps}")
|
||||
if not 0.0 <= betas[0] < 1.0:
|
||||
raise ValueError(f"Invalid beta parameter at index 0: {betas[0]}")
|
||||
if not 0.0 <= betas[1] < 1.0:
|
||||
raise ValueError(f"Invalid beta parameter at index 1: {betas[1]}")
|
||||
if not 0.0 <= weight_decay:
|
||||
raise ValueError(f"Invalid weight_decay value: {weight_decay}")
|
||||
if not (
|
||||
(isinstance(betas[0], float) and isinstance(betas[1], float))
|
||||
or (isinstance(betas[0], Tensor) and isinstance(betas[1], Tensor))
|
||||
):
|
||||
raise ValueError("betas must be either both floats or both Tensors")
|
||||
if isinstance(betas[0], Tensor):
|
||||
if not capturable and foreach:
|
||||
raise ValueError(
|
||||
"betas[0] as a Tensor is not supported for capturable=False and foreach=True"
|
||||
)
|
||||
if betas[0].numel() != 1:
|
||||
raise ValueError("Tensor betas[0] must be 1-element")
|
||||
if isinstance(betas[1], Tensor):
|
||||
if not capturable and foreach:
|
||||
raise ValueError(
|
||||
"betas[1] as a Tensor is not supported for capturable=False and foreach=True"
|
||||
)
|
||||
if betas[1].numel() != 1:
|
||||
raise ValueError("Tensor betas[1] must be 1-element")
|
||||
|
||||
defaults = dict(
|
||||
lr=lr,
|
||||
betas=betas,
|
||||
eps=eps,
|
||||
weight_decay=weight_decay,
|
||||
amsgrad=amsgrad,
|
||||
super().__init__(
|
||||
params,
|
||||
lr,
|
||||
betas,
|
||||
eps,
|
||||
weight_decay,
|
||||
amsgrad,
|
||||
foreach=foreach,
|
||||
maximize=maximize,
|
||||
capturable=capturable,
|
||||
differentiable=differentiable,
|
||||
fused=fused,
|
||||
decoupled_weight_decay=True,
|
||||
)
|
||||
super().__init__(params, defaults)
|
||||
|
||||
if fused:
|
||||
if differentiable:
|
||||
raise RuntimeError("`fused` does not support `differentiable`")
|
||||
self._step_supports_amp_scaling = True
|
||||
if foreach:
|
||||
raise RuntimeError("`fused` and `foreach` cannot be `True` together.")
|
||||
|
||||
def __setstate__(self, state):
|
||||
super().__setstate__(state)
|
||||
for group in self.param_groups:
|
||||
group.setdefault("amsgrad", False)
|
||||
group.setdefault("maximize", False)
|
||||
group.setdefault("foreach", None)
|
||||
group.setdefault("capturable", False)
|
||||
group.setdefault("differentiable", False)
|
||||
fused = group.setdefault("fused", None)
|
||||
for p in group["params"]:
|
||||
p_state = self.state.get(p, [])
|
||||
if len(p_state) != 0 and not torch.is_tensor(p_state["step"]):
|
||||
step_val = float(p_state["step"])
|
||||
p_state["step"] = (
|
||||
torch.tensor(
|
||||
step_val,
|
||||
dtype=_get_scalar_dtype(is_fused=fused),
|
||||
device=p.device,
|
||||
)
|
||||
if group["capturable"] or group["fused"]
|
||||
else torch.tensor(step_val, dtype=_get_scalar_dtype())
|
||||
)
|
||||
|
||||
def _init_group(
|
||||
self,
|
||||
group,
|
||||
params_with_grad,
|
||||
grads,
|
||||
amsgrad,
|
||||
exp_avgs,
|
||||
exp_avg_sqs,
|
||||
max_exp_avg_sqs,
|
||||
state_steps,
|
||||
):
|
||||
has_complex = False
|
||||
for p in group["params"]:
|
||||
if p.grad is None:
|
||||
continue
|
||||
has_complex |= torch.is_complex(p)
|
||||
params_with_grad.append(p)
|
||||
if p.grad.is_sparse:
|
||||
raise RuntimeError("AdamW does not support sparse gradients")
|
||||
grads.append(p.grad)
|
||||
|
||||
state = self.state[p]
|
||||
|
||||
# State initialization
|
||||
if len(state) == 0:
|
||||
if group["fused"]:
|
||||
_device_dtype_check_for_fused(p)
|
||||
# note(crcrpar): Deliberately host `step` on CPU if both capturable and fused are off.
|
||||
# This is because kernel launches are costly on CUDA and XLA.
|
||||
state["step"] = (
|
||||
torch.zeros(
|
||||
(),
|
||||
dtype=_get_scalar_dtype(is_fused=group["fused"]),
|
||||
device=p.device,
|
||||
)
|
||||
if group["capturable"] or group["fused"]
|
||||
else torch.tensor(0.0, dtype=_get_scalar_dtype())
|
||||
)
|
||||
# Exponential moving average of gradient values
|
||||
state["exp_avg"] = torch.zeros_like(
|
||||
p, memory_format=torch.preserve_format
|
||||
)
|
||||
# Exponential moving average of squared gradient values
|
||||
state["exp_avg_sq"] = torch.zeros_like(
|
||||
p, memory_format=torch.preserve_format
|
||||
)
|
||||
if amsgrad:
|
||||
# Maintains max of all exp. moving avg. of sq. grad. values
|
||||
state["max_exp_avg_sq"] = torch.zeros_like(
|
||||
p, memory_format=torch.preserve_format
|
||||
)
|
||||
|
||||
exp_avgs.append(state["exp_avg"])
|
||||
exp_avg_sqs.append(state["exp_avg_sq"])
|
||||
|
||||
if group["amsgrad"]:
|
||||
max_exp_avg_sqs.append(state["max_exp_avg_sq"])
|
||||
if group["differentiable"] and state["step"].requires_grad:
|
||||
raise RuntimeError(
|
||||
"`requires_grad` is not supported for `step` in differentiable mode"
|
||||
)
|
||||
|
||||
# Foreach without capturable does not support a tensor lr
|
||||
if (
|
||||
group["foreach"]
|
||||
and isinstance(group["lr"], Tensor)
|
||||
and not group["capturable"]
|
||||
):
|
||||
raise RuntimeError(
|
||||
"lr as a Tensor is not supported for capturable=False and foreach=True"
|
||||
)
|
||||
|
||||
state_steps.append(state["step"])
|
||||
return has_complex
|
||||
|
||||
@_use_grad_for_differentiable
|
||||
def step(self, closure=None):
|
||||
"""Perform a single optimization step.
|
||||
|
||||
Args:
|
||||
closure (Callable, optional): A closure that reevaluates the model
|
||||
and returns the loss.
|
||||
"""
|
||||
self._cuda_graph_capture_health_check()
|
||||
|
||||
loss = None
|
||||
if closure is not None:
|
||||
with torch.enable_grad():
|
||||
loss = closure()
|
||||
|
||||
for group in self.param_groups:
|
||||
params_with_grad: List[Tensor] = []
|
||||
grads: List[Tensor] = []
|
||||
exp_avgs: List[Tensor] = []
|
||||
exp_avg_sqs: List[Tensor] = []
|
||||
max_exp_avg_sqs: List[Tensor] = []
|
||||
state_steps: List[Tensor] = []
|
||||
amsgrad: bool = group["amsgrad"]
|
||||
beta1, beta2 = cast(Tuple[float, float], group["betas"])
|
||||
|
||||
has_complex = self._init_group(
|
||||
group,
|
||||
params_with_grad,
|
||||
grads,
|
||||
amsgrad,
|
||||
exp_avgs,
|
||||
exp_avg_sqs,
|
||||
max_exp_avg_sqs,
|
||||
state_steps,
|
||||
)
|
||||
|
||||
adamw(
|
||||
params_with_grad,
|
||||
grads,
|
||||
exp_avgs,
|
||||
exp_avg_sqs,
|
||||
max_exp_avg_sqs,
|
||||
state_steps,
|
||||
amsgrad=amsgrad,
|
||||
beta1=beta1,
|
||||
beta2=beta2,
|
||||
lr=group["lr"],
|
||||
weight_decay=group["weight_decay"],
|
||||
eps=group["eps"],
|
||||
maximize=group["maximize"],
|
||||
foreach=group["foreach"],
|
||||
capturable=group["capturable"],
|
||||
differentiable=group["differentiable"],
|
||||
fused=group["fused"],
|
||||
grad_scale=getattr(self, "grad_scale", None),
|
||||
found_inf=getattr(self, "found_inf", None),
|
||||
has_complex=has_complex,
|
||||
)
|
||||
|
||||
return loss
|
||||
|
||||
|
||||
AdamW.__doc__ = (
|
||||
|
|
@ -334,478 +119,7 @@ AdamW.__doc__ = (
|
|||
)
|
||||
|
||||
|
||||
def _single_tensor_adamw(
|
||||
params: List[Tensor],
|
||||
grads: List[Tensor],
|
||||
exp_avgs: List[Tensor],
|
||||
exp_avg_sqs: List[Tensor],
|
||||
max_exp_avg_sqs: List[Tensor],
|
||||
state_steps: List[Tensor],
|
||||
grad_scale: Optional[Tensor],
|
||||
found_inf: Optional[Tensor],
|
||||
*,
|
||||
amsgrad: bool,
|
||||
beta1: Union[Tensor, float],
|
||||
beta2: Union[Tensor, float],
|
||||
lr: Union[Tensor, float],
|
||||
weight_decay: float,
|
||||
eps: float,
|
||||
maximize: bool,
|
||||
capturable: bool,
|
||||
differentiable: bool,
|
||||
has_complex: bool,
|
||||
):
|
||||
assert grad_scale is None and found_inf is None
|
||||
|
||||
if torch.jit.is_scripting():
|
||||
# this assert is due to JIT being dumb and not realizing that the ops below
|
||||
# have overloads to handle both float and Tensor lrs, so we just assert it's
|
||||
# a float since most people using JIT are using floats
|
||||
assert isinstance(lr, float)
|
||||
assert isinstance(beta1, float)
|
||||
assert isinstance(beta2, float)
|
||||
|
||||
# We only shuffle around the beta when it is a Tensor, otherwise, we prefer
|
||||
# treating it as a scalar.
|
||||
# Note: ensure type declaration is under conditional check for isinstance
|
||||
# or else torchscript will get cranky about the DeviceDict type.
|
||||
if isinstance(beta1, Tensor):
|
||||
beta1_dict: Optional[DeviceDtypeDict] = {(beta1.device, beta1.dtype): beta1}
|
||||
else:
|
||||
beta1_dict = None
|
||||
|
||||
for i, param in enumerate(params):
|
||||
grad = grads[i] if not maximize else -grads[i]
|
||||
exp_avg = exp_avgs[i]
|
||||
exp_avg_sq = exp_avg_sqs[i]
|
||||
step_t = state_steps[i]
|
||||
|
||||
# If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable]
|
||||
if not torch.compiler.is_compiling() and capturable:
|
||||
capturable_supported_devices = _get_capturable_supported_devices()
|
||||
assert (
|
||||
param.device.type == step_t.device.type
|
||||
and param.device.type in capturable_supported_devices
|
||||
), f"If capturable=True, params and state_steps must be on supported devices: {capturable_supported_devices}."
|
||||
|
||||
if torch.is_complex(param):
|
||||
grad = torch.view_as_real(grad)
|
||||
exp_avg = torch.view_as_real(exp_avg)
|
||||
exp_avg_sq = torch.view_as_real(exp_avg_sq)
|
||||
if amsgrad:
|
||||
max_exp_avg_sqs[i] = torch.view_as_real(max_exp_avg_sqs[i])
|
||||
param = torch.view_as_real(param)
|
||||
|
||||
# update step
|
||||
step_t += 1
|
||||
|
||||
# Perform stepweight decay
|
||||
param.mul_(1 - lr * weight_decay)
|
||||
|
||||
device = param.device
|
||||
|
||||
device = param.device
|
||||
dtype = param.dtype
|
||||
|
||||
if beta1_dict is not None:
|
||||
dtype = param.dtype # type: ignore[union-attr]
|
||||
|
||||
# cast to workaround https://github.com/pytorch/pytorch/issues/140601
|
||||
key = (device, dtype)
|
||||
if key not in beta1_dict:
|
||||
beta1_dict[key] = beta1.to(device=device, dtype=dtype, non_blocking=True) # type: ignore[union-attr]
|
||||
|
||||
device_beta1: Union[float, Tensor] = beta1_dict[key]
|
||||
else:
|
||||
device_beta1 = beta1
|
||||
|
||||
# Decay the first and second moment running average coefficient
|
||||
exp_avg.lerp_(grad, 1 - device_beta1)
|
||||
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
|
||||
|
||||
if capturable or differentiable:
|
||||
step = step_t
|
||||
|
||||
bias_correction1 = 1 - beta1**step
|
||||
bias_correction2 = 1 - beta2**step
|
||||
|
||||
step_size = lr / bias_correction1
|
||||
step_size_neg = step_size.neg()
|
||||
|
||||
bias_correction2_sqrt = bias_correction2.sqrt()
|
||||
|
||||
if amsgrad:
|
||||
# Maintains the maximum of all 2nd moment running avg. till now
|
||||
if differentiable:
|
||||
max_exp_avg_sq = max_exp_avg_sqs[i].clone()
|
||||
else:
|
||||
max_exp_avg_sq = max_exp_avg_sqs[i]
|
||||
|
||||
max_exp_avg_sqs[i].copy_(torch.maximum(max_exp_avg_sq, exp_avg_sq))
|
||||
|
||||
# Uses the max. for normalizing running avg. of gradient
|
||||
# Folds in (admittedly ugly) 1-elem step_size math here to avoid extra param-set-sized read+write
|
||||
# (can't fold it into addcdiv_ below because addcdiv_ requires value is a Number, not a Tensor)
|
||||
denom = (
|
||||
max_exp_avg_sqs[i].sqrt() / (bias_correction2_sqrt * step_size_neg)
|
||||
).add_(eps / step_size_neg)
|
||||
else:
|
||||
denom = (
|
||||
exp_avg_sq.sqrt() / (bias_correction2_sqrt * step_size_neg)
|
||||
).add_(eps / step_size_neg)
|
||||
|
||||
param.addcdiv_(exp_avg, denom)
|
||||
else:
|
||||
step = _get_value(step_t)
|
||||
|
||||
bias_correction1 = 1 - beta1**step
|
||||
bias_correction2 = 1 - beta2**step
|
||||
|
||||
step_size = lr / bias_correction1
|
||||
|
||||
bias_correction2_sqrt = bias_correction2**0.5
|
||||
|
||||
if amsgrad:
|
||||
# Maintains the maximum of all 2nd moment running avg. till now
|
||||
torch.maximum(max_exp_avg_sqs[i], exp_avg_sq, out=max_exp_avg_sqs[i])
|
||||
|
||||
# Use the max. for normalizing running avg. of gradient
|
||||
denom = (max_exp_avg_sqs[i].sqrt() / bias_correction2_sqrt).add_(eps)
|
||||
else:
|
||||
denom = (exp_avg_sq.sqrt() / bias_correction2_sqrt).add_(eps)
|
||||
|
||||
param.addcdiv_(exp_avg, denom, value=-step_size)
|
||||
|
||||
# Lastly, switch back to complex view
|
||||
if amsgrad and torch.is_complex(params[i]):
|
||||
max_exp_avg_sqs[i] = torch.view_as_complex(max_exp_avg_sqs[i])
|
||||
|
||||
|
||||
def _multi_tensor_adamw(
|
||||
params: List[Tensor],
|
||||
grads: List[Tensor],
|
||||
exp_avgs: List[Tensor],
|
||||
exp_avg_sqs: List[Tensor],
|
||||
max_exp_avg_sqs: List[Tensor],
|
||||
state_steps: List[Tensor],
|
||||
grad_scale: Optional[Tensor],
|
||||
found_inf: Optional[Tensor],
|
||||
*,
|
||||
amsgrad: bool,
|
||||
beta1: Union[Tensor, float],
|
||||
beta2: Union[Tensor, float],
|
||||
lr: Union[Tensor, float],
|
||||
weight_decay: float,
|
||||
eps: float,
|
||||
maximize: bool,
|
||||
capturable: bool,
|
||||
differentiable: bool,
|
||||
has_complex: bool,
|
||||
):
|
||||
if len(params) == 0:
|
||||
return
|
||||
|
||||
if isinstance(lr, Tensor) and not capturable:
|
||||
raise RuntimeError(
|
||||
"lr as a Tensor is not supported for capturable=False and foreach=True"
|
||||
)
|
||||
|
||||
if isinstance(beta1, Tensor):
|
||||
if not capturable:
|
||||
raise ValueError(
|
||||
"beta1 as a Tensor is not supported for capturable=False and foreach=True"
|
||||
)
|
||||
if beta1.numel() != 1:
|
||||
raise ValueError("Tensor beta1 must be 1-element")
|
||||
|
||||
if isinstance(beta2, Tensor):
|
||||
if not capturable:
|
||||
raise ValueError(
|
||||
"beta2 as a Tensor is not supported for capturable=False and foreach=True"
|
||||
)
|
||||
if beta2.numel() != 1:
|
||||
raise ValueError("Tensor beta2 must be 1-element")
|
||||
|
||||
# If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable]
|
||||
if not torch.compiler.is_compiling() and capturable:
|
||||
capturable_supported_devices = _get_capturable_supported_devices(
|
||||
supports_xla=False
|
||||
)
|
||||
assert all(
|
||||
p.device.type == step.device.type
|
||||
and p.device.type in capturable_supported_devices
|
||||
for p, step in zip(params, state_steps)
|
||||
), f"If capturable=True, params and state_steps must be on supported devices: {capturable_supported_devices}."
|
||||
|
||||
assert not differentiable, "_foreach ops don't support autograd"
|
||||
|
||||
assert grad_scale is None and found_inf is None
|
||||
|
||||
grouped_tensors = Optimizer._group_tensors_by_device_and_dtype(
|
||||
[params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps] # type: ignore[list-item]
|
||||
)
|
||||
|
||||
# We only shuffle around the beta when it is a Tensor and on CUDA, otherwise, we prefer
|
||||
# treating it as a scalar.
|
||||
beta1_dict: Optional[DeviceDict] = ( # type: ignore[attr-defined]
|
||||
{beta1.device: beta1}
|
||||
if isinstance(beta1, Tensor) and str(beta1.device) != "cpu"
|
||||
else None
|
||||
)
|
||||
|
||||
for (
|
||||
device_params_,
|
||||
device_grads_,
|
||||
device_exp_avgs_,
|
||||
device_exp_avg_sqs_,
|
||||
device_max_exp_avg_sqs_,
|
||||
device_state_steps_,
|
||||
), _ in grouped_tensors.values():
|
||||
device_params = cast(List[Tensor], device_params_)
|
||||
device_grads = cast(List[Tensor], device_grads_)
|
||||
device_exp_avgs = cast(List[Tensor], device_exp_avgs_)
|
||||
device_exp_avg_sqs = cast(List[Tensor], device_exp_avg_sqs_)
|
||||
device_state_steps = cast(List[Tensor], device_state_steps_)
|
||||
|
||||
device = device_params[0].device
|
||||
if beta1_dict is not None and device not in beta1_dict:
|
||||
beta1_dict[device] = beta1.to(device=device, non_blocking=True) # type: ignore[union-attr]
|
||||
|
||||
device_beta1 = beta1_dict[device] if beta1_dict else beta1
|
||||
|
||||
if has_complex:
|
||||
if amsgrad:
|
||||
device_max_exp_avg_sqs = cast(List[Tensor], device_max_exp_avg_sqs_)
|
||||
_view_as_real(
|
||||
device_params,
|
||||
device_grads,
|
||||
device_exp_avgs,
|
||||
device_exp_avg_sqs,
|
||||
device_max_exp_avg_sqs,
|
||||
)
|
||||
else:
|
||||
_view_as_real(
|
||||
device_params, device_grads, device_exp_avgs, device_exp_avg_sqs
|
||||
)
|
||||
|
||||
if maximize:
|
||||
device_grads = torch._foreach_neg(device_grads) # type: ignore[assignment]
|
||||
|
||||
# Update steps
|
||||
# If steps are on CPU, foreach will fall back to the slow path, which is a for-loop calling t.add(1) over
|
||||
# and over. 1 will then be wrapped into a Tensor over and over again, which is slower than if we just
|
||||
# wrapped it once now. The alpha is required to assure we go to the right overload.
|
||||
if not torch.compiler.is_compiling() and device_state_steps[0].is_cpu:
|
||||
torch._foreach_add_(
|
||||
device_state_steps, torch.tensor(1.0, device="cpu"), alpha=1.0
|
||||
)
|
||||
else:
|
||||
torch._foreach_add_(device_state_steps, 1)
|
||||
|
||||
# Perform stepweight decay
|
||||
if weight_decay != 0:
|
||||
torch._foreach_mul_(device_params, 1 - lr * weight_decay)
|
||||
|
||||
# Decay the first and second moment running average coefficient
|
||||
torch._foreach_lerp_(device_exp_avgs, device_grads, 1 - device_beta1)
|
||||
|
||||
torch._foreach_mul_(device_exp_avg_sqs, beta2)
|
||||
# Due to the strictness of the _foreach_addcmul API, we can't have a single
|
||||
# tensor scalar as the scalar arg (only python number is supported there)
|
||||
# as a result, separate out the value mul
|
||||
# Filed https://github.com/pytorch/pytorch/issues/139795
|
||||
if isinstance(beta2, torch.Tensor):
|
||||
scaled_device_grads = torch._foreach_mul(device_grads, 1 - beta2) # type: ignore[assignment]
|
||||
value = 1.0
|
||||
else:
|
||||
scaled_device_grads = device_grads # type: ignore[assignment]
|
||||
value = 1 - beta2
|
||||
|
||||
torch._foreach_addcmul_(
|
||||
device_exp_avg_sqs, scaled_device_grads, device_grads, value
|
||||
)
|
||||
|
||||
# Delete the local intermediate(s) since they won't be used anymore to save on peak memory
|
||||
del device_grads
|
||||
del scaled_device_grads
|
||||
|
||||
bias_correction1: Union[Tuple[Tensor, ...], List[Tensor]]
|
||||
bias_correction2: Union[Tuple[Tensor, ...], List[Tensor]]
|
||||
bias_correction2_sqrt: Union[Tuple[Tensor, ...], List[Tensor]]
|
||||
|
||||
if capturable:
|
||||
bias_correction1 = torch._foreach_pow(beta1, device_state_steps) # type: ignore[arg-type]
|
||||
bias_correction2 = torch._foreach_pow(beta2, device_state_steps) # type: ignore[arg-type]
|
||||
# foreach_sub doesn't allow a scalar as the first arg
|
||||
torch._foreach_sub_(bias_correction1, 1)
|
||||
torch._foreach_sub_(bias_correction2, 1)
|
||||
# we do not negate bias_correction1 as it'll need to be negated later anyway
|
||||
torch._foreach_neg_(bias_correction2)
|
||||
|
||||
# foreach_div doesn't allow a scalar as the first arg
|
||||
torch._foreach_div_(bias_correction1, lr)
|
||||
torch._foreach_reciprocal_(bias_correction1)
|
||||
|
||||
torch._foreach_sqrt_(bias_correction2)
|
||||
|
||||
# Re-assign for clarity as we maintain minimal intermediates: we'll have
|
||||
# step_size = - lr / (1 - beta1 ^ t) where t = num_steps
|
||||
# bias_correction2_sqrt = sqrt(1 - beta2 ^ t)
|
||||
step_size = bias_correction1
|
||||
bias_correction2_sqrt = bias_correction2
|
||||
|
||||
if amsgrad:
|
||||
device_max_exp_avg_sqs = cast(List[Tensor], device_max_exp_avg_sqs_)
|
||||
|
||||
# Maintains the maximum of all 2nd moment running avg. till now
|
||||
torch._foreach_maximum_(device_max_exp_avg_sqs, device_exp_avg_sqs)
|
||||
|
||||
# Use the max. for normalizing running avg. of gradient
|
||||
exp_avg_sq_sqrt = torch._foreach_sqrt(device_max_exp_avg_sqs)
|
||||
else:
|
||||
exp_avg_sq_sqrt = torch._foreach_sqrt(device_exp_avg_sqs)
|
||||
|
||||
torch._foreach_div_(exp_avg_sq_sqrt, bias_correction2_sqrt)
|
||||
torch._foreach_add_(exp_avg_sq_sqrt, eps)
|
||||
torch._foreach_div_(exp_avg_sq_sqrt, step_size)
|
||||
|
||||
# at this point, exp_avg_sq_sqrt = - (1 - beta^t) * [sqrt(exp_avg_sq / (1 - beta2^t)) + eps] / lr
|
||||
torch._foreach_addcdiv_(device_params, device_exp_avgs, exp_avg_sq_sqrt)
|
||||
else:
|
||||
bias_correction1 = [
|
||||
1 - beta1 ** _get_value(step) for step in device_state_steps
|
||||
]
|
||||
bias_correction2 = [
|
||||
1 - beta2 ** _get_value(step) for step in device_state_steps
|
||||
]
|
||||
|
||||
step_size = _stack_if_compiling([(lr / bc) * -1 for bc in bias_correction1])
|
||||
|
||||
bias_correction2_sqrt = [
|
||||
bc**0.5 for bc in bias_correction2 # type: ignore[arg-type]
|
||||
]
|
||||
|
||||
if amsgrad:
|
||||
device_max_exp_avg_sqs = cast(List[Tensor], device_max_exp_avg_sqs_)
|
||||
|
||||
# Maintains the maximum of all 2nd moment running avg. till now
|
||||
torch._foreach_maximum_(device_max_exp_avg_sqs, device_exp_avg_sqs)
|
||||
|
||||
# Use the max. for normalizing running avg. of gradient
|
||||
exp_avg_sq_sqrt = torch._foreach_sqrt(device_max_exp_avg_sqs)
|
||||
else:
|
||||
exp_avg_sq_sqrt = torch._foreach_sqrt(device_exp_avg_sqs)
|
||||
|
||||
torch._foreach_div_(exp_avg_sq_sqrt, bias_correction2_sqrt)
|
||||
torch._foreach_add_(exp_avg_sq_sqrt, eps)
|
||||
torch._foreach_addcdiv_(
|
||||
device_params,
|
||||
device_exp_avgs,
|
||||
exp_avg_sq_sqrt,
|
||||
step_size, # type: ignore[arg-type]
|
||||
)
|
||||
|
||||
|
||||
def _fused_adamw(
|
||||
params: List[Tensor],
|
||||
grads: List[Tensor],
|
||||
exp_avgs: List[Tensor],
|
||||
exp_avg_sqs: List[Tensor],
|
||||
max_exp_avg_sqs: List[Tensor],
|
||||
state_steps: List[Tensor],
|
||||
grad_scale: Optional[Tensor],
|
||||
found_inf: Optional[Tensor],
|
||||
*,
|
||||
amsgrad: bool,
|
||||
beta1: float,
|
||||
beta2: float,
|
||||
lr: Union[Tensor, float],
|
||||
weight_decay: float,
|
||||
eps: float,
|
||||
maximize: bool,
|
||||
capturable: bool, # Needed for consistency.
|
||||
differentiable: bool,
|
||||
has_complex: bool, # Needed for consistency.
|
||||
) -> None:
|
||||
if not params:
|
||||
return
|
||||
if differentiable:
|
||||
raise RuntimeError("Adam with fused=True does not support differentiable=True")
|
||||
|
||||
grad_scale_dict: DeviceDict = (
|
||||
{grad_scale.device: grad_scale} if grad_scale is not None else {}
|
||||
)
|
||||
found_inf_dict: DeviceDict = (
|
||||
{found_inf.device: found_inf} if found_inf is not None else {}
|
||||
)
|
||||
|
||||
# We only shuffle around the lr when it is a Tensor and on CUDA, otherwise, we prefer
|
||||
# treating it as a scalar.
|
||||
lr_dict: Optional[DeviceDict] = (
|
||||
{lr.device: lr} if isinstance(lr, Tensor) and str(lr.device) != "cpu" else None
|
||||
)
|
||||
|
||||
grouped_tensors = Optimizer._group_tensors_by_device_and_dtype(
|
||||
[params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps] # type: ignore[list-item]
|
||||
)
|
||||
for (device, _), (
|
||||
(
|
||||
device_params_,
|
||||
device_grads_,
|
||||
device_exp_avgs_,
|
||||
device_exp_avg_sqs_,
|
||||
device_max_exp_avg_sqs,
|
||||
device_state_steps_,
|
||||
),
|
||||
_,
|
||||
) in grouped_tensors.items():
|
||||
device_params = cast(List[Tensor], device_params_)
|
||||
device_grads = cast(List[Tensor], device_grads_)
|
||||
device_exp_avgs = cast(List[Tensor], device_exp_avgs_)
|
||||
device_exp_avg_sqs = cast(List[Tensor], device_exp_avg_sqs_)
|
||||
device_state_steps = cast(List[Tensor], device_state_steps_)
|
||||
|
||||
if device.type == "mps": # type: ignore[union-attr]
|
||||
assert found_inf is None and grad_scale is None
|
||||
|
||||
device_grad_scale, device_found_inf = None, None
|
||||
if grad_scale is not None:
|
||||
device_grad_scale = grad_scale_dict.setdefault(
|
||||
device, grad_scale.to(device, non_blocking=True)
|
||||
)
|
||||
if found_inf is not None:
|
||||
device_found_inf = found_inf_dict.setdefault(
|
||||
device, found_inf.to(device, non_blocking=True)
|
||||
)
|
||||
if lr_dict is not None and device not in lr_dict:
|
||||
lr = lr_dict.setdefault(
|
||||
device, lr.to(device=device, non_blocking=True) # type: ignore[union-attr]
|
||||
)
|
||||
torch._foreach_add_(device_state_steps, 1)
|
||||
torch._fused_adamw_(
|
||||
device_params,
|
||||
device_grads,
|
||||
device_exp_avgs,
|
||||
device_exp_avg_sqs,
|
||||
device_max_exp_avg_sqs, # type: ignore[arg-type]
|
||||
device_state_steps,
|
||||
amsgrad=amsgrad,
|
||||
lr=lr, # type: ignore[arg-type]
|
||||
beta1=beta1,
|
||||
beta2=beta2,
|
||||
weight_decay=weight_decay,
|
||||
eps=eps,
|
||||
maximize=maximize,
|
||||
grad_scale=device_grad_scale,
|
||||
found_inf=device_found_inf,
|
||||
)
|
||||
if device_found_inf is not None:
|
||||
torch._foreach_sub_(
|
||||
device_state_steps, [device_found_inf] * len(device_state_steps)
|
||||
)
|
||||
|
||||
|
||||
@_disable_dynamo_if_unsupported(single_tensor_fn=_single_tensor_adamw)
|
||||
# @_disable_dynamo_if_unsupported logic occurs in the decorator that's applied to F.adam
|
||||
def adamw(
|
||||
params: List[Tensor],
|
||||
grads: List[Tensor],
|
||||
|
|
@ -835,48 +149,20 @@ def adamw(
|
|||
|
||||
See :class:`~torch.optim.AdamW` for details.
|
||||
"""
|
||||
if not torch.compiler.is_compiling() and not all(
|
||||
isinstance(t, torch.Tensor) for t in state_steps
|
||||
):
|
||||
raise RuntimeError(
|
||||
"API has changed, `state_steps` argument must contain a list of singleton tensors"
|
||||
)
|
||||
|
||||
# Respect when the user inputs False/True for foreach or fused. We only want to change
|
||||
# the default when neither have been user-specified. Note that we default to foreach
|
||||
# and pass False to use_fused. This is not a mistake--we want to give the fused impl
|
||||
# bake-in time before making it the default, even if it is typically faster.
|
||||
if fused is None and foreach is None:
|
||||
_, foreach = _default_to_fused_or_foreach(
|
||||
params, differentiable, use_fused=False
|
||||
)
|
||||
# Do not flip on foreach for the unsupported case where lr is a Tensor and capturable=False.
|
||||
if foreach and isinstance(lr, Tensor) and not capturable:
|
||||
foreach = False
|
||||
if fused is None:
|
||||
fused = False
|
||||
if foreach is None:
|
||||
foreach = False
|
||||
|
||||
if foreach and torch.jit.is_scripting():
|
||||
raise RuntimeError("torch.jit.script not supported with foreach optimizers")
|
||||
if fused and torch.jit.is_scripting():
|
||||
raise RuntimeError("torch.jit.script not supported with fused optimizers")
|
||||
|
||||
if fused and not torch.jit.is_scripting():
|
||||
func = _fused_adamw
|
||||
elif foreach and not torch.jit.is_scripting():
|
||||
func = _multi_tensor_adamw
|
||||
else:
|
||||
func = _single_tensor_adamw
|
||||
|
||||
func(
|
||||
adam(
|
||||
params,
|
||||
grads,
|
||||
exp_avgs,
|
||||
exp_avg_sqs,
|
||||
max_exp_avg_sqs,
|
||||
state_steps,
|
||||
foreach=foreach,
|
||||
capturable=capturable,
|
||||
differentiable=differentiable,
|
||||
fused=fused,
|
||||
grad_scale=grad_scale,
|
||||
found_inf=found_inf,
|
||||
has_complex=has_complex,
|
||||
amsgrad=amsgrad,
|
||||
beta1=beta1,
|
||||
beta2=beta2,
|
||||
|
|
@ -884,9 +170,5 @@ def adamw(
|
|||
weight_decay=weight_decay,
|
||||
eps=eps,
|
||||
maximize=maximize,
|
||||
capturable=capturable,
|
||||
differentiable=differentiable,
|
||||
grad_scale=grad_scale,
|
||||
found_inf=found_inf,
|
||||
has_complex=has_complex,
|
||||
decoupled_weight_decay=True,
|
||||
)
|
||||
|
|
|
|||
Loading…
Reference in a new issue