diff --git a/orttraining/orttraining/python/training/optim/fused_adam.py b/orttraining/orttraining/python/training/optim/fused_adam.py index 52c6cb623e..ae8654080e 100644 --- a/orttraining/orttraining/python/training/optim/fused_adam.py +++ b/orttraining/orttraining/python/training/optim/fused_adam.py @@ -56,6 +56,7 @@ class FusedAdam(torch.optim.Optimizer): (AdamWMode.ADAMW_TORCH) (default: AdamWMode.ADAMW_TRANSFORMERS) set_grad_none (bool, optional): whether set grad to None when zero_grad() method is called. (default: True) + PyTorch Adam has set_to_none parameter in zero_grad(), supporting that too .. _Adam - A Method for Stochastic Optimization: https://arxiv.org/abs/1412.6980 @@ -91,8 +92,8 @@ class FusedAdam(torch.optim.Optimizer): self._multi_tensor_applier = MultiTensorApply(2048 * 32) self._TorchTensorVector = fused_ops.TorchTensorVector - def zero_grad(self): - if self._set_grad_none: + def zero_grad(self, set_to_none=True): + if self._set_grad_none or set_to_none: for group in self.param_groups: for p in group["params"]: p.grad = None