mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-24 22:17:32 +00:00
Adding this set_to_none flag to zero_grad to have signature parity with pytorch Adam (#16375)
### Description torch.optim Adam zero_grad() signature is zero_grad(set_to_none=True) https://pytorch.org/docs/stable/generated/torch.optim.Adam.html#torch.optim.Adam.zero_grad We set this flag in initialization, similar to deepspeed: https://deepspeed.readthedocs.io/en/latest/optimizers.html#deepspeed.ops.adam.FusedAdam Adding this flag to have signature parity with pytorch Adam ### Motivation and Context Easier model integration Co-authored-by: Jingyan Wang <jingywa@microsoft.com@orttrainingdev7.d32nl1ml4oruzj4qz3bqlggovf.px.internal.cloudapp.net>
This commit is contained in:
parent
470d6c1cce
commit
5dcaf70501
1 changed files with 3 additions and 2 deletions
|
|
@ -56,6 +56,7 @@ class FusedAdam(torch.optim.Optimizer):
|
|||
(AdamWMode.ADAMW_TORCH) (default: AdamWMode.ADAMW_TRANSFORMERS)
|
||||
set_grad_none (bool, optional): whether set grad to None when zero_grad()
|
||||
method is called. (default: True)
|
||||
PyTorch Adam has set_to_none parameter in zero_grad(), supporting that too
|
||||
|
||||
.. _Adam - A Method for Stochastic Optimization:
|
||||
https://arxiv.org/abs/1412.6980
|
||||
|
|
@ -91,8 +92,8 @@ class FusedAdam(torch.optim.Optimizer):
|
|||
self._multi_tensor_applier = MultiTensorApply(2048 * 32)
|
||||
self._TorchTensorVector = fused_ops.TorchTensorVector
|
||||
|
||||
def zero_grad(self):
|
||||
if self._set_grad_none:
|
||||
def zero_grad(self, set_to_none=True):
|
||||
if self._set_grad_none or set_to_none:
|
||||
for group in self.param_groups:
|
||||
for p in group["params"]:
|
||||
p.grad = None
|
||||
|
|
|
|||
Loading…
Reference in a new issue