diff --git a/orttraining/orttraining/python/training/optim/__init__.py b/orttraining/orttraining/python/training/optim/__init__.py index 7d35a84b40..291268307d 100644 --- a/orttraining/orttraining/python/training/optim/__init__.py +++ b/orttraining/orttraining/python/training/optim/__init__.py @@ -2,5 +2,5 @@ from .config import _OptimizerConfig, AdamConfig, LambConfig, SGDConfig from .lr_scheduler import _LRScheduler, ConstantWarmupLRScheduler, CosineWarmupLRScheduler,\ LinearWarmupLRScheduler, PolyWarmupLRScheduler -from .fused_adam import FusedAdam +from .fused_adam import FusedAdam, AdamWMode from .fp16_optimizer import FP16_Optimizer diff --git a/orttraining/orttraining/python/training/optim/fused_adam.py b/orttraining/orttraining/python/training/optim/fused_adam.py index a6c36752d2..e655468e41 100644 --- a/orttraining/orttraining/python/training/optim/fused_adam.py +++ b/orttraining/orttraining/python/training/optim/fused_adam.py @@ -12,13 +12,22 @@ This file is adapted from fused adam in NVIDIA/apex, commit a109f85 import torch from ._multi_tensor_apply import MultiTensorApply +from enum import IntEnum + + +class AdamWMode(IntEnum): + ADAM_L2_REGULARIZATION = 0 # Adam with L2 regularization + ADAMW_TRANSFORMERS = 1 # Adam with weight decay implemented to be equivalent to Transformers/AdamW + ADAMW_TORCH = 2 # Adam with weight decay implemented to be equivalent to torch/AdamW class FusedAdam(torch.optim.Optimizer): """Implements Adam algorithm. - The algorithmic implementation is mathematically equivalent to Transformers/AdamW - as defined here: https://github.com/huggingface/transformers/blob/61f64262692ac7dc90e2e0bdeb7e79d9cd607a66/src/transformers/optimization.py#L349-L370 + The algorithmic implementation is mathematically equivalent to + `Transformers/AdamW `_ + when adam_w_mode = 1 and `torch/Adam `_ + when adam_w_mode = 2 Currently GPU-only. @@ -27,7 +36,7 @@ class FusedAdam(torch.optim.Optimizer): * Fusion of the Adam update's elementwise operations * A multi-tensor apply launch that batches the elementwise updates applied to all the model's parameters into one or a few kernel launches. - Adam was been proposed in `Adam: A Method for Stochastic Optimization`_. + Adam was proposed in `Adam: A Method for Stochastic Optimization`_. Arguments: params (iterable): iterable of parameters to optimize or dicts defining @@ -38,11 +47,11 @@ class FusedAdam(torch.optim.Optimizer): eps (float, optional): term added to the denominator to improve numerical stability. (default: 1e-8) weight_decay (float, optional): weight decay (L2 penalty) (default: 0) - amsgrad (boolean, optional): whether to use the AMSGrad variant of this - algorithm from the paper `On the Convergence of Adam and Beyond`_ - (default: False) NOT SUPPORTED in FusedAdam! - adam_w_mode (boolean, optional): Apply L2 regularization or weight decay - True for decoupled weight decay(also known as AdamW) (default: True) + adam_w_mode (AdamWMode, optional): Apply L2 regularization or weight decay + (AdamWMode.ADAM_L2_REGULARIZATION), decoupled weight decay with + transformers/AdamW mathematical implementation (AdamWMode.ADAMW_TRANSFORMERS) + or decoupled weight decay with transformers/AdamW implementation + (AdamWMode.ADAMW_TORCH) (default: AdamWMode.ADAMW_TRANSFORMERS) set_grad_none (bool, optional): whether set grad to None when zero_grad() method is called. (default: True) @@ -58,23 +67,20 @@ class FusedAdam(torch.optim.Optimizer): betas=(0.9, 0.999), eps=1e-6, - adam_w_mode=True, + adam_w_mode=AdamWMode.ADAMW_TRANSFORMERS, weight_decay=0., - amsgrad=False, - set_grad_none=False): + set_grad_none=True): # The FusedAdam implementation is mathematically equivalent to # transformers AdamW. The input arguments also have the same defaults. - if amsgrad: - raise RuntimeError('FusedAdam does not support the AMSGrad variant.') defaults = dict(lr=lr, bias_correction=bias_correction, betas=betas, eps=eps, weight_decay=weight_decay) super(FusedAdam, self).__init__(params, defaults) - self._adam_w_mode = 1 if adam_w_mode else 0 + self._adam_w_mode = adam_w_mode self._set_grad_none = set_grad_none # Skip buffer diff --git a/orttraining/orttraining/python/training/ortmodule/torch_cpp_extensions/cuda/fused_ops/multi_tensor_adam.cu b/orttraining/orttraining/python/training/ortmodule/torch_cpp_extensions/cuda/fused_ops/multi_tensor_adam.cu index 65601b2e19..affb0afa91 100644 --- a/orttraining/orttraining/python/training/ortmodule/torch_cpp_extensions/cuda/fused_ops/multi_tensor_adam.cu +++ b/orttraining/orttraining/python/training/ortmodule/torch_cpp_extensions/cuda/fused_ops/multi_tensor_adam.cu @@ -24,8 +24,9 @@ #define ILP 4 typedef enum { - ADAM_MODE_0 = 0, // L2 regularization mode - ADAM_MODE_1 = 1 // Decoupled weight decay mode(AdamW) + ADAM_MODE_0 = 0, // L2 regularization mode + ADAM_MODE_1 = 1, // Decoupled weight decay mode (AdamW) as implemented in transformers/AdamW + ADAM_MODE_2 = 2 // Decoupled weight decay mode (AdamW) as implemented in pytorch/AdamW } adamMode_t; using MATH_T = float; @@ -40,6 +41,8 @@ struct AdamFunctor { const float epsilon, const float lr, const float lr_corrected, + const float bias_correction1, + const float bias_correction2, adamMode_t mode, const float decay) { @@ -91,13 +94,20 @@ struct AdamFunctor { r_v[ii] = beta2 * r_v[ii] + (1 - beta2) * r_g[ii] * r_g[ii]; MATH_T denom = sqrtf(r_v[ii]) + epsilon; r_p[ii] = r_p[ii] - (lr_corrected * r_m[ii] / denom); - } else { // weight decay + } else if (mode == ADAM_MODE_1) { // weight decay // Adapted to be mathematically equivalent to transformers AdamW r_m[ii] = beta1 * r_m[ii] + (1 - beta1) * r_g[ii]; r_v[ii] = beta2 * r_v[ii] + (1 - beta2) * r_g[ii] * r_g[ii]; MATH_T denom = sqrtf(r_v[ii]) + epsilon; r_p[ii] = r_p[ii] - (lr_corrected * r_m[ii] / denom); r_p[ii] = r_p[ii] - (lr * decay * r_p[ii]); + } else if (mode == ADAM_MODE_2) { + // Adapted to be mathematically equivalent to torch AdamW + r_p[ii] = r_p[ii] - (r_p[ii] * lr * decay); + r_m[ii] = beta1 * r_m[ii] + (1 - beta1) * r_g[ii]; + r_v[ii] = beta2 * r_v[ii] + (1 - beta2) * r_g[ii] * r_g[ii]; + MATH_T denom = (sqrtf(r_v[ii]) / sqrtf(bias_correction2)) + epsilon; + r_p[ii] = r_p[ii] - (lr * r_m[ii]) / (bias_correction1 * denom); } } #pragma unroll @@ -128,7 +138,7 @@ void multi_tensor_adam_cuda(int chunk_size, using namespace at; // Handle bias correction mode - double bias_correction1 = 1.0, bias_correction2 = 1.0; + float bias_correction1 = 1.0, bias_correction2 = 1.0; float lr_corrected = lr; if (bias_correction == 1) { bias_correction1 = 1 - std::pow(beta1, step); @@ -150,6 +160,8 @@ void multi_tensor_adam_cuda(int chunk_size, epsilon, lr, lr_corrected, + bias_correction1, + bias_correction2, (adamMode_t)mode, weight_decay);) diff --git a/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py b/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py index 0964e47f21..be24d7c091 100644 --- a/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py +++ b/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py @@ -30,7 +30,7 @@ from onnxruntime.training.ortmodule import (ORTModule, _graph_execution_manager) import onnxruntime.training.ortmodule as ortmodule_module -from onnxruntime.training.optim import FusedAdam +from onnxruntime.training.optim import FusedAdam, AdamWMode from transformers import AdamW import _test_helpers @@ -4428,6 +4428,57 @@ def test_ortmodule_fused_adam_optimizer_correctness(): for pt_param, ort_param in zip(pt_model.parameters(), ort_model.parameters()): _test_helpers.assert_values_are_close(pt_param, ort_param, atol=1e-4, rtol=1e-5) +def test_ortmodule_fused_adam_optimizer_correctness_torch(): + + torch.manual_seed(8888) + + device = 'cuda' + N, D_in, H, D_out = 4, 4, 8, 4 + + pt_model = NeuralNetSinglePositionalArgument(D_in, H, D_out).to(device) + adamw_optimizer = torch.optim.AdamW(pt_model.parameters(), lr=1e-3) + + ort_model = ORTModule(copy.deepcopy(pt_model)) + ort_fused_adam_optimizer = FusedAdam(ort_model.parameters(), lr=1e-3, + adam_w_mode=AdamWMode.ADAMW_TORCH, + weight_decay=0.01, + eps=1e-8) + + def run_step(model, x): + prediction = model(x) + loss = prediction.sum() + loss.backward() + + return loss + + def run_optim_step(optimizer): + optimizer.step() + optimizer.zero_grad() + + ga_steps = 2 + pt_model.zero_grad() + ort_model.zero_grad() + + for step in range(1000): + x1 = torch.randn(N, D_in, device=device, dtype=torch.float32) + x2 = copy.deepcopy(x1) + + pt_loss = run_step(pt_model, x1) + ort_loss = run_step(ort_model, x2) + + for pt_param, ort_param in zip(pt_model.parameters(), ort_model.parameters()): + ort_param.grad = copy.deepcopy(pt_param.grad) + + _test_helpers.assert_values_are_close(pt_loss, ort_loss, atol=1e-4, rtol=1e-5) + _test_helpers.assert_gradients_match_and_reset_gradient(ort_model, pt_model, atol=1e-4, rtol=1e-5, reset_gradient=False) + + if (step+1) % ga_steps == 0: + run_optim_step(adamw_optimizer) + run_optim_step(ort_fused_adam_optimizer) + + for pt_param, ort_param in zip(pt_model.parameters(), ort_model.parameters()): + _test_helpers.assert_values_are_close(pt_param, ort_param, atol=1e-4, rtol=1e-5) + def test_sigmoid_grad(): class NeuralNetSigmoid(torch.nn.Module): def __init__(self, input_size, hidden_size, num_classes):