Ensure SWA boundary conditions w.r.t. definition (#133773)

According to the documentation, decay is a number in [0,1] range,[ i.e.](https://pytorch.org/docs/stable/optim.html)
```
Decay is a parameter between 0 and 1 that controls how fast the averaged parameters are decayed. If not provided to get_ema_multi_avg_fn, the default is 0.999.
```
An inspection of `swa_utils.py`  indicates there are no checks for invalid values of `decay`. Adding asserts as suggested in this PR ensures valid compute range (one way to enforce correct behavior, there are perhaps more suitable ones). Papers `torch` cites for reference idea/implementation also consider exclusively this range (e.g., https://arxiv.org/pdf/2310.04415).

Fixes https://github.com/pytorch/pytorch/issues/133772

Pull Request resolved: https://github.com/pytorch/pytorch/pull/133773
Approved by: https://github.com/janeyx99
This commit is contained in:
bskrlj 2024-10-31 18:24:05 +00:00 committed by PyTorch MergeBot
parent 547d921462
commit 8e27833e30
2 changed files with 11 additions and 1 deletions

View file

@ -520,7 +520,7 @@ EMA models are constructed by specifying the ``multi_avg_fn`` argument as follow
>>> decay = 0.999
>>> averaged_model = AveragedModel(model, multi_avg_fn=get_ema_multi_avg_fn(decay))
Decay is a parameter between 0 and 1 that controls how fast the averaged parameters are decayed. If not provided to :func:`torch.optim.swa_utils.get_ema_multi_avg_fn`, the default is 0.999.
Decay is a parameter between 0 and 1 that controls how fast the averaged parameters are decayed. If not provided to :func:`torch.optim.swa_utils.get_ema_multi_avg_fn`, the default is 0.999. Decay value should be close to 1.0, as smaller values can cause optimization convergence issues.
:func:`torch.optim.swa_utils.get_ema_multi_avg_fn` returns a function that applies the following EMA equation to the weights:

View file

@ -34,6 +34,11 @@ PARAM_LIST = Union[Tuple[Tensor, ...], List[Tensor]]
def get_ema_multi_avg_fn(decay=0.999):
"""Get the function applying exponential moving average (EMA) across multiple params."""
if decay < 0.0 or decay > 1.0:
raise ValueError(
f"Invalid decay value {decay} provided. Please provide a value in [0,1] range."
)
@torch.no_grad()
def ema_update(ema_param_list: PARAM_LIST, current_param_list: PARAM_LIST, _):
# foreach lerp only handles float and complex
@ -83,6 +88,11 @@ def get_swa_multi_avg_fn():
def get_ema_avg_fn(decay=0.999):
"""Get the function applying exponential moving average (EMA) across a single param."""
if decay < 0.0 or decay > 1.0:
raise ValueError(
f"Invalid decay value {decay} provided. Please provide a value in [0,1] range."
)
@torch.no_grad()
def ema_update(ema_param: Tensor, current_param: Tensor, num_averaged):
return decay * ema_param + (1 - decay) * current_param