mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-16 21:00:14 +00:00
Make FusedAdam mathematically equivalent to Transformers AdamW (#9343)
This commit is contained in:
parent
5b65f1cb44
commit
5da4e07daa
5 changed files with 76 additions and 82 deletions
|
|
@ -1,3 +1,5 @@
|
|||
from .config import _OptimizerConfig, AdamConfig, LambConfig, SGDConfig
|
||||
from .lr_scheduler import _LRScheduler, ConstantWarmupLRScheduler, CosineWarmupLRScheduler,\
|
||||
LinearWarmupLRScheduler, PolyWarmupLRScheduler
|
||||
|
||||
from .fused_adam import FusedAdam
|
||||
|
|
|
|||
|
|
@ -11,14 +11,15 @@ This file is adapted from fused adam in NVIDIA/apex, commit a109f85
|
|||
'''
|
||||
|
||||
import torch
|
||||
import importlib
|
||||
from .multi_tensor_apply import MultiTensorApply
|
||||
multi_tensor_applier = MultiTensorApply(2048 * 32)
|
||||
from ._multi_tensor_apply import MultiTensorApply
|
||||
|
||||
|
||||
class FusedAdam(torch.optim.Optimizer):
|
||||
"""Implements Adam algorithm.
|
||||
|
||||
The algorithmic implementation is mathematically equivalent to Transformers/AdamW
|
||||
as defined here: https://github.com/huggingface/transformers/blob/61f64262692ac7dc90e2e0bdeb7e79d9cd607a66/src/transformers/optimization.py#L349-L370
|
||||
|
||||
Currently GPU-only.
|
||||
|
||||
This version of fused Adam implements 2 fusions.
|
||||
|
|
@ -56,11 +57,14 @@ class FusedAdam(torch.optim.Optimizer):
|
|||
bias_correction=True,
|
||||
betas=(0.9,
|
||||
0.999),
|
||||
eps=1e-8,
|
||||
eps=1e-6,
|
||||
adam_w_mode=True,
|
||||
weight_decay=0.,
|
||||
amsgrad=False,
|
||||
set_grad_none=True):
|
||||
set_grad_none=False):
|
||||
|
||||
# The FusedAdam implementation is mathematically equivalent to
|
||||
# transformers AdamW. The input arguments also have the same defaults.
|
||||
|
||||
if amsgrad:
|
||||
raise RuntimeError('FusedAdam does not support the AMSGrad variant.')
|
||||
|
|
@ -70,29 +74,25 @@ class FusedAdam(torch.optim.Optimizer):
|
|||
eps=eps,
|
||||
weight_decay=weight_decay)
|
||||
super(FusedAdam, self).__init__(params, defaults)
|
||||
self.adam_w_mode = 1 if adam_w_mode else 0
|
||||
self.set_grad_none = set_grad_none
|
||||
self._adam_w_mode = 1 if adam_w_mode else 0
|
||||
self._set_grad_none = set_grad_none
|
||||
|
||||
# Skip buffer
|
||||
self._dummy_overflow_buf = torch.cuda.IntTensor([0])
|
||||
|
||||
from onnxruntime.training.ortmodule.torch_cpp_extensions import adam_optimizer
|
||||
self.multi_tensor_adam = adam_optimizer.multi_tensor_adam
|
||||
self._multi_tensor_adam = adam_optimizer.multi_tensor_adam
|
||||
self._multi_tensor_applier = MultiTensorApply(2048 * 32)
|
||||
|
||||
def zero_grad(self):
|
||||
if self.set_grad_none:
|
||||
if self._set_grad_none:
|
||||
for group in self.param_groups:
|
||||
for p in group['params']:
|
||||
p.grad = None
|
||||
else:
|
||||
super(FusedAdam, self).zero_grad()
|
||||
|
||||
def step(self,
|
||||
closure=None,
|
||||
grads=None,
|
||||
output_params=None,
|
||||
scale=None,
|
||||
grad_norms=None):
|
||||
def step(self, closure=None):
|
||||
"""Performs a single optimization step.
|
||||
|
||||
Arguments:
|
||||
|
|
@ -101,10 +101,6 @@ class FusedAdam(torch.optim.Optimizer):
|
|||
|
||||
The remaining arguments are deprecated, and are only retained (for the moment) for error-checking purposes.
|
||||
"""
|
||||
if any(p is not None for p in [grads, output_params, scale, grad_norms]):
|
||||
raise RuntimeError(
|
||||
'FusedAdam has been updated. Simply initialize it identically to torch.optim.Adam, and call step() with no arguments.'
|
||||
)
|
||||
loss = None
|
||||
if closure is not None:
|
||||
loss = closure()
|
||||
|
|
@ -154,34 +150,34 @@ class FusedAdam(torch.optim.Optimizer):
|
|||
raise RuntimeError('FusedAdam only support fp16 and fp32.')
|
||||
|
||||
if (len(g_16) > 0):
|
||||
multi_tensor_applier(self.multi_tensor_adam,
|
||||
self._dummy_overflow_buf,
|
||||
[g_16,
|
||||
p_16,
|
||||
m_16,
|
||||
v_16],
|
||||
group['lr'],
|
||||
beta1,
|
||||
beta2,
|
||||
group['eps'],
|
||||
group['step'],
|
||||
self.adam_w_mode,
|
||||
bias_correction,
|
||||
group['weight_decay'])
|
||||
self._multi_tensor_applier(self._multi_tensor_adam,
|
||||
self._dummy_overflow_buf,
|
||||
[g_16,
|
||||
p_16,
|
||||
m_16,
|
||||
v_16],
|
||||
group['lr'],
|
||||
beta1,
|
||||
beta2,
|
||||
group['eps'],
|
||||
group['step'],
|
||||
self._adam_w_mode,
|
||||
bias_correction,
|
||||
group['weight_decay'])
|
||||
if (len(g_32) > 0):
|
||||
multi_tensor_applier(self.multi_tensor_adam,
|
||||
self._dummy_overflow_buf,
|
||||
[g_32,
|
||||
p_32,
|
||||
m_32,
|
||||
v_32],
|
||||
group['lr'],
|
||||
beta1,
|
||||
beta2,
|
||||
group['eps'],
|
||||
group['step'],
|
||||
self.adam_w_mode,
|
||||
bias_correction,
|
||||
group['weight_decay'])
|
||||
self._multi_tensor_applier(self._multi_tensor_adam,
|
||||
self._dummy_overflow_buf,
|
||||
[g_32,
|
||||
p_32,
|
||||
m_32,
|
||||
v_32],
|
||||
group['lr'],
|
||||
beta1,
|
||||
beta2,
|
||||
group['eps'],
|
||||
group['step'],
|
||||
self._adam_w_mode,
|
||||
bias_correction,
|
||||
group['weight_decay'])
|
||||
|
||||
return loss
|
||||
|
|
|
|||
|
|
@ -18,6 +18,7 @@
|
|||
|
||||
#include "multi_tensor_apply.cuh"
|
||||
#include "type_shim.h"
|
||||
#include <cmath>
|
||||
|
||||
#define BLOCK_SIZE 512
|
||||
#define ILP 4
|
||||
|
|
@ -36,22 +37,14 @@ struct AdamFunctor {
|
|||
TensorListMetadata<4>& tl,
|
||||
const float beta1,
|
||||
const float beta2,
|
||||
const float beta1_correction,
|
||||
const float beta2_correction,
|
||||
const float epsilon,
|
||||
const float lr,
|
||||
const float lr_corrected,
|
||||
adamMode_t mode,
|
||||
const float decay)
|
||||
{
|
||||
// I'd like this kernel to propagate infs/nans.
|
||||
// if(*noop_gmem == 1)
|
||||
// return;
|
||||
|
||||
int tensor_loc = tl.block_to_tensor[blockIdx.x];
|
||||
|
||||
// potentially use to pass in list of scalar
|
||||
// int tensor_num = tl.start_tensor_this_launch + tensor_loc;
|
||||
|
||||
int chunk_idx = tl.block_to_chunk[blockIdx.x];
|
||||
int n = tl.sizes[tensor_loc];
|
||||
|
||||
|
|
@ -96,19 +89,15 @@ struct AdamFunctor {
|
|||
r_g[ii] = r_g[ii] + (decay * r_p[ii]);
|
||||
r_m[ii] = beta1 * r_m[ii] + (1 - beta1) * r_g[ii];
|
||||
r_v[ii] = beta2 * r_v[ii] + (1 - beta2) * r_g[ii] * r_g[ii];
|
||||
MATH_T next_m_unbiased = r_m[ii] / beta1_correction;
|
||||
MATH_T next_v_unbiased = r_v[ii] / beta2_correction;
|
||||
MATH_T denom = sqrtf(next_v_unbiased) + epsilon;
|
||||
MATH_T update = next_m_unbiased / denom;
|
||||
r_p[ii] = r_p[ii] - (lr * update);
|
||||
MATH_T denom = sqrtf(r_v[ii]) + epsilon;
|
||||
r_p[ii] = r_p[ii] - (lr_corrected * r_m[ii] / denom);
|
||||
} else { // weight decay
|
||||
// Adapted to be mathematically equivalent to transformers AdamW
|
||||
r_m[ii] = beta1 * r_m[ii] + (1 - beta1) * r_g[ii];
|
||||
r_v[ii] = beta2 * r_v[ii] + (1 - beta2) * r_g[ii] * r_g[ii];
|
||||
MATH_T next_m_unbiased = r_m[ii] / beta1_correction;
|
||||
MATH_T next_v_unbiased = r_v[ii] / beta2_correction;
|
||||
MATH_T denom = sqrtf(next_v_unbiased) + epsilon;
|
||||
MATH_T update = (next_m_unbiased / denom) + (decay * r_p[ii]);
|
||||
r_p[ii] = r_p[ii] - (lr * update);
|
||||
MATH_T denom = sqrtf(r_v[ii]) + epsilon;
|
||||
r_p[ii] = r_p[ii] - (lr_corrected * r_m[ii] / denom);
|
||||
r_p[ii] = r_p[ii] - (lr * decay * r_p[ii]);
|
||||
}
|
||||
}
|
||||
#pragma unroll
|
||||
|
|
@ -139,10 +128,12 @@ void multi_tensor_adam_cuda(int chunk_size,
|
|||
using namespace at;
|
||||
|
||||
// Handle bias correction mode
|
||||
float bias_correction1 = 1.0f, bias_correction2 = 1.0f;
|
||||
double bias_correction1 = 1.0, bias_correction2 = 1.0;
|
||||
float lr_corrected = lr;
|
||||
if (bias_correction == 1) {
|
||||
bias_correction1 = 1 - std::pow(beta1, step);
|
||||
bias_correction2 = 1 - std::pow(beta2, step);
|
||||
lr_corrected *= std::sqrt(bias_correction2) / bias_correction1;
|
||||
}
|
||||
|
||||
// Assume single type across p,g,m1,m2 now
|
||||
|
|
@ -156,10 +147,9 @@ void multi_tensor_adam_cuda(int chunk_size,
|
|||
AdamFunctor<scalar_t_0>(),
|
||||
beta1,
|
||||
beta2,
|
||||
bias_correction1,
|
||||
bias_correction2,
|
||||
epsilon,
|
||||
lr,
|
||||
lr_corrected,
|
||||
(adamMode_t)mode,
|
||||
weight_decay);)
|
||||
|
||||
|
|
|
|||
|
|
@ -29,7 +29,7 @@ from onnxruntime.training.ortmodule import (ORTModule,
|
|||
_fallback,
|
||||
_graph_execution_manager)
|
||||
|
||||
from onnxruntime.training.optim.fused_adam import FusedAdam
|
||||
from onnxruntime.training.optim import FusedAdam
|
||||
from transformers import AdamW
|
||||
|
||||
import _test_helpers
|
||||
|
|
@ -4103,20 +4103,23 @@ def test_override_pytorch_exporter_kwargs_using_ortmodule_extension():
|
|||
|
||||
def test_ortmodule_fused_adam_optimizer_correctness():
|
||||
|
||||
torch.manual_seed(8888)
|
||||
|
||||
device = 'cuda'
|
||||
N, D_in, H, D_out = 32, 128, 500, 10
|
||||
|
||||
pt_model = NeuralNetSinglePositionalArgument(D_in, H, D_out).to(device)
|
||||
transformers_adamw_optimizer = AdamW(pt_model.parameters())
|
||||
transformers_adamw_optimizer = AdamW(pt_model.parameters(), lr=1)
|
||||
|
||||
ort_model = ORTModule(copy.deepcopy(pt_model))
|
||||
ort_fused_adam_optimizer = FusedAdam(ort_model.parameters())
|
||||
ort_fused_adam_optimizer = FusedAdam(ort_model.parameters(), lr=1)
|
||||
|
||||
def run_step(model, x):
|
||||
prediction = model(x)
|
||||
loss = prediction.sum()
|
||||
loss.backward()
|
||||
return prediction, loss
|
||||
|
||||
return loss
|
||||
|
||||
def run_optim_step(optimizer):
|
||||
optimizer.step()
|
||||
|
|
@ -4126,22 +4129,25 @@ def test_ortmodule_fused_adam_optimizer_correctness():
|
|||
pt_model.zero_grad()
|
||||
ort_model.zero_grad()
|
||||
|
||||
for step in range(10):
|
||||
x = torch.randn(N, D_in, device=device)
|
||||
for step in range(1000):
|
||||
x1 = torch.randn(N, D_in, device=device, dtype=torch.float32)
|
||||
x2 = copy.deepcopy(x1)
|
||||
|
||||
_, pt_loss = run_step(pt_model, x)
|
||||
_, ort_loss = run_step(ort_model, x)
|
||||
pt_loss = run_step(pt_model, x1)
|
||||
ort_loss = run_step(ort_model, x2)
|
||||
|
||||
_test_helpers.assert_values_are_close(pt_loss, ort_loss, rtol=1e-4)
|
||||
_test_helpers.assert_gradients_match_and_reset_gradient(ort_model, pt_model)
|
||||
for pt_param, ort_param in zip(pt_model.parameters(), ort_model.parameters()):
|
||||
ort_param.grad = copy.deepcopy(pt_param.grad)
|
||||
|
||||
_test_helpers.assert_values_are_close(pt_loss, ort_loss)
|
||||
_test_helpers.assert_gradients_match_and_reset_gradient(ort_model, pt_model, reset_gradient=False)
|
||||
|
||||
if (step+1) % ga_steps == 0:
|
||||
run_optim_step(transformers_adamw_optimizer)
|
||||
run_optim_step(ort_fused_adam_optimizer)
|
||||
|
||||
for pt_param, ort_param in zip(pt_model.parameters(), ort_model.parameters()):
|
||||
_test_helpers.assert_values_are_close(pt_param, ort_param)
|
||||
|
||||
for pt_param, ort_param in zip(pt_model.parameters(), ort_model.parameters()):
|
||||
_test_helpers.assert_values_are_close(pt_param, ort_param, atol=1e-4, rtol=1e-5)
|
||||
|
||||
def test_sigmoid_grad():
|
||||
class NeuralNetSigmoid(torch.nn.Module):
|
||||
|
|
|
|||
Loading…
Reference in a new issue