mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-25 22:26:24 +00:00
Update LambConfig defaults to match backend (#4826)
This commit is contained in:
parent
6a360bad6b
commit
f933910ea3
2 changed files with 8 additions and 8 deletions
|
|
@ -192,7 +192,7 @@ class LambConfig(_OptimizerConfig):
|
|||
ratio_min (float, default is -inf): Lower bound on confidence ratio.
|
||||
ratio_max (float, default is inf): Upper bound on confidence ratio.
|
||||
epsilon (float, default is 1e-6): Small scalar to avoid dividing by zero.
|
||||
do_bias_correction (bool, default is True): Compute unbiased 1st and 2nd momentums.
|
||||
do_bias_correction (bool, default is False): Compute unbiased 1st and 2nd momentums.
|
||||
|
||||
NOTE: To prevent model parameters to be trained, refer to :py:attr:`.ORTTrainerOptions.utils.frozen_weights`.
|
||||
|
||||
|
|
@ -208,7 +208,7 @@ class LambConfig(_OptimizerConfig):
|
|||
"""
|
||||
|
||||
def __init__(self, params=[], lr=0.001, alpha=0.9, beta=0.999, lambda_coef=0.0,
|
||||
ratio_min=float('-inf'), ratio_max=float('inf'), epsilon=1e-6, do_bias_correction=True):
|
||||
ratio_min=float('-inf'), ratio_max=float('inf'), epsilon=1e-6, do_bias_correction=False):
|
||||
assert lr >= 0, "'lr' must be a positive number"
|
||||
assert alpha >= 0, "'alpha' must be a positive number"
|
||||
assert beta >= 0, "'beta' must be a positive number"
|
||||
|
|
|
|||
|
|
@ -403,7 +403,7 @@ def testOptimizerConfigInvalidInputs(optim_name, defaults, params):
|
|||
name=optim_name, params=params, defaults=defaults)
|
||||
|
||||
|
||||
def testSGDConfig():
|
||||
def testOptimizerConfigSGD():
|
||||
'''Test initialization of SGD'''
|
||||
cfg = optim.SGDConfig()
|
||||
assert cfg.name == 'SGDOptimizer'
|
||||
|
|
@ -422,7 +422,7 @@ def testSGDConfig():
|
|||
assert str(e.value) == "'params' must be an empty list for SGD optimizer"
|
||||
|
||||
|
||||
def testAdamConfig():
|
||||
def testOptimizerConfigAdam():
|
||||
'''Test initialization of Adam'''
|
||||
cfg = optim.AdamConfig()
|
||||
assert cfg.name == 'AdamOptimizer'
|
||||
|
|
@ -438,7 +438,7 @@ def testAdamConfig():
|
|||
assert cfg.weight_decay_mode == optim.AdamConfig.DecayMode.BEFORE_WEIGHT_UPDATE, "weight_decay_mode mismatch"
|
||||
|
||||
|
||||
def testLambConfig():
|
||||
def testOptimizerConfigLamb():
|
||||
'''Test initialization of Lamb'''
|
||||
cfg = optim.LambConfig()
|
||||
assert cfg.name == 'LambOptimizer'
|
||||
|
|
@ -451,14 +451,14 @@ def testLambConfig():
|
|||
assert cfg.ratio_min == float('-inf'), "ratio_min mismatch"
|
||||
assert cfg.ratio_max == float('inf'), "ratio_max mismatch"
|
||||
assert_allclose(1e-6, cfg.epsilon, rtol=rtol, err_msg="epsilon mismatch")
|
||||
assert cfg.do_bias_correction == True, "lambda_coef mismatch"
|
||||
assert cfg.do_bias_correction == False, "do_bias_correction mismatch"
|
||||
|
||||
|
||||
@pytest.mark.parametrize("optim_name", [
|
||||
('Adam'),
|
||||
('Lamb')
|
||||
])
|
||||
def testParamparams(optim_name):
|
||||
def testOptimizerConfigParams(optim_name):
|
||||
rtol = 1e-5
|
||||
params = [{'params': ['layer1.weight'], 'alpha': 0.1}]
|
||||
if optim_name == 'Adam':
|
||||
|
|
@ -476,7 +476,7 @@ def testParamparams(optim_name):
|
|||
('Adam'),
|
||||
('Lamb')
|
||||
])
|
||||
def testInvalidParamparams(optim_name):
|
||||
def testOptimizerConfigInvalidParams(optim_name):
|
||||
# lr is not supported within params
|
||||
with pytest.raises(AssertionError) as e:
|
||||
params = [{'params': ['layer1.weight'], 'lr': 0.1}]
|
||||
|
|
|
|||
Loading…
Reference in a new issue