mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-15 21:00:47 +00:00
Added check and test for betas parameter in Adam optimizer (#5147)
* Added check and test for betas parameter in Adam optimizer * Simplified test
This commit is contained in:
parent
6dc41f9e63
commit
a061000250
2 changed files with 6 additions and 0 deletions
|
|
@ -292,6 +292,8 @@ class TestOptim(TestCase):
|
|||
self._build_params_dict(weight, bias, lr=1e-2),
|
||||
lr=1e-3, amsgrad=True)
|
||||
)
|
||||
with self.assertRaisesRegexp(ValueError, "Invalid beta parameter at index 0: 1.0"):
|
||||
optim.Adam(None, lr=1e-2, betas=(1.0, 0.0))
|
||||
|
||||
def test_sparse_adam(self):
|
||||
self._test_rosenbrock_sparse(
|
||||
|
|
|
|||
|
|
@ -28,6 +28,10 @@ class Adam(Optimizer):
|
|||
|
||||
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8,
|
||||
weight_decay=0, amsgrad=False):
|
||||
if not 0.0 <= betas[0] < 1.0:
|
||||
raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
|
||||
if not 0.0 <= betas[1] < 1.0:
|
||||
raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
|
||||
defaults = dict(lr=lr, betas=betas, eps=eps,
|
||||
weight_decay=weight_decay, amsgrad=amsgrad)
|
||||
super(Adam, self).__init__(params, defaults)
|
||||
|
|
|
|||
Loading…
Reference in a new issue