pytorch/test/dynamo/test_optimizers.py
Jane Xu b5ba80828f [optim] Rectify capturable testing and fix bugs! (#118326)
This PR fixes several bugs, listed in priority:
1. `load_state_dict` with a nontensor step was incorrect for capturable and fused implementations since we don't create the tensors on the right device in `__setstate__`. This has been fixed.
2. The most recently added capturable implementations forgot the check that all tensors should be on CUDA for eager. We've now added those checks
3. The most recent change in Adamax only adds capturable for foreach but will silently be incorrect for forloop/single-tensor. I've added erroring and modified testing with many many many skips for that. Honestly my preference after this PR has only been further cemented  that we should just do the single tensor and multi tensor capturable implementations together in the future. @mlazos
4. The conditional for adding cuda-supported configs for the optimizer infos was incorrect! So we hadn't been testing capturable! This also stands rectified and was the trigger for this PR in the first place.
5. In a similar way, the conditional for `_get_optim_inputs_including_global_cliquey_kwargs` was incorrect sometimes as well. This has also been corrected.

The following is not a bug, but is just something to make life simpler by not needing to handle Nones: `optim_input_funcs` must now mandatorily take in a `device`, which could be a string or a torch.device.

Details for posterity:
4. Running the test_foreach_matches_forloop test and printing the configs that get printed yields capturable getting included, which is correct.
```
(pytorch-3.10) [janeyx@devgpu023.odn1 ~/local/pytorch (5d50138f)]$ python test/test_optim.py -k test_foreach_matches_forloop_AdamW_cuda
/home/janeyx/.conda/envs/pytorch-3.10/lib/python3.10/site-packages/transformers/utils/generic.py:441: UserWarning: torch.utils._pytree._register_pytree_node is deprecated. Please use torch.utils._pytree.register_pytree_node instead.
  _torch_pytree._register_pytree_node(
/home/janeyx/.conda/envs/pytorch-3.10/lib/python3.10/site-packages/scipy/__init__.py:146: UserWarning: A NumPy version >=1.17.3 and <1.25.0 is required for this version of SciPy (detected version 1.26.0
  warnings.warn(f"A NumPy version >={np_minversion} and <{np_maxversion}"
params=None, kwargs={}, desc=default
params=None, kwargs={'lr': 0.01}, desc=non-default lr
params=None, kwargs={'weight_decay': 0.1}, desc=nonzero weight_decay
params=None, kwargs={'weight_decay': 0.1, 'maximize': True}, desc=maximize
params=None, kwargs={'weight_decay': 0.1, 'amsgrad': True}, desc=amsgrad
params=None, kwargs={'capturable': True}, desc=capturable
params=None, kwargs={'weight_decay': 0.1, 'amsgrad': True, 'capturable': True}, desc=capturable, amsgrad
params=None, kwargs={'lr': tensor(0.0010), 'amsgrad': True, 'capturable': True}, desc=Tensor lr with capturable and amsgrad
.
----------------------------------------------------------------------
Ran 1 test in 19.229s

OK
```
5. Running the test_optimizer_can_be_printed test (which calls `_get_optim_inputs_including_global_cliquey_kwargs`) and printing what gets run is also now correct.
```
/home/janeyx/.conda/envs/pytorch-3.10/lib/python3.10/site-packages/scipy/__init__.py:146: UserWarning: A NumPy version >=1.17.3 and <1.25.0 is required for this version of SciPy (detected version 1.26.0
  warnings.warn(f"A NumPy version >={np_minversion} and <{np_maxversion}"
params=None, kwargs={'differentiable': False}, desc=default
params=None, kwargs={'differentiable': True}, desc=default & differentiable
params=None, kwargs={'lr': 0.01, 'differentiable': False}, desc=non-default lr
params=None, kwargs={'lr': 0.01, 'differentiable': True}, desc=non-default lr & differentiable
params=None, kwargs={'weight_decay': 0.1, 'differentiable': False}, desc=nonzero weight_decay
params=None, kwargs={'weight_decay': 0.1, 'differentiable': True}, desc=nonzero weight_decay & differentiable
params=None, kwargs={'weight_decay': 0.1, 'maximize': True, 'differentiable': False}, desc=maximize
params=None, kwargs={'weight_decay': 0.1, 'maximize': True, 'differentiable': True}, desc=maximize & differentiable
params=None, kwargs={'weight_decay': 0.1, 'amsgrad': True, 'differentiable': False}, desc=amsgrad
params=None, kwargs={'weight_decay': 0.1, 'amsgrad': True, 'differentiable': True}, desc=amsgrad & differentiable
.params=None, kwargs={'foreach': False, 'differentiable': False, 'fused': False}, desc=default
params=None, kwargs={'foreach': True, 'differentiable': False, 'fused': False}, desc=default & foreach
params=None, kwargs={'foreach': False, 'differentiable': True, 'fused': False}, desc=default & differentiable
params=None, kwargs={'foreach': False, 'differentiable': False, 'fused': True}, desc=default & fused
params=None, kwargs={'lr': 0.01, 'foreach': False, 'differentiable': False, 'fused': False}, desc=non-default lr
params=None, kwargs={'lr': 0.01, 'foreach': True, 'differentiable': False, 'fused': False}, desc=non-default lr & foreach
params=None, kwargs={'lr': 0.01, 'foreach': False, 'differentiable': True, 'fused': False}, desc=non-default lr & differentiable
params=None, kwargs={'lr': 0.01, 'foreach': False, 'differentiable': False, 'fused': True}, desc=non-default lr & fused
params=None, kwargs={'weight_decay': 0.1, 'foreach': False, 'differentiable': False, 'fused': False}, desc=nonzero weight_decay
params=None, kwargs={'weight_decay': 0.1, 'foreach': True, 'differentiable': False, 'fused': False}, desc=nonzero weight_decay & foreach
params=None, kwargs={'weight_decay': 0.1, 'foreach': False, 'differentiable': True, 'fused': False}, desc=nonzero weight_decay & differentiable
params=None, kwargs={'weight_decay': 0.1, 'foreach': False, 'differentiable': False, 'fused': True}, desc=nonzero weight_decay & fused
params=None, kwargs={'weight_decay': 0.1, 'maximize': True, 'foreach': False, 'differentiable': False, 'fused': False}, desc=maximize
params=None, kwargs={'weight_decay': 0.1, 'maximize': True, 'foreach': True, 'differentiable': False, 'fused': False}, desc=maximize & foreach
params=None, kwargs={'weight_decay': 0.1, 'maximize': True, 'foreach': False, 'differentiable': True, 'fused': False}, desc=maximize & differentiable
params=None, kwargs={'weight_decay': 0.1, 'maximize': True, 'foreach': False, 'differentiable': False, 'fused': True}, desc=maximize & fused
params=None, kwargs={'weight_decay': 0.1, 'amsgrad': True, 'foreach': False, 'differentiable': False, 'fused': False}, desc=amsgrad
params=None, kwargs={'weight_decay': 0.1, 'amsgrad': True, 'foreach': True, 'differentiable': False, 'fused': False}, desc=amsgrad & foreach
params=None, kwargs={'weight_decay': 0.1, 'amsgrad': True, 'foreach': False, 'differentiable': True, 'fused': False}, desc=amsgrad & differentiable
params=None, kwargs={'weight_decay': 0.1, 'amsgrad': True, 'foreach': False, 'differentiable': False, 'fused': True}, desc=amsgrad & fused
params=None, kwargs={'capturable': True, 'foreach': False, 'differentiable': False, 'fused': False}, desc=capturable
params=None, kwargs={'capturable': True, 'foreach': True, 'differentiable': False, 'fused': False}, desc=capturable & foreach
params=None, kwargs={'capturable': True, 'foreach': False, 'differentiable': True, 'fused': False}, desc=capturable & differentiable
params=None, kwargs={'capturable': True, 'foreach': False, 'differentiable': False, 'fused': True}, desc=capturable & fused
params=None, kwargs={'weight_decay': 0.1, 'amsgrad': True, 'capturable': True, 'foreach': False, 'differentiable': False, 'fused': False}, desc=capturable, amsgrad
params=None, kwargs={'weight_decay': 0.1, 'amsgrad': True, 'capturable': True, 'foreach': True, 'differentiable': False, 'fused': False}, desc=capturable, amsgrad & foreach
params=None, kwargs={'weight_decay': 0.1, 'amsgrad': True, 'capturable': True, 'foreach': False, 'differentiable': True, 'fused': False}, desc=capturable, amsgrad & differentiable
params=None, kwargs={'weight_decay': 0.1, 'amsgrad': True, 'capturable': True, 'foreach': False, 'differentiable': False, 'fused': True}, desc=capturable, amsgrad & fused
params=None, kwargs={'lr': tensor(0.0010), 'amsgrad': True, 'capturable': True, 'foreach': False, 'differentiable': False, 'fused': False}, desc=Tensor lr with capturable and amsgrad
params=None, kwargs={'lr': tensor(0.0010), 'amsgrad': True, 'capturable': True, 'foreach': True, 'differentiable': False, 'fused': False}, desc=Tensor lr with capturable and amsgrad & foreach
params=None, kwargs={'lr': tensor(0.0010), 'amsgrad': True, 'capturable': True, 'foreach': False, 'differentiable': True, 'fused': False}, desc=Tensor lr with capturable and amsgrad & differentiable
params=None, kwargs={'lr': tensor(0.0010), 'amsgrad': True, 'capturable': True, 'foreach': False, 'differentiable': False, 'fused': True}, desc=Tensor lr with capturable and amsgrad & fused
.
----------------------------------------------------------------------
Ran 2 tests in 11.112s

OK
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/118326
Approved by: https://github.com/mlazos
2024-02-02 19:13:00 +00:00

199 lines
6.1 KiB
Python

"""
PYTEST_DONT_REWRITE (prevents pytest from rewriting assertions, which interferes
with test_adam in OptimizerTests)
"""
import functools
# Owner(s): ["module: dynamo"]
import inspect
import torch
import torch._dynamo
import torch._dynamo.test_case
import torch._dynamo.testing
from torch.nn import Parameter
input = torch.ones([10, 10])
model = torch.nn.Sequential(*[torch.nn.Linear(10, 10) for _ in range(2)])
model(input).sum().backward()
def get_optimizer_step(opt, closure=None):
# run the patcher so that step has the expected structure
torch._dynamo.eval_frame.TorchPatcher.patch()
# unwrap step TWICE to avoid a deliberate graph break due to a limitation of
# functionalization/no_grad detection--see the [Note on graph break] in optimizer.py
# This ignores the _use_grad_if_differentiable wrapper, which is fine for now as
# dynamo does not support differentiable optimizers anyway.
# This _also_ ignores the outer profiling hook wrapper, which may NOT be fine.
step_fn = opt.step.__wrapped__.__wrapped__
if closure is not None:
def fn():
step_fn(opt, closure)
else:
def fn():
step_fn(opt)
return fn
def make_test(optim_cls, closure=None, **kwargs):
# Remove this conditional when #118230 is fixed
if optim_cls.__name__ == "Adamax":
kwargs["foreach"] = True
opt = optim_cls(model.parameters(), **kwargs)
def test_fn(self):
nonlocal opt
fn = get_optimizer_step(opt, closure=closure)
with torch.set_grad_enabled(False):
torch.compile(fn, backend="eager", fullgraph=True)()
return test_fn
class OptimizerTests(torch._dynamo.test_case.TestCase):
test_sgd = make_test(torch.optim.SGD, lr=0.01)
# lgbfs has data-dependent control and internally iterates
# calling the closure
# TODO mlazos: re-enable once we have latest pytorch with FakeTensor fix #497
# test_lbfgs = make_test(
# torch.optim.LBFGS, exp_frame_cnt=3, closure=lambda: model(input).sum()
# )
# Has data dependent control for rectification (needs symint)
# RAdam has data-dependent control which breaks the graph;
# furthermore, the break is inside a for loop, so we bail on the frame
# entirely. This is basically an xfail; if the frame count goes up
# you done good
# test_radam = unittest.skipIf(IS_FBCODE, "TypeError: _use_grad() missing")(
# make_test(torch.optim.RAdam, exp_graph_count=0)
# )
# exclude SparseAdam because other areas of the stack don't support it yet
# the others are handled specially above
exclude = {
"SGD", # Handled above
"Optimizer",
"SparseAdam", # Unsupported
"LBFGS", # Unsupported
"RAdam", # Has data dependent control for rectification (needs symint)
}
optimizers = [
opt
for opt in torch.optim.__dict__.values()
if inspect.isclass(opt)
and issubclass(opt, torch.optim.Optimizer)
and opt.__name__ not in exclude
]
for opt in optimizers:
setattr(OptimizerTests, "test_" + opt.__name__.lower(), make_test(opt))
class MyOptimizer(torch.optim.Optimizer):
def __init__(self, params):
super().__init__(params, {})
def _init_group(self, params, group):
any_complex = False
for p in group["params"]:
params.append(p)
any_complex |= p.is_complex()
return any_complex
def step(self):
for group in self.param_groups:
params = []
any_complex = self._init_group(params, group)
if any_complex:
params[0] -= 1
else:
params[0] += 1
class End2EndTests(torch._dynamo.test_case.TestCase):
# https://github.com/pytorch/torchdynamo/issues/1604
def test_optimizing_over_tensor_with_requires_grad(self):
class Net(torch.nn.Module):
def forward(self, x, y):
z = torch.bmm(x, y)
z = torch.flatten(z, 1)
return z
def training_iter_fn(batch, model, optimizer):
optimizer.zero_grad()
out = model(**batch)
target = torch.tensor([0, 7])
loss = torch.nn.CrossEntropyLoss()(out, target)
loss.backward()
optimizer.step()
return loss
net = Net()
input1 = torch.randn(2, 1, 4)
input2 = torch.randn(2, 4, 8, requires_grad=True)
optimizer = torch.optim.Adam([input2], lr=0.1)
cnts = torch._dynamo.testing.CompileCounter()
opt_training_iter_fn = torch._dynamo.optimize(cnts)(training_iter_fn)
batch = {"x": input1, "y": input2}
for _ in range(2):
opt_training_iter_fn(batch, net, optimizer)
self.assertEqual(cnts.frame_count, 2)
def test_state_dict(self):
@torch.compile(backend="eager")
def _test_state_dict(weight, bias, input):
def fn_base(optimizer, weight, bias):
optimizer.zero_grad()
i = input
loss = (weight.mv(i) + bias).pow(2).sum()
loss.backward()
return loss
optimizer = torch.optim.Adagrad([weight, bias])
fn = functools.partial(fn_base, optimizer, weight, bias)
return optimizer, fn
optimizer, fn = _test_state_dict(
Parameter(torch.randn(10, 5)),
Parameter(torch.randn(10)),
torch.randn(5, requires_grad=True),
)
optimizer.step(fn)
def test_init_group(self):
for dtype in [torch.float32, torch.cfloat]:
tensor = torch.randn(5, 5, dtype=dtype)
params = Parameter(tensor.detach().clone(), requires_grad=False)
opt_params = Parameter(tensor.detach().clone(), requires_grad=False)
print(params, opt_params)
optim = MyOptimizer([params])
optim.step()
opt_optim = MyOptimizer([opt_params])
opt_step = torch.compile(backend="eager", fullgraph=True)(opt_optim.step)
opt_step()
print(params, opt_params)
self.assertEqual(params, opt_params)
if __name__ == "__main__":
from torch._dynamo.test_case import run_tests
run_tests()