mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-31 23:27:43 +00:00
Add support for FusedAdam to be mathematically equivalent to pytorch/AdamW (#10106)
This commit is contained in:
parent
13e277525c
commit
141606534c
4 changed files with 89 additions and 20 deletions
|
|
@ -2,5 +2,5 @@ from .config import _OptimizerConfig, AdamConfig, LambConfig, SGDConfig
|
|||
from .lr_scheduler import _LRScheduler, ConstantWarmupLRScheduler, CosineWarmupLRScheduler,\
|
||||
LinearWarmupLRScheduler, PolyWarmupLRScheduler
|
||||
|
||||
from .fused_adam import FusedAdam
|
||||
from .fused_adam import FusedAdam, AdamWMode
|
||||
from .fp16_optimizer import FP16_Optimizer
|
||||
|
|
|
|||
|
|
@ -12,13 +12,22 @@ This file is adapted from fused adam in NVIDIA/apex, commit a109f85
|
|||
|
||||
import torch
|
||||
from ._multi_tensor_apply import MultiTensorApply
|
||||
from enum import IntEnum
|
||||
|
||||
|
||||
class AdamWMode(IntEnum):
|
||||
ADAM_L2_REGULARIZATION = 0 # Adam with L2 regularization
|
||||
ADAMW_TRANSFORMERS = 1 # Adam with weight decay implemented to be equivalent to Transformers/AdamW
|
||||
ADAMW_TORCH = 2 # Adam with weight decay implemented to be equivalent to torch/AdamW
|
||||
|
||||
|
||||
class FusedAdam(torch.optim.Optimizer):
|
||||
"""Implements Adam algorithm.
|
||||
|
||||
The algorithmic implementation is mathematically equivalent to Transformers/AdamW
|
||||
as defined here: https://github.com/huggingface/transformers/blob/61f64262692ac7dc90e2e0bdeb7e79d9cd607a66/src/transformers/optimization.py#L349-L370
|
||||
The algorithmic implementation is mathematically equivalent to
|
||||
`Transformers/AdamW <https://github.com/huggingface/transformers/blob/61f64262692ac7dc90e2e0bdeb7e79d9cd607a66/src/transformers/optimization.py#L349-L370>`_
|
||||
when adam_w_mode = 1 and `torch/Adam <https://github.com/pytorch/pytorch/blob/a217a62e73fd30b658743af8a69966f90327f018/torch/optim/adamw.py#L6>`_
|
||||
when adam_w_mode = 2
|
||||
|
||||
Currently GPU-only.
|
||||
|
||||
|
|
@ -27,7 +36,7 @@ class FusedAdam(torch.optim.Optimizer):
|
|||
* Fusion of the Adam update's elementwise operations
|
||||
* A multi-tensor apply launch that batches the elementwise updates applied to all the model's parameters into one or a few kernel launches.
|
||||
|
||||
Adam was been proposed in `Adam: A Method for Stochastic Optimization`_.
|
||||
Adam was proposed in `Adam: A Method for Stochastic Optimization`_.
|
||||
|
||||
Arguments:
|
||||
params (iterable): iterable of parameters to optimize or dicts defining
|
||||
|
|
@ -38,11 +47,11 @@ class FusedAdam(torch.optim.Optimizer):
|
|||
eps (float, optional): term added to the denominator to improve
|
||||
numerical stability. (default: 1e-8)
|
||||
weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
|
||||
amsgrad (boolean, optional): whether to use the AMSGrad variant of this
|
||||
algorithm from the paper `On the Convergence of Adam and Beyond`_
|
||||
(default: False) NOT SUPPORTED in FusedAdam!
|
||||
adam_w_mode (boolean, optional): Apply L2 regularization or weight decay
|
||||
True for decoupled weight decay(also known as AdamW) (default: True)
|
||||
adam_w_mode (AdamWMode, optional): Apply L2 regularization or weight decay
|
||||
(AdamWMode.ADAM_L2_REGULARIZATION), decoupled weight decay with
|
||||
transformers/AdamW mathematical implementation (AdamWMode.ADAMW_TRANSFORMERS)
|
||||
or decoupled weight decay with transformers/AdamW implementation
|
||||
(AdamWMode.ADAMW_TORCH) (default: AdamWMode.ADAMW_TRANSFORMERS)
|
||||
set_grad_none (bool, optional): whether set grad to None when zero_grad()
|
||||
method is called. (default: True)
|
||||
|
||||
|
|
@ -58,23 +67,20 @@ class FusedAdam(torch.optim.Optimizer):
|
|||
betas=(0.9,
|
||||
0.999),
|
||||
eps=1e-6,
|
||||
adam_w_mode=True,
|
||||
adam_w_mode=AdamWMode.ADAMW_TRANSFORMERS,
|
||||
weight_decay=0.,
|
||||
amsgrad=False,
|
||||
set_grad_none=False):
|
||||
set_grad_none=True):
|
||||
|
||||
# The FusedAdam implementation is mathematically equivalent to
|
||||
# transformers AdamW. The input arguments also have the same defaults.
|
||||
|
||||
if amsgrad:
|
||||
raise RuntimeError('FusedAdam does not support the AMSGrad variant.')
|
||||
defaults = dict(lr=lr,
|
||||
bias_correction=bias_correction,
|
||||
betas=betas,
|
||||
eps=eps,
|
||||
weight_decay=weight_decay)
|
||||
super(FusedAdam, self).__init__(params, defaults)
|
||||
self._adam_w_mode = 1 if adam_w_mode else 0
|
||||
self._adam_w_mode = adam_w_mode
|
||||
self._set_grad_none = set_grad_none
|
||||
|
||||
# Skip buffer
|
||||
|
|
|
|||
|
|
@ -24,8 +24,9 @@
|
|||
#define ILP 4
|
||||
|
||||
typedef enum {
|
||||
ADAM_MODE_0 = 0, // L2 regularization mode
|
||||
ADAM_MODE_1 = 1 // Decoupled weight decay mode(AdamW)
|
||||
ADAM_MODE_0 = 0, // L2 regularization mode
|
||||
ADAM_MODE_1 = 1, // Decoupled weight decay mode (AdamW) as implemented in transformers/AdamW
|
||||
ADAM_MODE_2 = 2 // Decoupled weight decay mode (AdamW) as implemented in pytorch/AdamW
|
||||
} adamMode_t;
|
||||
|
||||
using MATH_T = float;
|
||||
|
|
@ -40,6 +41,8 @@ struct AdamFunctor {
|
|||
const float epsilon,
|
||||
const float lr,
|
||||
const float lr_corrected,
|
||||
const float bias_correction1,
|
||||
const float bias_correction2,
|
||||
adamMode_t mode,
|
||||
const float decay)
|
||||
{
|
||||
|
|
@ -91,13 +94,20 @@ struct AdamFunctor {
|
|||
r_v[ii] = beta2 * r_v[ii] + (1 - beta2) * r_g[ii] * r_g[ii];
|
||||
MATH_T denom = sqrtf(r_v[ii]) + epsilon;
|
||||
r_p[ii] = r_p[ii] - (lr_corrected * r_m[ii] / denom);
|
||||
} else { // weight decay
|
||||
} else if (mode == ADAM_MODE_1) { // weight decay
|
||||
// Adapted to be mathematically equivalent to transformers AdamW
|
||||
r_m[ii] = beta1 * r_m[ii] + (1 - beta1) * r_g[ii];
|
||||
r_v[ii] = beta2 * r_v[ii] + (1 - beta2) * r_g[ii] * r_g[ii];
|
||||
MATH_T denom = sqrtf(r_v[ii]) + epsilon;
|
||||
r_p[ii] = r_p[ii] - (lr_corrected * r_m[ii] / denom);
|
||||
r_p[ii] = r_p[ii] - (lr * decay * r_p[ii]);
|
||||
} else if (mode == ADAM_MODE_2) {
|
||||
// Adapted to be mathematically equivalent to torch AdamW
|
||||
r_p[ii] = r_p[ii] - (r_p[ii] * lr * decay);
|
||||
r_m[ii] = beta1 * r_m[ii] + (1 - beta1) * r_g[ii];
|
||||
r_v[ii] = beta2 * r_v[ii] + (1 - beta2) * r_g[ii] * r_g[ii];
|
||||
MATH_T denom = (sqrtf(r_v[ii]) / sqrtf(bias_correction2)) + epsilon;
|
||||
r_p[ii] = r_p[ii] - (lr * r_m[ii]) / (bias_correction1 * denom);
|
||||
}
|
||||
}
|
||||
#pragma unroll
|
||||
|
|
@ -128,7 +138,7 @@ void multi_tensor_adam_cuda(int chunk_size,
|
|||
using namespace at;
|
||||
|
||||
// Handle bias correction mode
|
||||
double bias_correction1 = 1.0, bias_correction2 = 1.0;
|
||||
float bias_correction1 = 1.0, bias_correction2 = 1.0;
|
||||
float lr_corrected = lr;
|
||||
if (bias_correction == 1) {
|
||||
bias_correction1 = 1 - std::pow(beta1, step);
|
||||
|
|
@ -150,6 +160,8 @@ void multi_tensor_adam_cuda(int chunk_size,
|
|||
epsilon,
|
||||
lr,
|
||||
lr_corrected,
|
||||
bias_correction1,
|
||||
bias_correction2,
|
||||
(adamMode_t)mode,
|
||||
weight_decay);)
|
||||
|
||||
|
|
|
|||
|
|
@ -30,7 +30,7 @@ from onnxruntime.training.ortmodule import (ORTModule,
|
|||
_graph_execution_manager)
|
||||
import onnxruntime.training.ortmodule as ortmodule_module
|
||||
|
||||
from onnxruntime.training.optim import FusedAdam
|
||||
from onnxruntime.training.optim import FusedAdam, AdamWMode
|
||||
from transformers import AdamW
|
||||
|
||||
import _test_helpers
|
||||
|
|
@ -4428,6 +4428,57 @@ def test_ortmodule_fused_adam_optimizer_correctness():
|
|||
for pt_param, ort_param in zip(pt_model.parameters(), ort_model.parameters()):
|
||||
_test_helpers.assert_values_are_close(pt_param, ort_param, atol=1e-4, rtol=1e-5)
|
||||
|
||||
def test_ortmodule_fused_adam_optimizer_correctness_torch():
|
||||
|
||||
torch.manual_seed(8888)
|
||||
|
||||
device = 'cuda'
|
||||
N, D_in, H, D_out = 4, 4, 8, 4
|
||||
|
||||
pt_model = NeuralNetSinglePositionalArgument(D_in, H, D_out).to(device)
|
||||
adamw_optimizer = torch.optim.AdamW(pt_model.parameters(), lr=1e-3)
|
||||
|
||||
ort_model = ORTModule(copy.deepcopy(pt_model))
|
||||
ort_fused_adam_optimizer = FusedAdam(ort_model.parameters(), lr=1e-3,
|
||||
adam_w_mode=AdamWMode.ADAMW_TORCH,
|
||||
weight_decay=0.01,
|
||||
eps=1e-8)
|
||||
|
||||
def run_step(model, x):
|
||||
prediction = model(x)
|
||||
loss = prediction.sum()
|
||||
loss.backward()
|
||||
|
||||
return loss
|
||||
|
||||
def run_optim_step(optimizer):
|
||||
optimizer.step()
|
||||
optimizer.zero_grad()
|
||||
|
||||
ga_steps = 2
|
||||
pt_model.zero_grad()
|
||||
ort_model.zero_grad()
|
||||
|
||||
for step in range(1000):
|
||||
x1 = torch.randn(N, D_in, device=device, dtype=torch.float32)
|
||||
x2 = copy.deepcopy(x1)
|
||||
|
||||
pt_loss = run_step(pt_model, x1)
|
||||
ort_loss = run_step(ort_model, x2)
|
||||
|
||||
for pt_param, ort_param in zip(pt_model.parameters(), ort_model.parameters()):
|
||||
ort_param.grad = copy.deepcopy(pt_param.grad)
|
||||
|
||||
_test_helpers.assert_values_are_close(pt_loss, ort_loss, atol=1e-4, rtol=1e-5)
|
||||
_test_helpers.assert_gradients_match_and_reset_gradient(ort_model, pt_model, atol=1e-4, rtol=1e-5, reset_gradient=False)
|
||||
|
||||
if (step+1) % ga_steps == 0:
|
||||
run_optim_step(adamw_optimizer)
|
||||
run_optim_step(ort_fused_adam_optimizer)
|
||||
|
||||
for pt_param, ort_param in zip(pt_model.parameters(), ort_model.parameters()):
|
||||
_test_helpers.assert_values_are_close(pt_param, ort_param, atol=1e-4, rtol=1e-5)
|
||||
|
||||
def test_sigmoid_grad():
|
||||
class NeuralNetSigmoid(torch.nn.Module):
|
||||
def __init__(self, input_size, hidden_size, num_classes):
|
||||
|
|
|
|||
Loading…
Reference in a new issue