mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-08 00:23:03 +00:00
Improve LRScheduler tests (#4885)
* LRScheduler tests added to the Transformer model * Refactored LRScheduler tests for the BERT Toy onnx example * Removed dead code
This commit is contained in:
parent
e00ad83f2b
commit
acbf6d15c6
3 changed files with 175 additions and 102 deletions
45
orttraining/orttraining/test/python/_test_commons.py
Normal file
45
orttraining/orttraining/test/python/_test_commons.py
Normal file
|
|
@ -0,0 +1,45 @@
|
|||
import math
|
||||
|
||||
|
||||
def legacy_constant_lr_scheduler(global_step, initial_lr, total_steps, warmup):
|
||||
num_warmup_steps = warmup * total_steps
|
||||
if global_step < num_warmup_steps:
|
||||
new_lr = initial_lr * float(global_step) / float(max(1, num_warmup_steps))
|
||||
else:
|
||||
new_lr = initial_lr
|
||||
return new_lr
|
||||
|
||||
|
||||
def legacy_cosine_lr_scheduler(global_step, initial_lr, total_steps, warmup, cycles):
|
||||
num_warmup_steps = warmup * total_steps
|
||||
if global_step < num_warmup_steps:
|
||||
new_lr = initial_lr * float(global_step) / float(max(1, num_warmup_steps))
|
||||
else:
|
||||
progress = float(global_step - num_warmup_steps) / float(max(1, total_steps - num_warmup_steps))
|
||||
new_lr = initial_lr * max(0.0, 0.5 * (1.0 + math.cos(math.pi * float(cycles) * 2.0 * progress)))
|
||||
return new_lr
|
||||
|
||||
|
||||
|
||||
def legacy_linear_lr_scheduler(global_step, initial_lr, total_steps, warmup):
|
||||
num_warmup_steps = warmup * total_steps
|
||||
if global_step < num_warmup_steps:
|
||||
new_lr = initial_lr * float(global_step) / float(max(1, num_warmup_steps))
|
||||
else:
|
||||
new_lr = initial_lr * max(0.0, float(total_steps - global_step) / float(max(1, total_steps - num_warmup_steps)))
|
||||
return new_lr
|
||||
|
||||
|
||||
def legacy_poly_lr_scheduler(global_step, initial_lr, total_steps, warmup, power, lr_end):
|
||||
num_warmup_steps = warmup * total_steps
|
||||
if global_step < num_warmup_steps:
|
||||
new_lr = initial_lr * float(global_step) / float(max(1, num_warmup_steps))
|
||||
elif global_step > total_steps:
|
||||
new_lr = lr_end
|
||||
else:
|
||||
lr_range = initial_lr - lr_end
|
||||
decay_steps = total_steps - num_warmup_steps
|
||||
pct_remaining = 1 - (global_step - num_warmup_steps) / decay_steps
|
||||
decay = lr_range * pct_remaining ** power + lr_end
|
||||
new_lr = decay
|
||||
return new_lr
|
||||
|
|
@ -1,14 +1,14 @@
|
|||
import copy
|
||||
from functools import partial
|
||||
import inspect
|
||||
import math
|
||||
from numpy.testing import assert_allclose
|
||||
import onnx
|
||||
import os
|
||||
import math
|
||||
import pytest
|
||||
import copy
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from numpy.testing import assert_allclose
|
||||
|
||||
import onnxruntime
|
||||
from onnxruntime.capi.ort_trainer import IODescription as Legacy_IODescription,\
|
||||
ModelDescription as Legacy_ModelDescription,\
|
||||
|
|
@ -18,7 +18,7 @@ from onnxruntime.experimental import _utils, amp, checkpoint, optim, orttrainer,
|
|||
model_desc_validation as md_val,\
|
||||
orttrainer_options as orttrainer_options
|
||||
|
||||
import _test_helpers
|
||||
import _test_commons, _test_helpers
|
||||
|
||||
|
||||
###############################################################################
|
||||
|
|
@ -147,94 +147,6 @@ def legacy_bert_model_description():
|
|||
next_sentence_labels_desc], [loss_desc])
|
||||
|
||||
|
||||
def legacy_constant_lr_scheduler_1(global_step):
|
||||
return legacy_constant_lr_scheduler(global_step, 1.0)
|
||||
|
||||
|
||||
def legacy_constant_lr_scheduler_5(global_step):
|
||||
return legacy_constant_lr_scheduler(global_step, 0.5)
|
||||
|
||||
|
||||
def legacy_constant_lr_scheduler(global_step, initial_lr):
|
||||
warmup = 0.5
|
||||
total_steps = 10
|
||||
|
||||
num_warmup_steps = warmup * total_steps
|
||||
if global_step < num_warmup_steps:
|
||||
new_lr = initial_lr * float(global_step) / float(max(1, num_warmup_steps))
|
||||
else:
|
||||
new_lr = initial_lr
|
||||
return new_lr
|
||||
|
||||
|
||||
def legacy_cosine_lr_scheduler(global_step):
|
||||
initial_lr = 1.0
|
||||
warmup = 0.5
|
||||
total_steps = 10
|
||||
cycles = 0.5
|
||||
|
||||
num_warmup_steps = warmup * total_steps
|
||||
if global_step < num_warmup_steps:
|
||||
new_lr = initial_lr * float(global_step) / float(max(1, num_warmup_steps))
|
||||
else:
|
||||
progress = float(global_step - num_warmup_steps) / float(max(1, total_steps - num_warmup_steps))
|
||||
new_lr = initial_lr * max(0.0, 0.5 * (1.0 + math.cos(math.pi * float(cycles) * 2.0 * progress)))
|
||||
return new_lr
|
||||
|
||||
|
||||
|
||||
def legacy_linear_lr_scheduler(global_step):
|
||||
initial_lr = 1.0
|
||||
warmup = 0.5
|
||||
total_steps = 10
|
||||
|
||||
num_warmup_steps = warmup * total_steps
|
||||
if global_step < num_warmup_steps:
|
||||
new_lr = initial_lr * float(global_step) / float(max(1, num_warmup_steps))
|
||||
else:
|
||||
new_lr = max(0.0, float(total_steps - global_step) / float(max(1, total_steps - num_warmup_steps)))
|
||||
return new_lr
|
||||
|
||||
|
||||
def legacy_poly_lr_scheduler(global_step):
|
||||
initial_lr = 1.0
|
||||
warmup = 0.5
|
||||
total_steps = 10
|
||||
lr_end = 1e-7
|
||||
power = 1.0
|
||||
|
||||
num_warmup_steps = warmup * total_steps
|
||||
if global_step < num_warmup_steps:
|
||||
new_lr = initial_lr * float(global_step) / float(max(1, num_warmup_steps))
|
||||
elif global_step > total_steps:
|
||||
new_lr = lr_end / initial_lr
|
||||
else:
|
||||
lr_range = initial_lr - lr_end
|
||||
decay_steps = total_steps - num_warmup_steps
|
||||
pct_remaining = 1 - (global_step - num_warmup_steps) / decay_steps
|
||||
decay = lr_range * pct_remaining ** power + lr_end
|
||||
new_lr = decay / initial_lr
|
||||
return new_lr
|
||||
|
||||
|
||||
def legacy_optim_params_a(name):
|
||||
return {"alpha": 0.9, "beta": 0.999, "lambda": 0.01, "epsilon": 1e-6}
|
||||
|
||||
|
||||
def legacy_optim_params_b(name):
|
||||
params = ['bert.embeddings.LayerNorm.bias', 'bert.embeddings.LayerNorm.weight']
|
||||
if name in params:
|
||||
return {"alpha": 0.9, "beta": 0.999, "lambda": 0.0, "epsilon": 1e-6}
|
||||
return {"alpha": 0.9, "beta": 0.999, "lambda": 0.01, "epsilon": 1e-6}
|
||||
|
||||
|
||||
def legacy_optim_params_c(name):
|
||||
params_group = optimizer_parameters(load_bert_onnx_model())
|
||||
if name in params_group[0]['params']:
|
||||
return {"alpha": 0.9, "beta": 0.999, "lambda": 0.0, "epsilon": 1e-6}
|
||||
return {"alpha": 0.9, "beta": 0.999, "lambda": 0.01, "epsilon": 1e-6}
|
||||
|
||||
|
||||
###############################################################################
|
||||
# Testing starts here #########################################################
|
||||
###############################################################################
|
||||
|
|
@ -292,7 +204,7 @@ def testToyBERTDeterministicCheck(expected_losses):
|
|||
experimental_losses.append(trainer.train_step(*sample_input).cpu().item())
|
||||
|
||||
# Check output
|
||||
_test_helpers.assert_model_outputs(experimental_losses, expected_losses)
|
||||
_test_helpers.assert_model_outputs(experimental_losses, expected_losses, rtol=1e-6)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("initial_lr, lr_scheduler, expected_learning_rates, expected_losses", [
|
||||
|
|
@ -349,7 +261,7 @@ def testToyBERTModelLRScheduler(initial_lr, lr_scheduler, expected_learning_rate
|
|||
learning_rates.append(trainer.options.lr_scheduler.get_last_lr()[0])
|
||||
|
||||
# Check output
|
||||
_test_helpers.assert_model_outputs(learning_rates, expected_learning_rates)
|
||||
_test_helpers.assert_model_outputs(learning_rates, expected_learning_rates, rtol=1e-6)
|
||||
_test_helpers.assert_model_outputs(losses, expected_losses, rtol=1e-6)
|
||||
|
||||
|
||||
|
|
@ -743,11 +655,11 @@ def testToyBERTModelLegacyExperimentalBasicTraining():
|
|||
|
||||
|
||||
@pytest.mark.parametrize("initial_lr, lr_scheduler, legacy_lr_scheduler", [
|
||||
(1.0, optim.lr_scheduler.ConstantWarmupLRScheduler, legacy_constant_lr_scheduler_1),
|
||||
(0.5, optim.lr_scheduler.ConstantWarmupLRScheduler, legacy_constant_lr_scheduler_5),
|
||||
(1.0, optim.lr_scheduler.CosineWarmupLRScheduler, legacy_cosine_lr_scheduler),
|
||||
(1.0, optim.lr_scheduler.LinearWarmupLRScheduler, legacy_linear_lr_scheduler),
|
||||
(1.0, optim.lr_scheduler.PolyWarmupLRScheduler, legacy_poly_lr_scheduler),
|
||||
(1.0, optim.lr_scheduler.ConstantWarmupLRScheduler, _test_commons.legacy_constant_lr_scheduler),
|
||||
(0.5, optim.lr_scheduler.ConstantWarmupLRScheduler, _test_commons.legacy_constant_lr_scheduler),
|
||||
(1.0, optim.lr_scheduler.CosineWarmupLRScheduler, _test_commons.legacy_cosine_lr_scheduler),
|
||||
(1.0, optim.lr_scheduler.LinearWarmupLRScheduler, _test_commons.legacy_linear_lr_scheduler),
|
||||
(1.0, optim.lr_scheduler.PolyWarmupLRScheduler, _test_commons.legacy_poly_lr_scheduler),
|
||||
])
|
||||
def testToyBERTModelLegacyExperimentalLRScheduler(initial_lr, lr_scheduler, legacy_lr_scheduler):
|
||||
############################################################################
|
||||
|
|
@ -758,6 +670,29 @@ def testToyBERTModelLegacyExperimentalLRScheduler(initial_lr, lr_scheduler, lega
|
|||
total_steps = 10
|
||||
device = 'cuda'
|
||||
seed = 1
|
||||
warmup = 0.05
|
||||
cycles = 0.5
|
||||
power = 1.
|
||||
lr_end = 1e-7
|
||||
|
||||
# Setup both Experimental and Legacy LR Schedulers before the experimental loop
|
||||
if legacy_lr_scheduler == _test_commons.legacy_constant_lr_scheduler or legacy_lr_scheduler == _test_commons.legacy_linear_lr_scheduler:
|
||||
legacy_lr_scheduler = partial(legacy_lr_scheduler, initial_lr=initial_lr, total_steps=total_steps, warmup=warmup)
|
||||
elif legacy_lr_scheduler == _test_commons.legacy_cosine_lr_scheduler:
|
||||
legacy_lr_scheduler = partial(legacy_lr_scheduler, initial_lr=initial_lr, total_steps=total_steps, warmup=warmup, cycles=cycles)
|
||||
elif legacy_lr_scheduler == _test_commons.legacy_poly_lr_scheduler:
|
||||
legacy_lr_scheduler = partial(legacy_lr_scheduler, initial_lr=initial_lr, total_steps=total_steps, warmup=warmup, power=power, lr_end=lr_end)
|
||||
else:
|
||||
raise RuntimeError("Invalid legacy_lr_scheduler")
|
||||
if lr_scheduler == optim.lr_scheduler.ConstantWarmupLRScheduler or lr_scheduler == optim.lr_scheduler.LinearWarmupLRScheduler:
|
||||
lr_scheduler = lr_scheduler(total_steps=total_steps, warmup=warmup)
|
||||
elif lr_scheduler == optim.lr_scheduler.CosineWarmupLRScheduler:
|
||||
lr_scheduler = lr_scheduler(total_steps=total_steps, warmup=warmup, cycles=cycles)
|
||||
elif lr_scheduler == optim.lr_scheduler.PolyWarmupLRScheduler:
|
||||
lr_scheduler = lr_scheduler(total_steps=total_steps, warmup=warmup, power=power, lr_end=lr_end)
|
||||
else:
|
||||
raise RuntimeError("Invalid lr_scheduler")
|
||||
|
||||
|
||||
# EXPERIMENTAL API
|
||||
model_desc = bert_model_description()
|
||||
|
|
@ -772,7 +707,7 @@ def testToyBERTModelLegacyExperimentalLRScheduler(initial_lr, lr_scheduler, lega
|
|||
'device': {
|
||||
'id': device,
|
||||
},
|
||||
'lr_scheduler' : lr_scheduler(total_steps=total_steps, warmup=0.5)
|
||||
'lr_scheduler' : lr_scheduler
|
||||
})
|
||||
trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, options=opts)
|
||||
experimental_losses = []
|
||||
|
|
@ -785,6 +720,7 @@ def testToyBERTModelLegacyExperimentalLRScheduler(initial_lr, lr_scheduler, lega
|
|||
torch.manual_seed(seed)
|
||||
onnxruntime.set_seed(seed)
|
||||
device = torch.device(device)
|
||||
|
||||
legacy_model_desc, learning_rate_description, learning_rate = legacy_model_params(initial_lr)
|
||||
legacy_trainer = Legacy_ORTTrainer(model, None, legacy_model_desc, "AdamOptimizer",
|
||||
None,
|
||||
|
|
|
|||
|
|
@ -1,4 +1,7 @@
|
|||
from functools import partial
|
||||
import inspect
|
||||
import math
|
||||
|
||||
import onnx
|
||||
import os
|
||||
import pytest
|
||||
|
|
@ -14,7 +17,7 @@ from onnxruntime.capi.ort_trainer import IODescription as Legacy_IODescription,\
|
|||
from onnxruntime.experimental import _utils, amp, optim, orttrainer, TrainStepInfo,\
|
||||
model_desc_validation as md_val,\
|
||||
orttrainer_options as orttrainer_options
|
||||
import _test_helpers
|
||||
import _test_commons,_test_helpers
|
||||
|
||||
|
||||
###############################################################################
|
||||
|
|
@ -964,3 +967,92 @@ def testORTTrainerLegacyAndExperimentalGradientAccumulation(seed, device, gradie
|
|||
|
||||
# Compare legacy vs experimental APIs
|
||||
_test_helpers.assert_model_outputs(legacy_loss, experimental_loss, rtol=1e-6)
|
||||
|
||||
|
||||
|
||||
|
||||
@pytest.mark.parametrize("seed,device,optimizer_config,lr_scheduler, get_lr_this_step", [
|
||||
(0, 'cuda', optim.AdamConfig, optim.lr_scheduler.ConstantWarmupLRScheduler, _test_commons.legacy_constant_lr_scheduler),
|
||||
(0, 'cuda', optim.LambConfig, optim.lr_scheduler.ConstantWarmupLRScheduler, _test_commons.legacy_constant_lr_scheduler),
|
||||
(0, 'cuda', optim.SGDConfig, optim.lr_scheduler.ConstantWarmupLRScheduler, _test_commons.legacy_constant_lr_scheduler),
|
||||
(42, 'cuda', optim.AdamConfig, optim.lr_scheduler.LinearWarmupLRScheduler, _test_commons.legacy_linear_lr_scheduler),
|
||||
(42, 'cuda', optim.LambConfig, optim.lr_scheduler.LinearWarmupLRScheduler, _test_commons.legacy_linear_lr_scheduler),
|
||||
(42, 'cuda', optim.SGDConfig, optim.lr_scheduler.LinearWarmupLRScheduler, _test_commons.legacy_linear_lr_scheduler),
|
||||
(123, 'cuda', optim.AdamConfig, optim.lr_scheduler.CosineWarmupLRScheduler, _test_commons.legacy_cosine_lr_scheduler),
|
||||
(123, 'cuda', optim.LambConfig, optim.lr_scheduler.CosineWarmupLRScheduler, _test_commons.legacy_cosine_lr_scheduler),
|
||||
(123, 'cuda', optim.SGDConfig, optim.lr_scheduler.CosineWarmupLRScheduler, _test_commons.legacy_cosine_lr_scheduler),
|
||||
(321, 'cuda', optim.AdamConfig, optim.lr_scheduler.PolyWarmupLRScheduler, _test_commons.legacy_poly_lr_scheduler),
|
||||
(321, 'cuda', optim.LambConfig, optim.lr_scheduler.PolyWarmupLRScheduler, _test_commons.legacy_poly_lr_scheduler),
|
||||
(321, 'cuda', optim.SGDConfig, optim.lr_scheduler.PolyWarmupLRScheduler, _test_commons.legacy_poly_lr_scheduler),
|
||||
])
|
||||
def testORTTrainerLegacyAndExperimentalLRScheduler(seed, device, optimizer_config, lr_scheduler, get_lr_this_step):
|
||||
# Common data
|
||||
total_steps = 10
|
||||
lr = 0.001
|
||||
warmup = 0.5
|
||||
cycles = 0.5
|
||||
power = 1.
|
||||
lr_end = 1e-7
|
||||
torch.set_printoptions(precision=10)
|
||||
|
||||
# Setup experimental API
|
||||
torch.manual_seed(seed)
|
||||
set_seed(seed)
|
||||
if lr_scheduler == optim.lr_scheduler.ConstantWarmupLRScheduler or lr_scheduler == optim.lr_scheduler.LinearWarmupLRScheduler:
|
||||
lr_scheduler = lr_scheduler(total_steps=total_steps, warmup=warmup)
|
||||
elif lr_scheduler == optim.lr_scheduler.CosineWarmupLRScheduler:
|
||||
lr_scheduler = lr_scheduler(total_steps=total_steps, warmup=warmup, cycles=cycles)
|
||||
elif lr_scheduler == optim.lr_scheduler.PolyWarmupLRScheduler:
|
||||
lr_scheduler = lr_scheduler(total_steps=total_steps, warmup=warmup, power=power, lr_end=lr_end)
|
||||
else:
|
||||
raise RuntimeError("Invalid lr_scheduler")
|
||||
|
||||
options = orttrainer.ORTTrainerOptions({'device' : {'id' : device},
|
||||
'debug' : {'deterministic_compute' : True},
|
||||
'lr_scheduler' : lr_scheduler})
|
||||
model, model_desc, my_loss, batcher_fn, train_data, val_data, _ = _load_pytorch_transformer_model(device)
|
||||
optim_config = optimizer_config(lr=lr)
|
||||
trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, loss_fn=my_loss, options=options)
|
||||
# Training loop
|
||||
experimental_loss = []
|
||||
for i in range(total_steps):
|
||||
data, targets = batcher_fn(train_data, i)
|
||||
exp_loss, exp_preds = trainer.train_step(data, targets)
|
||||
experimental_loss.append(exp_loss.cpu())
|
||||
|
||||
# Setup legacy API
|
||||
torch.manual_seed(seed)
|
||||
set_seed(seed)
|
||||
|
||||
if optimizer_config == optim.AdamConfig:
|
||||
legacy_optimizer_config = 'AdamOptimizer'
|
||||
elif optimizer_config == optim.LambConfig:
|
||||
legacy_optimizer_config = 'LambOptimizer'
|
||||
elif optimizer_config == optim.SGDConfig:
|
||||
legacy_optimizer_config = 'SGDOptimizer'
|
||||
else:
|
||||
raise RuntimeError("Invalid optimizer_config")
|
||||
|
||||
if get_lr_this_step == _test_commons.legacy_constant_lr_scheduler or get_lr_this_step == _test_commons.legacy_linear_lr_scheduler:
|
||||
get_lr_this_step = partial(get_lr_this_step, initial_lr=lr, total_steps=total_steps, warmup=warmup)
|
||||
elif get_lr_this_step == _test_commons.legacy_cosine_lr_scheduler:
|
||||
get_lr_this_step = partial(get_lr_this_step, initial_lr=lr, total_steps=total_steps, warmup=warmup, cycles=cycles)
|
||||
elif get_lr_this_step == _test_commons.legacy_poly_lr_scheduler:
|
||||
get_lr_this_step = partial(get_lr_this_step, initial_lr=lr, total_steps=total_steps, warmup=warmup, power=power, lr_end=lr_end)
|
||||
else:
|
||||
raise RuntimeError("Invalid get_lr_this_step")
|
||||
|
||||
model, (model_desc, lr_desc), _, _, _, _, _ = _load_pytorch_transformer_model(device, legacy_api=True)
|
||||
legacy_trainer = Legacy_ORTTrainer(model, my_loss, model_desc, legacy_optimizer_config,
|
||||
None, lr_desc, device=device,
|
||||
_use_deterministic_compute=True,
|
||||
get_lr_this_step=get_lr_this_step)
|
||||
# Training loop
|
||||
legacy_loss = []
|
||||
for i in range(total_steps):
|
||||
data, targets = batcher_fn(train_data, i)
|
||||
leg_loss, leg_preds = legacy_trainer.train_step(data, targets)
|
||||
legacy_loss.append(leg_loss.cpu())
|
||||
|
||||
# Compare legacy vs experimental APIs
|
||||
_test_helpers.assert_model_outputs(legacy_loss, experimental_loss)
|
||||
|
|
|
|||
Loading…
Reference in a new issue