From 3908be676c7406da52824c83b804ae45e1daff7a Mon Sep 17 00:00:00 2001 From: Jane Xu Date: Thu, 16 Jan 2025 16:44:00 +0000 Subject: [PATCH] Fix loading older state_dict into AdamW after refactor (#144972) Pull Request resolved: https://github.com/pytorch/pytorch/pull/144972 Approved by: https://github.com/albanD --- test/test_optim.py | 19 +++++++++++++++++++ torch/optim/adamw.py | 8 ++++++++ 2 files changed, 27 insertions(+) diff --git a/test/test_optim.py b/test/test_optim.py index 0beae5b1cd7..da6ddd6f724 100644 --- a/test/test_optim.py +++ b/test/test_optim.py @@ -671,9 +671,28 @@ class TestOptimRenewed(TestCase): loaded_dict = optim.state_dict() + # Test that Adam respects the decoupled_weight_decay key new_optim = torch.optim.Adam(model.parameters()) new_optim.load_state_dict(loaded_dict) + self.assertTrue(new_optim.param_groups[0]["decoupled_weight_decay"]) + # Test that decoupled_weight_decay is always True for AdamW + adam_optim = torch.optim.Adam(model.parameters()) + adam_state_dict = adam_optim.state_dict() + self.assertFalse(adam_state_dict["param_groups"][0]["decoupled_weight_decay"]) + + new_optim = torch.optim.AdamW(model.parameters()) + new_optim.load_state_dict(adam_state_dict) + self.assertTrue(new_optim.param_groups[0]["decoupled_weight_decay"]) + + # Test that state_dicts from the old AdamW (with no decoupled_weight_decay key) + # will have decoupled_weight_decay=True in new AdamW: + old_adamw_dict = deepcopy(loaded_dict) + del old_adamw_dict["param_groups"][0]["decoupled_weight_decay"] + self.assertFalse("decoupled_weight_decay" in old_adamw_dict["param_groups"][0]) + + new_optim = torch.optim.AdamW(model.parameters()) + new_optim.load_state_dict(old_adamw_dict) self.assertTrue(new_optim.param_groups[0]["decoupled_weight_decay"]) def _compare_between( diff --git a/torch/optim/adamw.py b/torch/optim/adamw.py index 5588a2bc51a..63984b1e932 100644 --- a/torch/optim/adamw.py +++ b/torch/optim/adamw.py @@ -49,6 +49,14 @@ class AdamW(Adam): decoupled_weight_decay=True, ) + # Preserve decoupled_weight_decay from AdamW for backwards compatibility. The following + # guarantees that decoupled_weight_decay will always be True for loading any state into + # AdamW + def __setstate__(self, state): + super().__setstate__(state) + for group in self.param_groups: + group["decoupled_weight_decay"] = True + AdamW.__doc__ = ( r"""Implements AdamW algorithm, where weight decay does not accumulate in the momentum nor variance.