mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
Adding maximize to rprop (#81864)
Added the maximize flag #68052 to rprop optimizer and updates the respective tests. Pull Request resolved: https://github.com/pytorch/pytorch/pull/81864 Approved by: https://github.com/albanD
This commit is contained in:
parent
a8941aa996
commit
ff75562cff
3 changed files with 30 additions and 11 deletions
|
|
@ -21,7 +21,7 @@ from torch.optim.lr_scheduler import LambdaLR, MultiplicativeLR, SequentialLR, S
|
|||
EPOCH_DEPRECATION_WARNING
|
||||
from torch.optim.swa_utils import AveragedModel, SWALR, update_bn
|
||||
from torch.testing._internal.common_utils import TestCase, run_tests, TEST_WITH_UBSAN, load_tests, \
|
||||
parametrize, instantiate_parametrized_tests, gradcheck
|
||||
parametrize, instantiate_parametrized_tests, gradcheck, skipIfRocm
|
||||
# load_tests from common_utils is used to automatically filter tests for
|
||||
# sharding on sandcastle. This line silences flake warnings
|
||||
load_tests = load_tests
|
||||
|
|
@ -895,15 +895,18 @@ class TestOptim(TestCase):
|
|||
with self.assertRaisesRegex(ValueError, "Invalid weight_decay value: -0.5"):
|
||||
optimizer(None, lr=1e-2, weight_decay=-0.5)
|
||||
|
||||
@skipIfRocm
|
||||
def test_rprop(self):
|
||||
for optimizer in [optim.Rprop, optim_mt.Rprop]:
|
||||
self._test_basic_cases(
|
||||
lambda weight, bias: optimizer([weight, bias], lr=1e-3)
|
||||
lambda weight, bias, maximize: optimizer([weight, bias], lr=2e-4, maximize=maximize),
|
||||
constructor_accepts_maximize=True
|
||||
)
|
||||
self._test_basic_cases(
|
||||
lambda weight, bias: optimizer(
|
||||
lambda weight, bias, maximize: optimizer(
|
||||
self._build_params_dict(weight, bias, lr=1e-2),
|
||||
lr=1e-3)
|
||||
lr=2e-4, maximize=maximize),
|
||||
constructor_accepts_maximize=True
|
||||
)
|
||||
with self.assertRaisesRegex(ValueError, "Invalid eta values: 1.0, 0.5"):
|
||||
optimizer(None, lr=1e-2, etas=(1.0, 0.5))
|
||||
|
|
|
|||
|
|
@ -24,6 +24,7 @@ class _FunctionalRprop(object):
|
|||
etas: Tuple[float, float] = (0.5, 1.2),
|
||||
step_sizes: Tuple[float, float] = (1e-6, 50),
|
||||
foreach: bool = False,
|
||||
maximize: bool = False,
|
||||
_allow_empty_param_list: bool = False,
|
||||
):
|
||||
self.defaults = {
|
||||
|
|
@ -32,6 +33,7 @@ class _FunctionalRprop(object):
|
|||
self.etas = etas
|
||||
self.step_sizes = step_sizes
|
||||
self.foreach = foreach
|
||||
self.maximize = maximize
|
||||
|
||||
if len(params) == 0 and not _allow_empty_param_list:
|
||||
raise ValueError("optimizer got an empty parameter list")
|
||||
|
|
@ -86,4 +88,5 @@ class _FunctionalRprop(object):
|
|||
step_size_max=step_size_max,
|
||||
etaminus=etaminus,
|
||||
etaplus=etaplus,
|
||||
foreach=self.foreach)
|
||||
foreach=self.foreach,
|
||||
maximize=self.maximize)
|
||||
|
|
|
|||
|
|
@ -52,22 +52,25 @@ class Rprop(Optimizer):
|
|||
maximal allowed step sizes (default: (1e-6, 50))
|
||||
foreach (bool, optional): whether foreach implementation of optimizer
|
||||
is used (default: None)
|
||||
maximize (bool, optional): maximize the params based on the objective, instead of
|
||||
minimizing (default: False)
|
||||
"""
|
||||
|
||||
def __init__(self, params, lr=1e-2, etas=(0.5, 1.2), step_sizes=(1e-6, 50),
|
||||
foreach: Optional[bool] = None):
|
||||
foreach: Optional[bool] = None, maximize: bool = False):
|
||||
if not 0.0 <= lr:
|
||||
raise ValueError("Invalid learning rate: {}".format(lr))
|
||||
if not 0.0 < etas[0] < 1.0 < etas[1]:
|
||||
raise ValueError("Invalid eta values: {}, {}".format(etas[0], etas[1]))
|
||||
|
||||
defaults = dict(lr=lr, etas=etas, step_sizes=step_sizes, foreach=foreach)
|
||||
defaults = dict(lr=lr, etas=etas, step_sizes=step_sizes, foreach=foreach, maximize=maximize)
|
||||
super(Rprop, self).__init__(params, defaults)
|
||||
|
||||
def __setstate__(self, state):
|
||||
super().__setstate__(state)
|
||||
for group in self.param_groups:
|
||||
group.setdefault('foreach', None)
|
||||
group.setdefault('maximize', False)
|
||||
|
||||
@torch.no_grad()
|
||||
def step(self, closure=None):
|
||||
|
|
@ -90,6 +93,7 @@ class Rprop(Optimizer):
|
|||
etaminus, etaplus = group['etas']
|
||||
step_size_min, step_size_max = group['step_sizes']
|
||||
foreach = group['foreach']
|
||||
maximize = group['maximize']
|
||||
|
||||
for p in group['params']:
|
||||
if p.grad is None:
|
||||
|
|
@ -121,7 +125,8 @@ class Rprop(Optimizer):
|
|||
step_size_max=step_size_max,
|
||||
etaminus=etaminus,
|
||||
etaplus=etaplus,
|
||||
foreach=foreach)
|
||||
foreach=foreach,
|
||||
maximize=maximize)
|
||||
|
||||
return loss
|
||||
|
||||
|
|
@ -133,6 +138,7 @@ def rprop(params: List[Tensor],
|
|||
# kwonly args with defaults are not supported by functions compiled with torchscript issue #70627
|
||||
# setting this as kwarg for now as functional API is compiled by torch/distributed/optim
|
||||
foreach: bool = None,
|
||||
maximize: bool = False,
|
||||
*,
|
||||
step_size_min: float,
|
||||
step_size_max: float,
|
||||
|
|
@ -162,7 +168,8 @@ def rprop(params: List[Tensor],
|
|||
step_size_min=step_size_min,
|
||||
step_size_max=step_size_max,
|
||||
etaminus=etaminus,
|
||||
etaplus=etaplus)
|
||||
etaplus=etaplus,
|
||||
maximize=maximize)
|
||||
|
||||
|
||||
def _single_tensor_rprop(params: List[Tensor],
|
||||
|
|
@ -173,10 +180,12 @@ def _single_tensor_rprop(params: List[Tensor],
|
|||
step_size_min: float,
|
||||
step_size_max: float,
|
||||
etaminus: float,
|
||||
etaplus: float):
|
||||
etaplus: float,
|
||||
maximize: bool):
|
||||
|
||||
for i, param in enumerate(params):
|
||||
grad = grads[i]
|
||||
grad = grad if not maximize else -grad
|
||||
prev = prevs[i]
|
||||
step_size = step_sizes[i]
|
||||
|
||||
|
|
@ -207,11 +216,15 @@ def _multi_tensor_rprop(params: List[Tensor],
|
|||
step_size_min: float,
|
||||
step_size_max: float,
|
||||
etaminus: float,
|
||||
etaplus: float):
|
||||
etaplus: float,
|
||||
maximize: bool):
|
||||
|
||||
if len(params) == 0:
|
||||
return
|
||||
|
||||
if maximize:
|
||||
torch._foreach_neg_(grads)
|
||||
|
||||
signs = torch._foreach_mul(grads, prevs)
|
||||
signs = [s.sign() for s in signs]
|
||||
for sign in signs:
|
||||
|
|
|
|||
Loading…
Reference in a new issue