Add `allow_empty_param_list` to functional optimizers (#62522)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/62522

Addresses https://github.com/pytorch/pytorch/issues/62481

Test Plan: Imported from OSS

Reviewed By: zou3519

Differential Revision: D30072074

Pulled By: andwgu

fbshipit-source-id: 1a5da21f9636b8d74a6b00c0f029427f0edff0e3
This commit is contained in:
Andrew Gu 2021-08-09 11:15:35 -07:00 committed by Facebook GitHub Bot
parent 710c419f11
commit 1b1f1e36b4
11 changed files with 22 additions and 16 deletions

View file

@ -956,7 +956,7 @@ class TestZeroRedundancyOptimizerDistributed(TestZeroRedundancyOptimizer):
lr=SGD_LR,
momentum=SGD_MOMENTUM,
weight_decay=SGD_WEIGHT_DECAY,
allow_empty_param_list=True
_allow_empty_param_list=True
)
ddp_model_overlap.register_comm_hook(
None,

View file

@ -41,7 +41,7 @@ class TestFunctionalOptimParity(TestCase):
if not functional_optim_cls:
raise ValueError(f"Functional optimizer not implemented for {optim_cls}")
optim_functional = functional_optim_cls(
[], *args, **kwargs, allow_empty_param_list=True
[], *args, **kwargs, _allow_empty_param_list=True
)
if not hasattr(optim_functional, "step_param"):
raise ValueError(

View file

@ -85,7 +85,7 @@ class _OptimizerHookState(object):
[],
*functional_optim_args,
**functional_optim_kwargs,
allow_empty_param_list=True,
_allow_empty_param_list=True,
)
if not hasattr(self.functional_optimizer, "step_param"):
raise ValueError(

View file

@ -22,6 +22,7 @@ class _FunctionalAdadelta(object):
rho: float = 0.9,
eps: float = 1e-6,
weight_decay: float = 0.0,
_allow_empty_param_list: bool = False,
):
self.defaults = {
"lr": lr,
@ -30,7 +31,7 @@ class _FunctionalAdadelta(object):
"weight_decay": weight_decay,
}
if len(params) == 0:
if len(params) == 0 and not _allow_empty_param_list:
raise ValueError("optimizer got an empty parameter list")
# NOTE: we only have one param_group and don't allow user to add additional

View file

@ -26,6 +26,7 @@ class _FunctionalAdagrad(object):
warmup_num_iters: float = 0.0,
eps: float = 1e-10,
coalesce_grad: bool = True,
_allow_empty_param_list: bool = False,
):
self.defaults = {
"lr": lr,
@ -39,7 +40,7 @@ class _FunctionalAdagrad(object):
self.coalesce_grad = coalesce_grad
self.state = torch.jit.annotate(Dict[torch.Tensor, Dict[str, torch.Tensor]], {})
if len(params) == 0:
if len(params) == 0 and not _allow_empty_param_list:
raise ValueError("optimizer got an empty parameter list")
# NOTE: we only have one param_group and don't allow user to add additional

View file

@ -23,7 +23,7 @@ class _FunctionalAdam(object):
eps: float = 1e-8,
weight_decay: float = 0.0,
amsgrad: bool = False,
allow_empty_param_list: bool = False
_allow_empty_param_list: bool = False,
):
if not 0.0 <= lr:
raise ValueError("Invalid learning rate: {}".format(lr))
@ -46,7 +46,7 @@ class _FunctionalAdam(object):
self.amsgrad = amsgrad
self.state = torch.jit.annotate(Dict[torch.Tensor, Dict[str, torch.Tensor]], {})
if len(params) == 0 and not allow_empty_param_list:
if len(params) == 0 and not _allow_empty_param_list:
raise ValueError("optimizer got an empty parameter list")
# NOTE: we only have one param_group and don't allow user to add additional

View file

@ -22,6 +22,7 @@ class _FunctionalAdamax(object):
betas: Tuple[float, float] = (0.9, 0.999),
eps: float = 1e-8,
weight_decay: float = 0.0,
_allow_empty_param_list: bool = False,
):
if not 0.0 <= lr:
raise ValueError("Invalid learning rate: {}".format(lr))
@ -43,7 +44,7 @@ class _FunctionalAdamax(object):
}
self.state = torch.jit.annotate(Dict[torch.Tensor, Dict[str, torch.Tensor]], {})
if len(params) == 0:
if len(params) == 0 and not _allow_empty_param_list:
raise ValueError("optimizer got an empty parameter list")
# NOTE: we only have one param_group and don't allow user to add additional

View file

@ -22,7 +22,8 @@ class _FunctionalAdamW(object):
betas: Tuple[float, float] = (0.9, 0.999),
eps: float = 1e-8,
weight_decay: float = 1e-2,
amsgrad: bool = False
amsgrad: bool = False,
_allow_empty_param_list: bool = False,
):
if not 0.0 <= lr:
raise ValueError("Invalid learning rate: {}".format(lr))
@ -45,7 +46,7 @@ class _FunctionalAdamW(object):
self.amsgrad = amsgrad
self.state = torch.jit.annotate(Dict[torch.Tensor, Dict[str, torch.Tensor]], {})
if len(params) == 0:
if len(params) == 0 and not _allow_empty_param_list:
raise ValueError("optimizer got an empty parameter list")
# NOTE: we only have one param_group and don't allow user to add additional

View file

@ -23,7 +23,8 @@ class _FunctionalRMSprop(object):
eps: float = 1e-8,
weight_decay: float = 0.0,
momentum: float = 0.0,
centered: bool = False
centered: bool = False,
_allow_empty_param_list: bool = False,
):
self.defaults = {
"lr": lr,
@ -34,7 +35,7 @@ class _FunctionalRMSprop(object):
}
self.centered = centered
if len(params) == 0:
if len(params) == 0 and not _allow_empty_param_list:
raise ValueError("optimizer got an empty parameter list")
# NOTE: we only have one param_group and don't allow user to add additional

View file

@ -20,7 +20,8 @@ class _FunctionalRprop(object):
params: List[Tensor],
lr: float = 1e-2,
etas: Tuple[float, float] = (0.5, 1.2),
step_sizes: Tuple[float, float] = (1e-6, 50)
step_sizes: Tuple[float, float] = (1e-6, 50),
_allow_empty_param_list: bool = False,
):
self.defaults = {
"lr": lr,
@ -28,7 +29,7 @@ class _FunctionalRprop(object):
self.etas = etas
self.step_sizes = step_sizes
if len(params) == 0:
if len(params) == 0 and not _allow_empty_param_list:
raise ValueError("optimizer got an empty parameter list")
# NOTE: we only have one param_group and don't allow user to add additional

View file

@ -23,7 +23,7 @@ class _FunctionalSGD(object):
dampening: float = 0.0,
weight_decay: float = 0.0,
nesterov: bool = False,
allow_empty_param_list: bool = False
_allow_empty_param_list: bool = False
):
self.defaults = {
"lr": lr,
@ -34,7 +34,7 @@ class _FunctionalSGD(object):
self.nesterov = nesterov
self.state = torch.jit.annotate(Dict[torch.Tensor, Dict[str, torch.Tensor]], {})
if len(params) == 0 and not allow_empty_param_list:
if len(params) == 0 and not _allow_empty_param_list:
raise ValueError("optimizer got an empty parameter list")
# NOTE: we only have one param_group and don't allow user to add additional