From 92d8965082bf77315b0d059e37786aa98d268a44 Mon Sep 17 00:00:00 2001 From: emmettbicker <88175920+EmmettBicker@users.noreply.github.com> Date: Mon, 30 Dec 2024 01:11:55 +0000 Subject: [PATCH] Adding support for differentiable lr, weight_decay, and betas in Adam/AdamW (#143726) Third PR in a series of PRs to broaden differentiable optimizer support w/ @janeyx99 (sorry for pinging over the holidays! I just wanted to put this one out but I am definitely not asking for review or anything like that rn) This is also going to probably be my last PR before the holidays! Note: This is a branch of #143710 -- I've never worked on a branch of a branch before so I wasn't sure about the protocol so I thought I'd just made the PR and wait until that one gets merged. This is adding support for differentiable lr, weight_decay, and betas to Adam and AdamW (but after refactoring AdamW into an Adam subclass, it's really just changing code in torch/optim/adam.py) I had one main thing I was wondering about, which is that adam already has a differentiable flag built in, so I have code like this ```py if differentiable and isinstance(beta2, Tensor): if beta2.requires_grad: exp_avg_sq.mul_(beta2).addcmul_(grad, grad.conj().mul(1 - beta2)) else: exp_avg_sq.mul_(beta2).addcmul_(grad, grad.conj(), value=1 - beta2) else: exp_avg_sq.mul_(beta2).addcmul_(grad, grad.conj(), value=1 - beta2) ``` That I could definitely simplify to just ```py if differentiable and isinstance(beta2, Tensor): exp_avg_sq.mul_(beta2).addcmul_(grad, grad.conj().mul(1 - beta2)) else: exp_avg_sq.mul_(beta2).addcmul_(grad, grad.conj(), value=1 - beta2) ``` It would definitely be a little slower in the case that it's differentiable but doesn't need a grad for beta2, but the code would also be a lot more clear and I'm debating speed vs future code usability. Also the line in the above example: ```py exp_avg_sq.mul_(beta2).addcmul_(grad, grad.conj().mul(1 - beta2)) ``` was concerning to me because it is considerably more expensive than `value=1 - beta2`, but I couldn't think of a better way to do it. Further work on #141832 Pull Request resolved: https://github.com/pytorch/pytorch/pull/143726 Approved by: https://github.com/janeyx99 --- test/optim/test_optim.py | 292 ++++++++++++++++++++++++++++++++++++++- torch/optim/adam.py | 50 ++++++- 2 files changed, 331 insertions(+), 11 deletions(-) diff --git a/test/optim/test_optim.py b/test/optim/test_optim.py index 0fd270338d4..00f5db1478c 100644 --- a/test/optim/test_optim.py +++ b/test/optim/test_optim.py @@ -76,12 +76,22 @@ def _multistep_backprop_diff_hyperparams_fn( # This copy is necessary so the update on line 78 doesn't overwrite the original kwargs values kwargs = kwargs.copy() + + # Have to pass in beta1 and beta2 separately + # so they're passed in as Tensors (not a tuple) and recognized by gradcheck + if "beta1" in kwargs or "beta2" in kwargs: + # Prevent just one beta kwarg from being passed in + assert ( + "beta1" in kwargs and "beta2" in kwargs + ), "Both betas should be defined in kwargs" + kwargs.update({"betas": (kwargs.pop("beta1"), kwargs.pop("beta2"))}) + kwargs.update( {k: v.clone() if isinstance(v, torch.Tensor) else v for k, v in kwargs.items()} ) differentiable_kwargs = [ v for v in kwargs.values() if isinstance(v, torch.Tensor) and v.requires_grad - ] + ] + (list(kwargs["betas"]) if "betas" in kwargs else []) criterion = nn.MSELoss() @@ -104,6 +114,10 @@ def _multistep_backprop_diff_hyperparams_fn( meta_loss = loss meta_loss.backward(inputs=(*differentiable_kwargs,), create_graph=True) + # Extra check to make sure the test properly computed a gradient for all kwargs + for kwarg in differentiable_kwargs: + assert kwarg.grad is not None + return ( (meta_loss,) + tuple( @@ -111,11 +125,7 @@ def _multistep_backprop_diff_hyperparams_fn( for v in optimizer.state[params].values() if isinstance(v, torch.Tensor) and v.requires_grad ) - + tuple( - v - for v in kwargs.values() - if isinstance(v, torch.Tensor) and v.requires_grad - ) + + tuple(differentiable_kwargs) ) @@ -404,6 +414,276 @@ class TestDifferentiableOptimizer(TestCase): ), ) + def test_adam_differentiable_lr(self): + params = torch.rand(10, requires_grad=True, dtype=torch.float64) + grad = torch.rand_like(params, requires_grad=True, dtype=torch.float64) + lr = torch.tensor(0.001, requires_grad=True, dtype=torch.float64) + + state = {} + state["step"] = torch.tensor(10.0, requires_grad=False, dtype=torch.float64) + state["exp_avg"] = torch.rand(10, requires_grad=True, dtype=torch.float64) + state["exp_avg_sq"] = torch.rand(10, requires_grad=True, dtype=torch.float64) + state["max_exp_avg_sq"] = torch.rand( + 10, requires_grad=True, dtype=torch.float64 + ) + kwargs: dict[str, Any] = {"lr": lr, "differentiable": True} + + gradcheck( + _multistep_backprop_diff_hyperparams_fn, + ( + params, + grad, + state, + Adam, + kwargs, # includes lr + *state.values(), + *kwargs.values(), + ), + ) + + def test_adam_differentiable_weight_decay(self): + params = torch.rand(10, requires_grad=True, dtype=torch.float64) + grad = torch.rand_like(params, requires_grad=True, dtype=torch.float64) + weight_decay = torch.tensor(0.999, requires_grad=True, dtype=torch.float64) + + state = {} + state["step"] = torch.tensor(10.0, requires_grad=False, dtype=torch.float64) + state["exp_avg"] = torch.rand(10, requires_grad=True, dtype=torch.float64) + state["exp_avg_sq"] = torch.rand(10, requires_grad=True, dtype=torch.float64) + state["max_exp_avg_sq"] = torch.rand( + 10, requires_grad=True, dtype=torch.float64 + ) + kwargs: dict[str, Any] = {"weight_decay": weight_decay, "differentiable": True} + + gradcheck( + _multistep_backprop_diff_hyperparams_fn, + ( + params, + grad, + state, + Adam, + kwargs, # includes weight_decay + *state.values(), + *kwargs.values(), + ), + ) + + def test_adam_differentiable_betas(self): + params = torch.rand(10, requires_grad=True, dtype=torch.float64) + grad = torch.rand_like(params, requires_grad=True, dtype=torch.float64) + + lr = torch.tensor([0.001], requires_grad=True, dtype=torch.float64) + betas = ( + torch.tensor(0.9, requires_grad=True, dtype=torch.float64), + torch.tensor(0.999, requires_grad=True, dtype=torch.float64), + ) + state = {} + state["step"] = torch.tensor(10.0, requires_grad=False, dtype=torch.float64) + state["exp_avg"] = torch.rand(10, requires_grad=True, dtype=torch.float64) + state["exp_avg_sq"] = torch.rand(10, requires_grad=True, dtype=torch.float64) + state["max_exp_avg_sq"] = torch.rand( + 10, requires_grad=True, dtype=torch.float64 + ) + + # Have to pass in beta1 and beta2 separately + # so they're passed in as Tensors (not a tuple) and recognized by gradcheck. + # In the test, this is called: kwargs.update({betas: (beta1, beta2)}) + kwargs: dict[str, Any] = { + "beta1": betas[0], + "beta2": betas[1], + "lr": lr, + "differentiable": True, + } + + gradcheck( + _multistep_backprop_diff_hyperparams_fn, + ( + params, + grad, + state, + Adam, + kwargs, # includes betas + *state.values(), + *kwargs.values(), + ), + ) + + def test_adam_differentiable_all_hyperparams(self): + params = torch.rand(10, requires_grad=True, dtype=torch.float64) + grad = torch.rand_like(params, requires_grad=True, dtype=torch.float64) + + lr = torch.tensor(0.001, requires_grad=True, dtype=torch.float64) + weight_decay = torch.tensor(0.999, requires_grad=True, dtype=torch.float64) + betas = ( + torch.tensor(0.9, requires_grad=True, dtype=torch.float64), + torch.tensor(0.999, requires_grad=True, dtype=torch.float64), + ) + state = {} + state["step"] = torch.tensor(10.0, requires_grad=False, dtype=torch.float64) + state["exp_avg"] = torch.rand(10, requires_grad=True, dtype=torch.float64) + state["exp_avg_sq"] = torch.rand(10, requires_grad=True, dtype=torch.float64) + state["max_exp_avg_sq"] = torch.rand( + 10, requires_grad=True, dtype=torch.float64 + ) + + # Have to pass in beta1 and beta2 separately + # so they're passed in as Tensors (not a tuple) and recognized by gradcheck. + # In the test, this is called: kwargs.update({betas: (beta1, beta2)}) + kwargs: dict[str, Any] = { + "lr": lr, + "weight_decay": weight_decay, + "beta1": betas[0], + "beta2": betas[1], + "differentiable": True, + } + + gradcheck( + _multistep_backprop_diff_hyperparams_fn, + ( + params, + grad, + state, + Adam, + kwargs, # includes betas + *state.values(), + *kwargs.values(), + ), + ) + + def test_adamw_differentiable_lr(self): + params = torch.rand(10, requires_grad=True, dtype=torch.float64) + grad = torch.rand_like(params, requires_grad=True, dtype=torch.float64) + lr = torch.tensor(0.001, requires_grad=True, dtype=torch.float64) + + state = {} + state["step"] = torch.tensor(10.0, requires_grad=False, dtype=torch.float64) + state["exp_avg"] = torch.rand(10, requires_grad=True, dtype=torch.float64) + state["exp_avg_sq"] = torch.rand(10, requires_grad=True, dtype=torch.float64) + state["max_exp_avg_sq"] = torch.rand( + 10, requires_grad=True, dtype=torch.float64 + ) + kwargs: dict[str, Any] = {"lr": lr, "differentiable": True} + + gradcheck( + _multistep_backprop_diff_hyperparams_fn, + ( + params, + grad, + state, + AdamW, + kwargs, # includes lr + *state.values(), + *kwargs.values(), + ), + ) + + def test_adamw_differentiable_weight_decay(self): + params = torch.rand(10, requires_grad=True, dtype=torch.float64) + grad = torch.rand_like(params, requires_grad=True, dtype=torch.float64) + weight_decay = torch.tensor(0.999, requires_grad=True, dtype=torch.float64) + + state = {} + state["step"] = torch.tensor(10.0, requires_grad=False, dtype=torch.float64) + state["exp_avg"] = torch.rand(10, requires_grad=True, dtype=torch.float64) + state["exp_avg_sq"] = torch.rand(10, requires_grad=True, dtype=torch.float64) + state["max_exp_avg_sq"] = torch.rand( + 10, requires_grad=True, dtype=torch.float64 + ) + kwargs: dict[str, Any] = {"weight_decay": weight_decay, "differentiable": True} + + gradcheck( + _multistep_backprop_diff_hyperparams_fn, + ( + params, + grad, + state, + AdamW, + kwargs, # includes weight_decay + *state.values(), + *kwargs.values(), + ), + ) + + def test_adamw_differentiable_betas(self): + params = torch.rand(10, requires_grad=True, dtype=torch.float64) + grad = torch.rand_like(params, requires_grad=True, dtype=torch.float64) + + betas = ( + torch.tensor(0.9, requires_grad=True, dtype=torch.float64), + torch.tensor(0.999, requires_grad=True, dtype=torch.float64), + ) + state = {} + state["step"] = torch.tensor(10.0, requires_grad=False, dtype=torch.float64) + state["exp_avg"] = torch.rand(10, requires_grad=True, dtype=torch.float64) + state["exp_avg_sq"] = torch.rand(10, requires_grad=True, dtype=torch.float64) + state["max_exp_avg_sq"] = torch.rand( + 10, requires_grad=True, dtype=torch.float64 + ) + + # Have to pass in beta1 and beta2 separately + # so they're passed in as Tensors (not a tuple) and recognized by gradcheck. + # In the test, this is called: kwargs.update({betas: (beta1, beta2)}) + kwargs: dict[str, Any] = { + "beta1": betas[0], + "beta2": betas[1], + "differentiable": True, + } + + gradcheck( + _multistep_backprop_diff_hyperparams_fn, + ( + params, + grad, + state, + AdamW, + kwargs, # includes betas + *state.values(), + *kwargs.values(), + ), + ) + + def test_adamw_differentiable_all_hyperparams(self): + params = torch.rand(10, requires_grad=True, dtype=torch.float64) + grad = torch.rand_like(params, requires_grad=True, dtype=torch.float64) + + lr = torch.tensor(0.001, requires_grad=True, dtype=torch.float64) + weight_decay = torch.tensor(0.999, requires_grad=True, dtype=torch.float64) + betas = ( + torch.tensor(0.9, requires_grad=True, dtype=torch.float64), + torch.tensor(0.999, requires_grad=True, dtype=torch.float64), + ) + state = {} + state["step"] = torch.tensor(10.0, requires_grad=False, dtype=torch.float64) + state["exp_avg"] = torch.rand(10, requires_grad=True, dtype=torch.float64) + state["exp_avg_sq"] = torch.rand(10, requires_grad=True, dtype=torch.float64) + state["max_exp_avg_sq"] = torch.rand( + 10, requires_grad=True, dtype=torch.float64 + ) + + # Have to pass in beta1 and beta2 separately + # so they're passed in as Tensors (not a tuple) and recognized by gradcheck. + # In the test, this is called: kwargs.update({betas: (beta1, beta2)}) + kwargs: dict[str, Any] = { + "lr": lr, + "weight_decay": weight_decay, + "beta1": betas[0], + "beta2": betas[1], + "differentiable": True, + } + + gradcheck( + _multistep_backprop_diff_hyperparams_fn, + ( + params, + grad, + state, + AdamW, + kwargs, # includes betas + *state.values(), + *kwargs.values(), + ), + ) + def test_differentiable_lr(self): params = torch.rand(10, requires_grad=True, dtype=torch.float64) grad = torch.rand_like(params, requires_grad=True, dtype=torch.float64) diff --git a/torch/optim/adam.py b/torch/optim/adam.py index e3a628cb576..536c0c271c5 100644 --- a/torch/optim/adam.py +++ b/torch/optim/adam.py @@ -402,7 +402,14 @@ def _single_tensor_adam( # Perform stepweight decay param.mul_(1 - lr * weight_decay) else: - grad = grad.add(param, alpha=weight_decay) + # Nested if is necessary to bypass jitscript rules + if differentiable and isinstance(weight_decay, Tensor): + if weight_decay.requires_grad: + grad = grad.addcmul_(param.clone(), weight_decay) + else: + grad = grad.add(param, alpha=weight_decay) + else: + grad = grad.add(param, alpha=weight_decay) if torch.is_complex(param): grad = torch.view_as_real(grad) @@ -429,13 +436,43 @@ def _single_tensor_adam( # Decay the first and second moment running average coefficient exp_avg.lerp_(grad, 1 - device_beta1) - exp_avg_sq.mul_(beta2).addcmul_(grad, grad.conj(), value=1 - beta2) + # Nested if is necessary to bypass jitscript rules + if differentiable and isinstance(beta2, Tensor): + if beta2.requires_grad: + # Using lerp to only use 2 operations bc addcmul's value cannot be a tensor + # Showing equivalence of differentiable path and nondifferentiable path + # expavg * b2 + grad^2 * (1-b2) + # add expavg * (1-b2) - expavg * (1-b2) = 0 + # expavg * b2 + expavg * (1-b2) - expavg * (1-b2) + grad^2 * (1-b2) + # expavg - expavg * (1-b2) + grad^2 * (1-b2) + # expavg + (grad^2 - expavg) * (1-b2) + # expavg.lerp(grad^2, 1-beta2) + exp_avg_sq.lerp_(torch.square(grad), weight=1 - beta2) + else: + exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) + else: + exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) if capturable or differentiable: step = step_t - bias_correction1 = 1 - beta1**step - bias_correction2 = 1 - beta2**step + # Nested if is necessary to bypass jitscript rules + if differentiable and isinstance(beta1, Tensor): + if beta1.requires_grad: + bias_correction1 = 1 - beta1 ** step.clone() + else: + bias_correction1 = 1 - beta1**step + else: + bias_correction1 = 1 - beta1**step + + # Nested if is necessary to bypass jitscript rules + if differentiable and isinstance(beta2, Tensor): + if beta2.requires_grad: + bias_correction2 = 1 - beta2 ** step.clone() + else: + bias_correction2 = 1 - beta2**step + else: + bias_correction2 = 1 - beta2**step step_size = lr / bias_correction1 step_size_neg = step_size.neg() @@ -462,7 +499,10 @@ def _single_tensor_adam( exp_avg_sq.sqrt() / (bias_correction2_sqrt * step_size_neg) ).add_(eps / step_size_neg) - param.addcdiv_(exp_avg, denom) + if differentiable: + param.addcdiv_(exp_avg.clone(), denom) + else: + param.addcdiv_(exp_avg, denom) else: step = _get_value(step_t)