diff --git a/test/test_optim.py b/test/test_optim.py index f05b219a42f..76ff0e6c405 100644 --- a/test/test_optim.py +++ b/test/test_optim.py @@ -292,6 +292,8 @@ class TestOptim(TestCase): self._build_params_dict(weight, bias, lr=1e-2), lr=1e-3, amsgrad=True) ) + with self.assertRaisesRegexp(ValueError, "Invalid beta parameter at index 0: 1.0"): + optim.Adam(None, lr=1e-2, betas=(1.0, 0.0)) def test_sparse_adam(self): self._test_rosenbrock_sparse( diff --git a/torch/optim/adam.py b/torch/optim/adam.py index 490fe8b31e3..6d401e52fe7 100644 --- a/torch/optim/adam.py +++ b/torch/optim/adam.py @@ -28,6 +28,10 @@ class Adam(Optimizer): def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, amsgrad=False): + if not 0.0 <= betas[0] < 1.0: + raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) + if not 0.0 <= betas[1] < 1.0: + raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, amsgrad=amsgrad) super(Adam, self).__init__(params, defaults)