Add support for FusedAdam to be mathematically equivalent to pytorch/AdamW (#10106)

This commit is contained in:
Baiju Meswani 2022-01-21 13:37:59 -08:00 committed by GitHub
parent 13e277525c
commit 141606534c
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 89 additions and 20 deletions

View file

@ -2,5 +2,5 @@ from .config import _OptimizerConfig, AdamConfig, LambConfig, SGDConfig
from .lr_scheduler import _LRScheduler, ConstantWarmupLRScheduler, CosineWarmupLRScheduler,\
LinearWarmupLRScheduler, PolyWarmupLRScheduler
from .fused_adam import FusedAdam
from .fused_adam import FusedAdam, AdamWMode
from .fp16_optimizer import FP16_Optimizer

View file

@ -12,13 +12,22 @@ This file is adapted from fused adam in NVIDIA/apex, commit a109f85
import torch
from ._multi_tensor_apply import MultiTensorApply
from enum import IntEnum
class AdamWMode(IntEnum):
ADAM_L2_REGULARIZATION = 0 # Adam with L2 regularization
ADAMW_TRANSFORMERS = 1 # Adam with weight decay implemented to be equivalent to Transformers/AdamW
ADAMW_TORCH = 2 # Adam with weight decay implemented to be equivalent to torch/AdamW
class FusedAdam(torch.optim.Optimizer):
"""Implements Adam algorithm.
The algorithmic implementation is mathematically equivalent to Transformers/AdamW
as defined here: https://github.com/huggingface/transformers/blob/61f64262692ac7dc90e2e0bdeb7e79d9cd607a66/src/transformers/optimization.py#L349-L370
The algorithmic implementation is mathematically equivalent to
`Transformers/AdamW <https://github.com/huggingface/transformers/blob/61f64262692ac7dc90e2e0bdeb7e79d9cd607a66/src/transformers/optimization.py#L349-L370>`_
when adam_w_mode = 1 and `torch/Adam <https://github.com/pytorch/pytorch/blob/a217a62e73fd30b658743af8a69966f90327f018/torch/optim/adamw.py#L6>`_
when adam_w_mode = 2
Currently GPU-only.
@ -27,7 +36,7 @@ class FusedAdam(torch.optim.Optimizer):
* Fusion of the Adam update's elementwise operations
* A multi-tensor apply launch that batches the elementwise updates applied to all the model's parameters into one or a few kernel launches.
Adam was been proposed in `Adam: A Method for Stochastic Optimization`_.
Adam was proposed in `Adam: A Method for Stochastic Optimization`_.
Arguments:
params (iterable): iterable of parameters to optimize or dicts defining
@ -38,11 +47,11 @@ class FusedAdam(torch.optim.Optimizer):
eps (float, optional): term added to the denominator to improve
numerical stability. (default: 1e-8)
weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
amsgrad (boolean, optional): whether to use the AMSGrad variant of this
algorithm from the paper `On the Convergence of Adam and Beyond`_
(default: False) NOT SUPPORTED in FusedAdam!
adam_w_mode (boolean, optional): Apply L2 regularization or weight decay
True for decoupled weight decay(also known as AdamW) (default: True)
adam_w_mode (AdamWMode, optional): Apply L2 regularization or weight decay
(AdamWMode.ADAM_L2_REGULARIZATION), decoupled weight decay with
transformers/AdamW mathematical implementation (AdamWMode.ADAMW_TRANSFORMERS)
or decoupled weight decay with transformers/AdamW implementation
(AdamWMode.ADAMW_TORCH) (default: AdamWMode.ADAMW_TRANSFORMERS)
set_grad_none (bool, optional): whether set grad to None when zero_grad()
method is called. (default: True)
@ -58,23 +67,20 @@ class FusedAdam(torch.optim.Optimizer):
betas=(0.9,
0.999),
eps=1e-6,
adam_w_mode=True,
adam_w_mode=AdamWMode.ADAMW_TRANSFORMERS,
weight_decay=0.,
amsgrad=False,
set_grad_none=False):
set_grad_none=True):
# The FusedAdam implementation is mathematically equivalent to
# transformers AdamW. The input arguments also have the same defaults.
if amsgrad:
raise RuntimeError('FusedAdam does not support the AMSGrad variant.')
defaults = dict(lr=lr,
bias_correction=bias_correction,
betas=betas,
eps=eps,
weight_decay=weight_decay)
super(FusedAdam, self).__init__(params, defaults)
self._adam_w_mode = 1 if adam_w_mode else 0
self._adam_w_mode = adam_w_mode
self._set_grad_none = set_grad_none
# Skip buffer

View file

@ -24,8 +24,9 @@
#define ILP 4
typedef enum {
ADAM_MODE_0 = 0, // L2 regularization mode
ADAM_MODE_1 = 1 // Decoupled weight decay mode(AdamW)
ADAM_MODE_0 = 0, // L2 regularization mode
ADAM_MODE_1 = 1, // Decoupled weight decay mode (AdamW) as implemented in transformers/AdamW
ADAM_MODE_2 = 2 // Decoupled weight decay mode (AdamW) as implemented in pytorch/AdamW
} adamMode_t;
using MATH_T = float;
@ -40,6 +41,8 @@ struct AdamFunctor {
const float epsilon,
const float lr,
const float lr_corrected,
const float bias_correction1,
const float bias_correction2,
adamMode_t mode,
const float decay)
{
@ -91,13 +94,20 @@ struct AdamFunctor {
r_v[ii] = beta2 * r_v[ii] + (1 - beta2) * r_g[ii] * r_g[ii];
MATH_T denom = sqrtf(r_v[ii]) + epsilon;
r_p[ii] = r_p[ii] - (lr_corrected * r_m[ii] / denom);
} else { // weight decay
} else if (mode == ADAM_MODE_1) { // weight decay
// Adapted to be mathematically equivalent to transformers AdamW
r_m[ii] = beta1 * r_m[ii] + (1 - beta1) * r_g[ii];
r_v[ii] = beta2 * r_v[ii] + (1 - beta2) * r_g[ii] * r_g[ii];
MATH_T denom = sqrtf(r_v[ii]) + epsilon;
r_p[ii] = r_p[ii] - (lr_corrected * r_m[ii] / denom);
r_p[ii] = r_p[ii] - (lr * decay * r_p[ii]);
} else if (mode == ADAM_MODE_2) {
// Adapted to be mathematically equivalent to torch AdamW
r_p[ii] = r_p[ii] - (r_p[ii] * lr * decay);
r_m[ii] = beta1 * r_m[ii] + (1 - beta1) * r_g[ii];
r_v[ii] = beta2 * r_v[ii] + (1 - beta2) * r_g[ii] * r_g[ii];
MATH_T denom = (sqrtf(r_v[ii]) / sqrtf(bias_correction2)) + epsilon;
r_p[ii] = r_p[ii] - (lr * r_m[ii]) / (bias_correction1 * denom);
}
}
#pragma unroll
@ -128,7 +138,7 @@ void multi_tensor_adam_cuda(int chunk_size,
using namespace at;
// Handle bias correction mode
double bias_correction1 = 1.0, bias_correction2 = 1.0;
float bias_correction1 = 1.0, bias_correction2 = 1.0;
float lr_corrected = lr;
if (bias_correction == 1) {
bias_correction1 = 1 - std::pow(beta1, step);
@ -150,6 +160,8 @@ void multi_tensor_adam_cuda(int chunk_size,
epsilon,
lr,
lr_corrected,
bias_correction1,
bias_correction2,
(adamMode_t)mode,
weight_decay);)

View file

@ -30,7 +30,7 @@ from onnxruntime.training.ortmodule import (ORTModule,
_graph_execution_manager)
import onnxruntime.training.ortmodule as ortmodule_module
from onnxruntime.training.optim import FusedAdam
from onnxruntime.training.optim import FusedAdam, AdamWMode
from transformers import AdamW
import _test_helpers
@ -4428,6 +4428,57 @@ def test_ortmodule_fused_adam_optimizer_correctness():
for pt_param, ort_param in zip(pt_model.parameters(), ort_model.parameters()):
_test_helpers.assert_values_are_close(pt_param, ort_param, atol=1e-4, rtol=1e-5)
def test_ortmodule_fused_adam_optimizer_correctness_torch():
torch.manual_seed(8888)
device = 'cuda'
N, D_in, H, D_out = 4, 4, 8, 4
pt_model = NeuralNetSinglePositionalArgument(D_in, H, D_out).to(device)
adamw_optimizer = torch.optim.AdamW(pt_model.parameters(), lr=1e-3)
ort_model = ORTModule(copy.deepcopy(pt_model))
ort_fused_adam_optimizer = FusedAdam(ort_model.parameters(), lr=1e-3,
adam_w_mode=AdamWMode.ADAMW_TORCH,
weight_decay=0.01,
eps=1e-8)
def run_step(model, x):
prediction = model(x)
loss = prediction.sum()
loss.backward()
return loss
def run_optim_step(optimizer):
optimizer.step()
optimizer.zero_grad()
ga_steps = 2
pt_model.zero_grad()
ort_model.zero_grad()
for step in range(1000):
x1 = torch.randn(N, D_in, device=device, dtype=torch.float32)
x2 = copy.deepcopy(x1)
pt_loss = run_step(pt_model, x1)
ort_loss = run_step(ort_model, x2)
for pt_param, ort_param in zip(pt_model.parameters(), ort_model.parameters()):
ort_param.grad = copy.deepcopy(pt_param.grad)
_test_helpers.assert_values_are_close(pt_loss, ort_loss, atol=1e-4, rtol=1e-5)
_test_helpers.assert_gradients_match_and_reset_gradient(ort_model, pt_model, atol=1e-4, rtol=1e-5, reset_gradient=False)
if (step+1) % ga_steps == 0:
run_optim_step(adamw_optimizer)
run_optim_step(ort_fused_adam_optimizer)
for pt_param, ort_param in zip(pt_model.parameters(), ort_model.parameters()):
_test_helpers.assert_values_are_close(pt_param, ort_param, atol=1e-4, rtol=1e-5)
def test_sigmoid_grad():
class NeuralNetSigmoid(torch.nn.Module):
def __init__(self, input_size, hidden_size, num_classes):