Make FusedAdam mathematically equivalent to Transformers AdamW (#9343)

This commit is contained in:
baijumeswani 2021-10-18 16:03:18 -07:00 committed by GitHub
parent 5b65f1cb44
commit 5da4e07daa
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 76 additions and 82 deletions

View file

@ -1,3 +1,5 @@
from .config import _OptimizerConfig, AdamConfig, LambConfig, SGDConfig
from .lr_scheduler import _LRScheduler, ConstantWarmupLRScheduler, CosineWarmupLRScheduler,\
LinearWarmupLRScheduler, PolyWarmupLRScheduler
from .fused_adam import FusedAdam

View file

@ -11,14 +11,15 @@ This file is adapted from fused adam in NVIDIA/apex, commit a109f85
'''
import torch
import importlib
from .multi_tensor_apply import MultiTensorApply
multi_tensor_applier = MultiTensorApply(2048 * 32)
from ._multi_tensor_apply import MultiTensorApply
class FusedAdam(torch.optim.Optimizer):
"""Implements Adam algorithm.
The algorithmic implementation is mathematically equivalent to Transformers/AdamW
as defined here: https://github.com/huggingface/transformers/blob/61f64262692ac7dc90e2e0bdeb7e79d9cd607a66/src/transformers/optimization.py#L349-L370
Currently GPU-only.
This version of fused Adam implements 2 fusions.
@ -56,11 +57,14 @@ class FusedAdam(torch.optim.Optimizer):
bias_correction=True,
betas=(0.9,
0.999),
eps=1e-8,
eps=1e-6,
adam_w_mode=True,
weight_decay=0.,
amsgrad=False,
set_grad_none=True):
set_grad_none=False):
# The FusedAdam implementation is mathematically equivalent to
# transformers AdamW. The input arguments also have the same defaults.
if amsgrad:
raise RuntimeError('FusedAdam does not support the AMSGrad variant.')
@ -70,29 +74,25 @@ class FusedAdam(torch.optim.Optimizer):
eps=eps,
weight_decay=weight_decay)
super(FusedAdam, self).__init__(params, defaults)
self.adam_w_mode = 1 if adam_w_mode else 0
self.set_grad_none = set_grad_none
self._adam_w_mode = 1 if adam_w_mode else 0
self._set_grad_none = set_grad_none
# Skip buffer
self._dummy_overflow_buf = torch.cuda.IntTensor([0])
from onnxruntime.training.ortmodule.torch_cpp_extensions import adam_optimizer
self.multi_tensor_adam = adam_optimizer.multi_tensor_adam
self._multi_tensor_adam = adam_optimizer.multi_tensor_adam
self._multi_tensor_applier = MultiTensorApply(2048 * 32)
def zero_grad(self):
if self.set_grad_none:
if self._set_grad_none:
for group in self.param_groups:
for p in group['params']:
p.grad = None
else:
super(FusedAdam, self).zero_grad()
def step(self,
closure=None,
grads=None,
output_params=None,
scale=None,
grad_norms=None):
def step(self, closure=None):
"""Performs a single optimization step.
Arguments:
@ -101,10 +101,6 @@ class FusedAdam(torch.optim.Optimizer):
The remaining arguments are deprecated, and are only retained (for the moment) for error-checking purposes.
"""
if any(p is not None for p in [grads, output_params, scale, grad_norms]):
raise RuntimeError(
'FusedAdam has been updated. Simply initialize it identically to torch.optim.Adam, and call step() with no arguments.'
)
loss = None
if closure is not None:
loss = closure()
@ -154,34 +150,34 @@ class FusedAdam(torch.optim.Optimizer):
raise RuntimeError('FusedAdam only support fp16 and fp32.')
if (len(g_16) > 0):
multi_tensor_applier(self.multi_tensor_adam,
self._dummy_overflow_buf,
[g_16,
p_16,
m_16,
v_16],
group['lr'],
beta1,
beta2,
group['eps'],
group['step'],
self.adam_w_mode,
bias_correction,
group['weight_decay'])
self._multi_tensor_applier(self._multi_tensor_adam,
self._dummy_overflow_buf,
[g_16,
p_16,
m_16,
v_16],
group['lr'],
beta1,
beta2,
group['eps'],
group['step'],
self._adam_w_mode,
bias_correction,
group['weight_decay'])
if (len(g_32) > 0):
multi_tensor_applier(self.multi_tensor_adam,
self._dummy_overflow_buf,
[g_32,
p_32,
m_32,
v_32],
group['lr'],
beta1,
beta2,
group['eps'],
group['step'],
self.adam_w_mode,
bias_correction,
group['weight_decay'])
self._multi_tensor_applier(self._multi_tensor_adam,
self._dummy_overflow_buf,
[g_32,
p_32,
m_32,
v_32],
group['lr'],
beta1,
beta2,
group['eps'],
group['step'],
self._adam_w_mode,
bias_correction,
group['weight_decay'])
return loss

View file

@ -18,6 +18,7 @@
#include "multi_tensor_apply.cuh"
#include "type_shim.h"
#include <cmath>
#define BLOCK_SIZE 512
#define ILP 4
@ -36,22 +37,14 @@ struct AdamFunctor {
TensorListMetadata<4>& tl,
const float beta1,
const float beta2,
const float beta1_correction,
const float beta2_correction,
const float epsilon,
const float lr,
const float lr_corrected,
adamMode_t mode,
const float decay)
{
// I'd like this kernel to propagate infs/nans.
// if(*noop_gmem == 1)
// return;
int tensor_loc = tl.block_to_tensor[blockIdx.x];
// potentially use to pass in list of scalar
// int tensor_num = tl.start_tensor_this_launch + tensor_loc;
int chunk_idx = tl.block_to_chunk[blockIdx.x];
int n = tl.sizes[tensor_loc];
@ -96,19 +89,15 @@ struct AdamFunctor {
r_g[ii] = r_g[ii] + (decay * r_p[ii]);
r_m[ii] = beta1 * r_m[ii] + (1 - beta1) * r_g[ii];
r_v[ii] = beta2 * r_v[ii] + (1 - beta2) * r_g[ii] * r_g[ii];
MATH_T next_m_unbiased = r_m[ii] / beta1_correction;
MATH_T next_v_unbiased = r_v[ii] / beta2_correction;
MATH_T denom = sqrtf(next_v_unbiased) + epsilon;
MATH_T update = next_m_unbiased / denom;
r_p[ii] = r_p[ii] - (lr * update);
MATH_T denom = sqrtf(r_v[ii]) + epsilon;
r_p[ii] = r_p[ii] - (lr_corrected * r_m[ii] / denom);
} else { // weight decay
// Adapted to be mathematically equivalent to transformers AdamW
r_m[ii] = beta1 * r_m[ii] + (1 - beta1) * r_g[ii];
r_v[ii] = beta2 * r_v[ii] + (1 - beta2) * r_g[ii] * r_g[ii];
MATH_T next_m_unbiased = r_m[ii] / beta1_correction;
MATH_T next_v_unbiased = r_v[ii] / beta2_correction;
MATH_T denom = sqrtf(next_v_unbiased) + epsilon;
MATH_T update = (next_m_unbiased / denom) + (decay * r_p[ii]);
r_p[ii] = r_p[ii] - (lr * update);
MATH_T denom = sqrtf(r_v[ii]) + epsilon;
r_p[ii] = r_p[ii] - (lr_corrected * r_m[ii] / denom);
r_p[ii] = r_p[ii] - (lr * decay * r_p[ii]);
}
}
#pragma unroll
@ -139,10 +128,12 @@ void multi_tensor_adam_cuda(int chunk_size,
using namespace at;
// Handle bias correction mode
float bias_correction1 = 1.0f, bias_correction2 = 1.0f;
double bias_correction1 = 1.0, bias_correction2 = 1.0;
float lr_corrected = lr;
if (bias_correction == 1) {
bias_correction1 = 1 - std::pow(beta1, step);
bias_correction2 = 1 - std::pow(beta2, step);
lr_corrected *= std::sqrt(bias_correction2) / bias_correction1;
}
// Assume single type across p,g,m1,m2 now
@ -156,10 +147,9 @@ void multi_tensor_adam_cuda(int chunk_size,
AdamFunctor<scalar_t_0>(),
beta1,
beta2,
bias_correction1,
bias_correction2,
epsilon,
lr,
lr_corrected,
(adamMode_t)mode,
weight_decay);)

View file

@ -29,7 +29,7 @@ from onnxruntime.training.ortmodule import (ORTModule,
_fallback,
_graph_execution_manager)
from onnxruntime.training.optim.fused_adam import FusedAdam
from onnxruntime.training.optim import FusedAdam
from transformers import AdamW
import _test_helpers
@ -4103,20 +4103,23 @@ def test_override_pytorch_exporter_kwargs_using_ortmodule_extension():
def test_ortmodule_fused_adam_optimizer_correctness():
torch.manual_seed(8888)
device = 'cuda'
N, D_in, H, D_out = 32, 128, 500, 10
pt_model = NeuralNetSinglePositionalArgument(D_in, H, D_out).to(device)
transformers_adamw_optimizer = AdamW(pt_model.parameters())
transformers_adamw_optimizer = AdamW(pt_model.parameters(), lr=1)
ort_model = ORTModule(copy.deepcopy(pt_model))
ort_fused_adam_optimizer = FusedAdam(ort_model.parameters())
ort_fused_adam_optimizer = FusedAdam(ort_model.parameters(), lr=1)
def run_step(model, x):
prediction = model(x)
loss = prediction.sum()
loss.backward()
return prediction, loss
return loss
def run_optim_step(optimizer):
optimizer.step()
@ -4126,22 +4129,25 @@ def test_ortmodule_fused_adam_optimizer_correctness():
pt_model.zero_grad()
ort_model.zero_grad()
for step in range(10):
x = torch.randn(N, D_in, device=device)
for step in range(1000):
x1 = torch.randn(N, D_in, device=device, dtype=torch.float32)
x2 = copy.deepcopy(x1)
_, pt_loss = run_step(pt_model, x)
_, ort_loss = run_step(ort_model, x)
pt_loss = run_step(pt_model, x1)
ort_loss = run_step(ort_model, x2)
_test_helpers.assert_values_are_close(pt_loss, ort_loss, rtol=1e-4)
_test_helpers.assert_gradients_match_and_reset_gradient(ort_model, pt_model)
for pt_param, ort_param in zip(pt_model.parameters(), ort_model.parameters()):
ort_param.grad = copy.deepcopy(pt_param.grad)
_test_helpers.assert_values_are_close(pt_loss, ort_loss)
_test_helpers.assert_gradients_match_and_reset_gradient(ort_model, pt_model, reset_gradient=False)
if (step+1) % ga_steps == 0:
run_optim_step(transformers_adamw_optimizer)
run_optim_step(ort_fused_adam_optimizer)
for pt_param, ort_param in zip(pt_model.parameters(), ort_model.parameters()):
_test_helpers.assert_values_are_close(pt_param, ort_param)
for pt_param, ort_param in zip(pt_model.parameters(), ort_model.parameters()):
_test_helpers.assert_values_are_close(pt_param, ort_param, atol=1e-4, rtol=1e-5)
def test_sigmoid_grad():
class NeuralNetSigmoid(torch.nn.Module):