mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
This PR fixes several bugs, listed in priority: 1. `load_state_dict` with a nontensor step was incorrect for capturable and fused implementations since we don't create the tensors on the right device in `__setstate__`. This has been fixed. 2. The most recently added capturable implementations forgot the check that all tensors should be on CUDA for eager. We've now added those checks 3. The most recent change in Adamax only adds capturable for foreach but will silently be incorrect for forloop/single-tensor. I've added erroring and modified testing with many many many skips for that. Honestly my preference after this PR has only been further cemented that we should just do the single tensor and multi tensor capturable implementations together in the future. @mlazos 4. The conditional for adding cuda-supported configs for the optimizer infos was incorrect! So we hadn't been testing capturable! This also stands rectified and was the trigger for this PR in the first place. 5. In a similar way, the conditional for `_get_optim_inputs_including_global_cliquey_kwargs` was incorrect sometimes as well. This has also been corrected. The following is not a bug, but is just something to make life simpler by not needing to handle Nones: `optim_input_funcs` must now mandatorily take in a `device`, which could be a string or a torch.device. Details for posterity: 4. Running the test_foreach_matches_forloop test and printing the configs that get printed yields capturable getting included, which is correct. ``` (pytorch-3.10) [janeyx@devgpu023.odn1 ~/local/pytorch (5d50138f)]$ python test/test_optim.py -k test_foreach_matches_forloop_AdamW_cuda /home/janeyx/.conda/envs/pytorch-3.10/lib/python3.10/site-packages/transformers/utils/generic.py:441: UserWarning: torch.utils._pytree._register_pytree_node is deprecated. Please use torch.utils._pytree.register_pytree_node instead. _torch_pytree._register_pytree_node( /home/janeyx/.conda/envs/pytorch-3.10/lib/python3.10/site-packages/scipy/__init__.py:146: UserWarning: A NumPy version >=1.17.3 and <1.25.0 is required for this version of SciPy (detected version 1.26.0 warnings.warn(f"A NumPy version >={np_minversion} and <{np_maxversion}" params=None, kwargs={}, desc=default params=None, kwargs={'lr': 0.01}, desc=non-default lr params=None, kwargs={'weight_decay': 0.1}, desc=nonzero weight_decay params=None, kwargs={'weight_decay': 0.1, 'maximize': True}, desc=maximize params=None, kwargs={'weight_decay': 0.1, 'amsgrad': True}, desc=amsgrad params=None, kwargs={'capturable': True}, desc=capturable params=None, kwargs={'weight_decay': 0.1, 'amsgrad': True, 'capturable': True}, desc=capturable, amsgrad params=None, kwargs={'lr': tensor(0.0010), 'amsgrad': True, 'capturable': True}, desc=Tensor lr with capturable and amsgrad . ---------------------------------------------------------------------- Ran 1 test in 19.229s OK ``` 5. Running the test_optimizer_can_be_printed test (which calls `_get_optim_inputs_including_global_cliquey_kwargs`) and printing what gets run is also now correct. ``` /home/janeyx/.conda/envs/pytorch-3.10/lib/python3.10/site-packages/scipy/__init__.py:146: UserWarning: A NumPy version >=1.17.3 and <1.25.0 is required for this version of SciPy (detected version 1.26.0 warnings.warn(f"A NumPy version >={np_minversion} and <{np_maxversion}" params=None, kwargs={'differentiable': False}, desc=default params=None, kwargs={'differentiable': True}, desc=default & differentiable params=None, kwargs={'lr': 0.01, 'differentiable': False}, desc=non-default lr params=None, kwargs={'lr': 0.01, 'differentiable': True}, desc=non-default lr & differentiable params=None, kwargs={'weight_decay': 0.1, 'differentiable': False}, desc=nonzero weight_decay params=None, kwargs={'weight_decay': 0.1, 'differentiable': True}, desc=nonzero weight_decay & differentiable params=None, kwargs={'weight_decay': 0.1, 'maximize': True, 'differentiable': False}, desc=maximize params=None, kwargs={'weight_decay': 0.1, 'maximize': True, 'differentiable': True}, desc=maximize & differentiable params=None, kwargs={'weight_decay': 0.1, 'amsgrad': True, 'differentiable': False}, desc=amsgrad params=None, kwargs={'weight_decay': 0.1, 'amsgrad': True, 'differentiable': True}, desc=amsgrad & differentiable .params=None, kwargs={'foreach': False, 'differentiable': False, 'fused': False}, desc=default params=None, kwargs={'foreach': True, 'differentiable': False, 'fused': False}, desc=default & foreach params=None, kwargs={'foreach': False, 'differentiable': True, 'fused': False}, desc=default & differentiable params=None, kwargs={'foreach': False, 'differentiable': False, 'fused': True}, desc=default & fused params=None, kwargs={'lr': 0.01, 'foreach': False, 'differentiable': False, 'fused': False}, desc=non-default lr params=None, kwargs={'lr': 0.01, 'foreach': True, 'differentiable': False, 'fused': False}, desc=non-default lr & foreach params=None, kwargs={'lr': 0.01, 'foreach': False, 'differentiable': True, 'fused': False}, desc=non-default lr & differentiable params=None, kwargs={'lr': 0.01, 'foreach': False, 'differentiable': False, 'fused': True}, desc=non-default lr & fused params=None, kwargs={'weight_decay': 0.1, 'foreach': False, 'differentiable': False, 'fused': False}, desc=nonzero weight_decay params=None, kwargs={'weight_decay': 0.1, 'foreach': True, 'differentiable': False, 'fused': False}, desc=nonzero weight_decay & foreach params=None, kwargs={'weight_decay': 0.1, 'foreach': False, 'differentiable': True, 'fused': False}, desc=nonzero weight_decay & differentiable params=None, kwargs={'weight_decay': 0.1, 'foreach': False, 'differentiable': False, 'fused': True}, desc=nonzero weight_decay & fused params=None, kwargs={'weight_decay': 0.1, 'maximize': True, 'foreach': False, 'differentiable': False, 'fused': False}, desc=maximize params=None, kwargs={'weight_decay': 0.1, 'maximize': True, 'foreach': True, 'differentiable': False, 'fused': False}, desc=maximize & foreach params=None, kwargs={'weight_decay': 0.1, 'maximize': True, 'foreach': False, 'differentiable': True, 'fused': False}, desc=maximize & differentiable params=None, kwargs={'weight_decay': 0.1, 'maximize': True, 'foreach': False, 'differentiable': False, 'fused': True}, desc=maximize & fused params=None, kwargs={'weight_decay': 0.1, 'amsgrad': True, 'foreach': False, 'differentiable': False, 'fused': False}, desc=amsgrad params=None, kwargs={'weight_decay': 0.1, 'amsgrad': True, 'foreach': True, 'differentiable': False, 'fused': False}, desc=amsgrad & foreach params=None, kwargs={'weight_decay': 0.1, 'amsgrad': True, 'foreach': False, 'differentiable': True, 'fused': False}, desc=amsgrad & differentiable params=None, kwargs={'weight_decay': 0.1, 'amsgrad': True, 'foreach': False, 'differentiable': False, 'fused': True}, desc=amsgrad & fused params=None, kwargs={'capturable': True, 'foreach': False, 'differentiable': False, 'fused': False}, desc=capturable params=None, kwargs={'capturable': True, 'foreach': True, 'differentiable': False, 'fused': False}, desc=capturable & foreach params=None, kwargs={'capturable': True, 'foreach': False, 'differentiable': True, 'fused': False}, desc=capturable & differentiable params=None, kwargs={'capturable': True, 'foreach': False, 'differentiable': False, 'fused': True}, desc=capturable & fused params=None, kwargs={'weight_decay': 0.1, 'amsgrad': True, 'capturable': True, 'foreach': False, 'differentiable': False, 'fused': False}, desc=capturable, amsgrad params=None, kwargs={'weight_decay': 0.1, 'amsgrad': True, 'capturable': True, 'foreach': True, 'differentiable': False, 'fused': False}, desc=capturable, amsgrad & foreach params=None, kwargs={'weight_decay': 0.1, 'amsgrad': True, 'capturable': True, 'foreach': False, 'differentiable': True, 'fused': False}, desc=capturable, amsgrad & differentiable params=None, kwargs={'weight_decay': 0.1, 'amsgrad': True, 'capturable': True, 'foreach': False, 'differentiable': False, 'fused': True}, desc=capturable, amsgrad & fused params=None, kwargs={'lr': tensor(0.0010), 'amsgrad': True, 'capturable': True, 'foreach': False, 'differentiable': False, 'fused': False}, desc=Tensor lr with capturable and amsgrad params=None, kwargs={'lr': tensor(0.0010), 'amsgrad': True, 'capturable': True, 'foreach': True, 'differentiable': False, 'fused': False}, desc=Tensor lr with capturable and amsgrad & foreach params=None, kwargs={'lr': tensor(0.0010), 'amsgrad': True, 'capturable': True, 'foreach': False, 'differentiable': True, 'fused': False}, desc=Tensor lr with capturable and amsgrad & differentiable params=None, kwargs={'lr': tensor(0.0010), 'amsgrad': True, 'capturable': True, 'foreach': False, 'differentiable': False, 'fused': True}, desc=Tensor lr with capturable and amsgrad & fused . ---------------------------------------------------------------------- Ran 2 tests in 11.112s OK ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/118326 Approved by: https://github.com/mlazos
199 lines
6.1 KiB
Python
199 lines
6.1 KiB
Python
"""
|
|
PYTEST_DONT_REWRITE (prevents pytest from rewriting assertions, which interferes
|
|
with test_adam in OptimizerTests)
|
|
"""
|
|
import functools
|
|
|
|
# Owner(s): ["module: dynamo"]
|
|
|
|
import inspect
|
|
|
|
import torch
|
|
|
|
import torch._dynamo
|
|
import torch._dynamo.test_case
|
|
import torch._dynamo.testing
|
|
from torch.nn import Parameter
|
|
|
|
input = torch.ones([10, 10])
|
|
model = torch.nn.Sequential(*[torch.nn.Linear(10, 10) for _ in range(2)])
|
|
model(input).sum().backward()
|
|
|
|
|
|
def get_optimizer_step(opt, closure=None):
|
|
# run the patcher so that step has the expected structure
|
|
torch._dynamo.eval_frame.TorchPatcher.patch()
|
|
|
|
# unwrap step TWICE to avoid a deliberate graph break due to a limitation of
|
|
# functionalization/no_grad detection--see the [Note on graph break] in optimizer.py
|
|
# This ignores the _use_grad_if_differentiable wrapper, which is fine for now as
|
|
# dynamo does not support differentiable optimizers anyway.
|
|
# This _also_ ignores the outer profiling hook wrapper, which may NOT be fine.
|
|
step_fn = opt.step.__wrapped__.__wrapped__
|
|
if closure is not None:
|
|
|
|
def fn():
|
|
step_fn(opt, closure)
|
|
|
|
else:
|
|
|
|
def fn():
|
|
step_fn(opt)
|
|
|
|
return fn
|
|
|
|
|
|
def make_test(optim_cls, closure=None, **kwargs):
|
|
# Remove this conditional when #118230 is fixed
|
|
if optim_cls.__name__ == "Adamax":
|
|
kwargs["foreach"] = True
|
|
|
|
opt = optim_cls(model.parameters(), **kwargs)
|
|
|
|
def test_fn(self):
|
|
nonlocal opt
|
|
|
|
fn = get_optimizer_step(opt, closure=closure)
|
|
|
|
with torch.set_grad_enabled(False):
|
|
torch.compile(fn, backend="eager", fullgraph=True)()
|
|
|
|
return test_fn
|
|
|
|
|
|
class OptimizerTests(torch._dynamo.test_case.TestCase):
|
|
test_sgd = make_test(torch.optim.SGD, lr=0.01)
|
|
# lgbfs has data-dependent control and internally iterates
|
|
# calling the closure
|
|
# TODO mlazos: re-enable once we have latest pytorch with FakeTensor fix #497
|
|
# test_lbfgs = make_test(
|
|
# torch.optim.LBFGS, exp_frame_cnt=3, closure=lambda: model(input).sum()
|
|
# )
|
|
|
|
# Has data dependent control for rectification (needs symint)
|
|
# RAdam has data-dependent control which breaks the graph;
|
|
# furthermore, the break is inside a for loop, so we bail on the frame
|
|
# entirely. This is basically an xfail; if the frame count goes up
|
|
# you done good
|
|
# test_radam = unittest.skipIf(IS_FBCODE, "TypeError: _use_grad() missing")(
|
|
# make_test(torch.optim.RAdam, exp_graph_count=0)
|
|
# )
|
|
|
|
|
|
# exclude SparseAdam because other areas of the stack don't support it yet
|
|
# the others are handled specially above
|
|
exclude = {
|
|
"SGD", # Handled above
|
|
"Optimizer",
|
|
"SparseAdam", # Unsupported
|
|
"LBFGS", # Unsupported
|
|
"RAdam", # Has data dependent control for rectification (needs symint)
|
|
}
|
|
|
|
optimizers = [
|
|
opt
|
|
for opt in torch.optim.__dict__.values()
|
|
if inspect.isclass(opt)
|
|
and issubclass(opt, torch.optim.Optimizer)
|
|
and opt.__name__ not in exclude
|
|
]
|
|
|
|
|
|
for opt in optimizers:
|
|
setattr(OptimizerTests, "test_" + opt.__name__.lower(), make_test(opt))
|
|
|
|
|
|
class MyOptimizer(torch.optim.Optimizer):
|
|
def __init__(self, params):
|
|
super().__init__(params, {})
|
|
|
|
def _init_group(self, params, group):
|
|
any_complex = False
|
|
for p in group["params"]:
|
|
params.append(p)
|
|
any_complex |= p.is_complex()
|
|
return any_complex
|
|
|
|
def step(self):
|
|
for group in self.param_groups:
|
|
params = []
|
|
any_complex = self._init_group(params, group)
|
|
if any_complex:
|
|
params[0] -= 1
|
|
else:
|
|
params[0] += 1
|
|
|
|
|
|
class End2EndTests(torch._dynamo.test_case.TestCase):
|
|
# https://github.com/pytorch/torchdynamo/issues/1604
|
|
def test_optimizing_over_tensor_with_requires_grad(self):
|
|
class Net(torch.nn.Module):
|
|
def forward(self, x, y):
|
|
z = torch.bmm(x, y)
|
|
z = torch.flatten(z, 1)
|
|
return z
|
|
|
|
def training_iter_fn(batch, model, optimizer):
|
|
optimizer.zero_grad()
|
|
out = model(**batch)
|
|
target = torch.tensor([0, 7])
|
|
loss = torch.nn.CrossEntropyLoss()(out, target)
|
|
loss.backward()
|
|
optimizer.step()
|
|
return loss
|
|
|
|
net = Net()
|
|
input1 = torch.randn(2, 1, 4)
|
|
input2 = torch.randn(2, 4, 8, requires_grad=True)
|
|
optimizer = torch.optim.Adam([input2], lr=0.1)
|
|
|
|
cnts = torch._dynamo.testing.CompileCounter()
|
|
opt_training_iter_fn = torch._dynamo.optimize(cnts)(training_iter_fn)
|
|
batch = {"x": input1, "y": input2}
|
|
for _ in range(2):
|
|
opt_training_iter_fn(batch, net, optimizer)
|
|
self.assertEqual(cnts.frame_count, 2)
|
|
|
|
def test_state_dict(self):
|
|
@torch.compile(backend="eager")
|
|
def _test_state_dict(weight, bias, input):
|
|
def fn_base(optimizer, weight, bias):
|
|
optimizer.zero_grad()
|
|
i = input
|
|
loss = (weight.mv(i) + bias).pow(2).sum()
|
|
loss.backward()
|
|
return loss
|
|
|
|
optimizer = torch.optim.Adagrad([weight, bias])
|
|
fn = functools.partial(fn_base, optimizer, weight, bias)
|
|
return optimizer, fn
|
|
|
|
optimizer, fn = _test_state_dict(
|
|
Parameter(torch.randn(10, 5)),
|
|
Parameter(torch.randn(10)),
|
|
torch.randn(5, requires_grad=True),
|
|
)
|
|
optimizer.step(fn)
|
|
|
|
def test_init_group(self):
|
|
for dtype in [torch.float32, torch.cfloat]:
|
|
tensor = torch.randn(5, 5, dtype=dtype)
|
|
params = Parameter(tensor.detach().clone(), requires_grad=False)
|
|
opt_params = Parameter(tensor.detach().clone(), requires_grad=False)
|
|
print(params, opt_params)
|
|
|
|
optim = MyOptimizer([params])
|
|
optim.step()
|
|
|
|
opt_optim = MyOptimizer([opt_params])
|
|
opt_step = torch.compile(backend="eager", fullgraph=True)(opt_optim.step)
|
|
opt_step()
|
|
print(params, opt_params)
|
|
|
|
self.assertEqual(params, opt_params)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
from torch._dynamo.test_case import run_tests
|
|
|
|
run_tests()
|