Added check and test for betas parameter in Adam optimizer (#5147)

* Added check and test for betas parameter in Adam optimizer

* Simplified test
This commit is contained in:
lazypanda1 2018-02-11 19:24:43 -06:00 committed by Soumith Chintala
parent 6dc41f9e63
commit a061000250
2 changed files with 6 additions and 0 deletions

View file

@ -292,6 +292,8 @@ class TestOptim(TestCase):
self._build_params_dict(weight, bias, lr=1e-2),
lr=1e-3, amsgrad=True)
)
with self.assertRaisesRegexp(ValueError, "Invalid beta parameter at index 0: 1.0"):
optim.Adam(None, lr=1e-2, betas=(1.0, 0.0))
def test_sparse_adam(self):
self._test_rosenbrock_sparse(

View file

@ -28,6 +28,10 @@ class Adam(Optimizer):
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8,
weight_decay=0, amsgrad=False):
if not 0.0 <= betas[0] < 1.0:
raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
if not 0.0 <= betas[1] < 1.0:
raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
defaults = dict(lr=lr, betas=betas, eps=eps,
weight_decay=weight_decay, amsgrad=amsgrad)
super(Adam, self).__init__(params, defaults)