Adding this set_to_none flag to zero_grad to have signature parity with pytorch Adam (#16375)

### Description
torch.optim Adam zero_grad() signature is
zero_grad(set_to_none=True)

https://pytorch.org/docs/stable/generated/torch.optim.Adam.html#torch.optim.Adam.zero_grad

We set this flag in initialization, similar to deepspeed:
https://deepspeed.readthedocs.io/en/latest/optimizers.html#deepspeed.ops.adam.FusedAdam

Adding this flag to have signature parity with pytorch Adam

### Motivation and Context
Easier model integration

Co-authored-by: Jingyan Wang <jingywa@microsoft.com@orttrainingdev7.d32nl1ml4oruzj4qz3bqlggovf.px.internal.cloudapp.net>
This commit is contained in:
jingyanwangms 2023-06-19 17:27:41 -07:00 committed by GitHub
parent 470d6c1cce
commit 5dcaf70501
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -56,6 +56,7 @@ class FusedAdam(torch.optim.Optimizer):
(AdamWMode.ADAMW_TORCH) (default: AdamWMode.ADAMW_TRANSFORMERS)
set_grad_none (bool, optional): whether set grad to None when zero_grad()
method is called. (default: True)
PyTorch Adam has set_to_none parameter in zero_grad(), supporting that too
.. _Adam - A Method for Stochastic Optimization:
https://arxiv.org/abs/1412.6980
@ -91,8 +92,8 @@ class FusedAdam(torch.optim.Optimizer):
self._multi_tensor_applier = MultiTensorApply(2048 * 32)
self._TorchTensorVector = fused_ops.TorchTensorVector
def zero_grad(self):
if self._set_grad_none:
def zero_grad(self, set_to_none=True):
if self._set_grad_none or set_to_none:
for group in self.param_groups:
for p in group["params"]:
p.grad = None