diff --git a/orttraining/orttraining/python/training/optim/__init__.py b/orttraining/orttraining/python/training/optim/__init__.py index d0541a609b..c01c1c1eab 100644 --- a/orttraining/orttraining/python/training/optim/__init__.py +++ b/orttraining/orttraining/python/training/optim/__init__.py @@ -1,3 +1,5 @@ from .config import _OptimizerConfig, AdamConfig, LambConfig, SGDConfig from .lr_scheduler import _LRScheduler, ConstantWarmupLRScheduler, CosineWarmupLRScheduler,\ LinearWarmupLRScheduler, PolyWarmupLRScheduler + +from .fused_adam import FusedAdam diff --git a/orttraining/orttraining/python/training/optim/multi_tensor_apply.py b/orttraining/orttraining/python/training/optim/_multi_tensor_apply.py similarity index 100% rename from orttraining/orttraining/python/training/optim/multi_tensor_apply.py rename to orttraining/orttraining/python/training/optim/_multi_tensor_apply.py diff --git a/orttraining/orttraining/python/training/optim/fused_adam.py b/orttraining/orttraining/python/training/optim/fused_adam.py index c6ab65ed5c..4ca5523bab 100644 --- a/orttraining/orttraining/python/training/optim/fused_adam.py +++ b/orttraining/orttraining/python/training/optim/fused_adam.py @@ -11,14 +11,15 @@ This file is adapted from fused adam in NVIDIA/apex, commit a109f85 ''' import torch -import importlib -from .multi_tensor_apply import MultiTensorApply -multi_tensor_applier = MultiTensorApply(2048 * 32) +from ._multi_tensor_apply import MultiTensorApply class FusedAdam(torch.optim.Optimizer): """Implements Adam algorithm. + The algorithmic implementation is mathematically equivalent to Transformers/AdamW + as defined here: https://github.com/huggingface/transformers/blob/61f64262692ac7dc90e2e0bdeb7e79d9cd607a66/src/transformers/optimization.py#L349-L370 + Currently GPU-only. This version of fused Adam implements 2 fusions. @@ -56,11 +57,14 @@ class FusedAdam(torch.optim.Optimizer): bias_correction=True, betas=(0.9, 0.999), - eps=1e-8, + eps=1e-6, adam_w_mode=True, weight_decay=0., amsgrad=False, - set_grad_none=True): + set_grad_none=False): + + # The FusedAdam implementation is mathematically equivalent to + # transformers AdamW. The input arguments also have the same defaults. if amsgrad: raise RuntimeError('FusedAdam does not support the AMSGrad variant.') @@ -70,29 +74,25 @@ class FusedAdam(torch.optim.Optimizer): eps=eps, weight_decay=weight_decay) super(FusedAdam, self).__init__(params, defaults) - self.adam_w_mode = 1 if adam_w_mode else 0 - self.set_grad_none = set_grad_none + self._adam_w_mode = 1 if adam_w_mode else 0 + self._set_grad_none = set_grad_none # Skip buffer self._dummy_overflow_buf = torch.cuda.IntTensor([0]) from onnxruntime.training.ortmodule.torch_cpp_extensions import adam_optimizer - self.multi_tensor_adam = adam_optimizer.multi_tensor_adam + self._multi_tensor_adam = adam_optimizer.multi_tensor_adam + self._multi_tensor_applier = MultiTensorApply(2048 * 32) def zero_grad(self): - if self.set_grad_none: + if self._set_grad_none: for group in self.param_groups: for p in group['params']: p.grad = None else: super(FusedAdam, self).zero_grad() - def step(self, - closure=None, - grads=None, - output_params=None, - scale=None, - grad_norms=None): + def step(self, closure=None): """Performs a single optimization step. Arguments: @@ -101,10 +101,6 @@ class FusedAdam(torch.optim.Optimizer): The remaining arguments are deprecated, and are only retained (for the moment) for error-checking purposes. """ - if any(p is not None for p in [grads, output_params, scale, grad_norms]): - raise RuntimeError( - 'FusedAdam has been updated. Simply initialize it identically to torch.optim.Adam, and call step() with no arguments.' - ) loss = None if closure is not None: loss = closure() @@ -154,34 +150,34 @@ class FusedAdam(torch.optim.Optimizer): raise RuntimeError('FusedAdam only support fp16 and fp32.') if (len(g_16) > 0): - multi_tensor_applier(self.multi_tensor_adam, - self._dummy_overflow_buf, - [g_16, - p_16, - m_16, - v_16], - group['lr'], - beta1, - beta2, - group['eps'], - group['step'], - self.adam_w_mode, - bias_correction, - group['weight_decay']) + self._multi_tensor_applier(self._multi_tensor_adam, + self._dummy_overflow_buf, + [g_16, + p_16, + m_16, + v_16], + group['lr'], + beta1, + beta2, + group['eps'], + group['step'], + self._adam_w_mode, + bias_correction, + group['weight_decay']) if (len(g_32) > 0): - multi_tensor_applier(self.multi_tensor_adam, - self._dummy_overflow_buf, - [g_32, - p_32, - m_32, - v_32], - group['lr'], - beta1, - beta2, - group['eps'], - group['step'], - self.adam_w_mode, - bias_correction, - group['weight_decay']) + self._multi_tensor_applier(self._multi_tensor_adam, + self._dummy_overflow_buf, + [g_32, + p_32, + m_32, + v_32], + group['lr'], + beta1, + beta2, + group['eps'], + group['step'], + self._adam_w_mode, + bias_correction, + group['weight_decay']) return loss diff --git a/orttraining/orttraining/python/training/ortmodule/torch_cpp_extensions/cuda/adam_optimizer/multi_tensor_adam.cu b/orttraining/orttraining/python/training/ortmodule/torch_cpp_extensions/cuda/adam_optimizer/multi_tensor_adam.cu index be096f3b75..65601b2e19 100644 --- a/orttraining/orttraining/python/training/ortmodule/torch_cpp_extensions/cuda/adam_optimizer/multi_tensor_adam.cu +++ b/orttraining/orttraining/python/training/ortmodule/torch_cpp_extensions/cuda/adam_optimizer/multi_tensor_adam.cu @@ -18,6 +18,7 @@ #include "multi_tensor_apply.cuh" #include "type_shim.h" +#include #define BLOCK_SIZE 512 #define ILP 4 @@ -36,22 +37,14 @@ struct AdamFunctor { TensorListMetadata<4>& tl, const float beta1, const float beta2, - const float beta1_correction, - const float beta2_correction, const float epsilon, const float lr, + const float lr_corrected, adamMode_t mode, const float decay) { - // I'd like this kernel to propagate infs/nans. - // if(*noop_gmem == 1) - // return; - int tensor_loc = tl.block_to_tensor[blockIdx.x]; - // potentially use to pass in list of scalar - // int tensor_num = tl.start_tensor_this_launch + tensor_loc; - int chunk_idx = tl.block_to_chunk[blockIdx.x]; int n = tl.sizes[tensor_loc]; @@ -96,19 +89,15 @@ struct AdamFunctor { r_g[ii] = r_g[ii] + (decay * r_p[ii]); r_m[ii] = beta1 * r_m[ii] + (1 - beta1) * r_g[ii]; r_v[ii] = beta2 * r_v[ii] + (1 - beta2) * r_g[ii] * r_g[ii]; - MATH_T next_m_unbiased = r_m[ii] / beta1_correction; - MATH_T next_v_unbiased = r_v[ii] / beta2_correction; - MATH_T denom = sqrtf(next_v_unbiased) + epsilon; - MATH_T update = next_m_unbiased / denom; - r_p[ii] = r_p[ii] - (lr * update); + MATH_T denom = sqrtf(r_v[ii]) + epsilon; + r_p[ii] = r_p[ii] - (lr_corrected * r_m[ii] / denom); } else { // weight decay + // Adapted to be mathematically equivalent to transformers AdamW r_m[ii] = beta1 * r_m[ii] + (1 - beta1) * r_g[ii]; r_v[ii] = beta2 * r_v[ii] + (1 - beta2) * r_g[ii] * r_g[ii]; - MATH_T next_m_unbiased = r_m[ii] / beta1_correction; - MATH_T next_v_unbiased = r_v[ii] / beta2_correction; - MATH_T denom = sqrtf(next_v_unbiased) + epsilon; - MATH_T update = (next_m_unbiased / denom) + (decay * r_p[ii]); - r_p[ii] = r_p[ii] - (lr * update); + MATH_T denom = sqrtf(r_v[ii]) + epsilon; + r_p[ii] = r_p[ii] - (lr_corrected * r_m[ii] / denom); + r_p[ii] = r_p[ii] - (lr * decay * r_p[ii]); } } #pragma unroll @@ -139,10 +128,12 @@ void multi_tensor_adam_cuda(int chunk_size, using namespace at; // Handle bias correction mode - float bias_correction1 = 1.0f, bias_correction2 = 1.0f; + double bias_correction1 = 1.0, bias_correction2 = 1.0; + float lr_corrected = lr; if (bias_correction == 1) { bias_correction1 = 1 - std::pow(beta1, step); bias_correction2 = 1 - std::pow(beta2, step); + lr_corrected *= std::sqrt(bias_correction2) / bias_correction1; } // Assume single type across p,g,m1,m2 now @@ -156,10 +147,9 @@ void multi_tensor_adam_cuda(int chunk_size, AdamFunctor(), beta1, beta2, - bias_correction1, - bias_correction2, epsilon, lr, + lr_corrected, (adamMode_t)mode, weight_decay);) diff --git a/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py b/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py index 894cd3bf8a..8b17df4a07 100644 --- a/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py +++ b/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py @@ -29,7 +29,7 @@ from onnxruntime.training.ortmodule import (ORTModule, _fallback, _graph_execution_manager) -from onnxruntime.training.optim.fused_adam import FusedAdam +from onnxruntime.training.optim import FusedAdam from transformers import AdamW import _test_helpers @@ -4103,20 +4103,23 @@ def test_override_pytorch_exporter_kwargs_using_ortmodule_extension(): def test_ortmodule_fused_adam_optimizer_correctness(): + torch.manual_seed(8888) + device = 'cuda' N, D_in, H, D_out = 32, 128, 500, 10 pt_model = NeuralNetSinglePositionalArgument(D_in, H, D_out).to(device) - transformers_adamw_optimizer = AdamW(pt_model.parameters()) + transformers_adamw_optimizer = AdamW(pt_model.parameters(), lr=1) ort_model = ORTModule(copy.deepcopy(pt_model)) - ort_fused_adam_optimizer = FusedAdam(ort_model.parameters()) + ort_fused_adam_optimizer = FusedAdam(ort_model.parameters(), lr=1) def run_step(model, x): prediction = model(x) loss = prediction.sum() loss.backward() - return prediction, loss + + return loss def run_optim_step(optimizer): optimizer.step() @@ -4126,22 +4129,25 @@ def test_ortmodule_fused_adam_optimizer_correctness(): pt_model.zero_grad() ort_model.zero_grad() - for step in range(10): - x = torch.randn(N, D_in, device=device) + for step in range(1000): + x1 = torch.randn(N, D_in, device=device, dtype=torch.float32) + x2 = copy.deepcopy(x1) - _, pt_loss = run_step(pt_model, x) - _, ort_loss = run_step(ort_model, x) + pt_loss = run_step(pt_model, x1) + ort_loss = run_step(ort_model, x2) - _test_helpers.assert_values_are_close(pt_loss, ort_loss, rtol=1e-4) - _test_helpers.assert_gradients_match_and_reset_gradient(ort_model, pt_model) + for pt_param, ort_param in zip(pt_model.parameters(), ort_model.parameters()): + ort_param.grad = copy.deepcopy(pt_param.grad) + + _test_helpers.assert_values_are_close(pt_loss, ort_loss) + _test_helpers.assert_gradients_match_and_reset_gradient(ort_model, pt_model, reset_gradient=False) if (step+1) % ga_steps == 0: run_optim_step(transformers_adamw_optimizer) run_optim_step(ort_fused_adam_optimizer) - for pt_param, ort_param in zip(pt_model.parameters(), ort_model.parameters()): - _test_helpers.assert_values_are_close(pt_param, ort_param) - + for pt_param, ort_param in zip(pt_model.parameters(), ort_model.parameters()): + _test_helpers.assert_values_are_close(pt_param, ort_param, atol=1e-4, rtol=1e-5) def test_sigmoid_grad(): class NeuralNetSigmoid(torch.nn.Module):