From 5dcaf70501e6346ca1e207a9aceee4ec07663f8a Mon Sep 17 00:00:00 2001 From: jingyanwangms <47403504+jingyanwangms@users.noreply.github.com> Date: Mon, 19 Jun 2023 17:27:41 -0700 Subject: [PATCH] Adding this set_to_none flag to zero_grad to have signature parity with pytorch Adam (#16375) ### Description torch.optim Adam zero_grad() signature is zero_grad(set_to_none=True) https://pytorch.org/docs/stable/generated/torch.optim.Adam.html#torch.optim.Adam.zero_grad We set this flag in initialization, similar to deepspeed: https://deepspeed.readthedocs.io/en/latest/optimizers.html#deepspeed.ops.adam.FusedAdam Adding this flag to have signature parity with pytorch Adam ### Motivation and Context Easier model integration Co-authored-by: Jingyan Wang --- orttraining/orttraining/python/training/optim/fused_adam.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/orttraining/orttraining/python/training/optim/fused_adam.py b/orttraining/orttraining/python/training/optim/fused_adam.py index 52c6cb623e..ae8654080e 100644 --- a/orttraining/orttraining/python/training/optim/fused_adam.py +++ b/orttraining/orttraining/python/training/optim/fused_adam.py @@ -56,6 +56,7 @@ class FusedAdam(torch.optim.Optimizer): (AdamWMode.ADAMW_TORCH) (default: AdamWMode.ADAMW_TRANSFORMERS) set_grad_none (bool, optional): whether set grad to None when zero_grad() method is called. (default: True) + PyTorch Adam has set_to_none parameter in zero_grad(), supporting that too .. _Adam - A Method for Stochastic Optimization: https://arxiv.org/abs/1412.6980 @@ -91,8 +92,8 @@ class FusedAdam(torch.optim.Optimizer): self._multi_tensor_applier = MultiTensorApply(2048 * 32) self._TorchTensorVector = fused_ops.TorchTensorVector - def zero_grad(self): - if self._set_grad_none: + def zero_grad(self, set_to_none=True): + if self._set_grad_none or set_to_none: for group in self.param_groups: for p in group["params"]: p.grad = None