From 8e27833e305539345c058d77b1d1b3419d402cd7 Mon Sep 17 00:00:00 2001 From: bskrlj Date: Thu, 31 Oct 2024 18:24:05 +0000 Subject: [PATCH] Ensure SWA boundary conditions w.r.t. definition (#133773) According to the documentation, decay is a number in [0,1] range,[ i.e.](https://pytorch.org/docs/stable/optim.html) ``` Decay is a parameter between 0 and 1 that controls how fast the averaged parameters are decayed. If not provided to get_ema_multi_avg_fn, the default is 0.999. ``` An inspection of `swa_utils.py` indicates there are no checks for invalid values of `decay`. Adding asserts as suggested in this PR ensures valid compute range (one way to enforce correct behavior, there are perhaps more suitable ones). Papers `torch` cites for reference idea/implementation also consider exclusively this range (e.g., https://arxiv.org/pdf/2310.04415). Fixes https://github.com/pytorch/pytorch/issues/133772 Pull Request resolved: https://github.com/pytorch/pytorch/pull/133773 Approved by: https://github.com/janeyx99 --- docs/source/optim.rst | 2 +- torch/optim/swa_utils.py | 10 ++++++++++ 2 files changed, 11 insertions(+), 1 deletion(-) diff --git a/docs/source/optim.rst b/docs/source/optim.rst index 35c23dacc8e..a5ae21b8358 100644 --- a/docs/source/optim.rst +++ b/docs/source/optim.rst @@ -520,7 +520,7 @@ EMA models are constructed by specifying the ``multi_avg_fn`` argument as follow >>> decay = 0.999 >>> averaged_model = AveragedModel(model, multi_avg_fn=get_ema_multi_avg_fn(decay)) -Decay is a parameter between 0 and 1 that controls how fast the averaged parameters are decayed. If not provided to :func:`torch.optim.swa_utils.get_ema_multi_avg_fn`, the default is 0.999. +Decay is a parameter between 0 and 1 that controls how fast the averaged parameters are decayed. If not provided to :func:`torch.optim.swa_utils.get_ema_multi_avg_fn`, the default is 0.999. Decay value should be close to 1.0, as smaller values can cause optimization convergence issues. :func:`torch.optim.swa_utils.get_ema_multi_avg_fn` returns a function that applies the following EMA equation to the weights: diff --git a/torch/optim/swa_utils.py b/torch/optim/swa_utils.py index 7bea0d355be..541da8d477c 100644 --- a/torch/optim/swa_utils.py +++ b/torch/optim/swa_utils.py @@ -34,6 +34,11 @@ PARAM_LIST = Union[Tuple[Tensor, ...], List[Tensor]] def get_ema_multi_avg_fn(decay=0.999): """Get the function applying exponential moving average (EMA) across multiple params.""" + if decay < 0.0 or decay > 1.0: + raise ValueError( + f"Invalid decay value {decay} provided. Please provide a value in [0,1] range." + ) + @torch.no_grad() def ema_update(ema_param_list: PARAM_LIST, current_param_list: PARAM_LIST, _): # foreach lerp only handles float and complex @@ -83,6 +88,11 @@ def get_swa_multi_avg_fn(): def get_ema_avg_fn(decay=0.999): """Get the function applying exponential moving average (EMA) across a single param.""" + if decay < 0.0 or decay > 1.0: + raise ValueError( + f"Invalid decay value {decay} provided. Please provide a value in [0,1] range." + ) + @torch.no_grad() def ema_update(ema_param: Tensor, current_param: Tensor, num_averaged): return decay * ema_param + (1 - decay) * current_param