mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
Ensure SWA boundary conditions w.r.t. definition (#133773)
According to the documentation, decay is a number in [0,1] range,[ i.e.](https://pytorch.org/docs/stable/optim.html) ``` Decay is a parameter between 0 and 1 that controls how fast the averaged parameters are decayed. If not provided to get_ema_multi_avg_fn, the default is 0.999. ``` An inspection of `swa_utils.py` indicates there are no checks for invalid values of `decay`. Adding asserts as suggested in this PR ensures valid compute range (one way to enforce correct behavior, there are perhaps more suitable ones). Papers `torch` cites for reference idea/implementation also consider exclusively this range (e.g., https://arxiv.org/pdf/2310.04415). Fixes https://github.com/pytorch/pytorch/issues/133772 Pull Request resolved: https://github.com/pytorch/pytorch/pull/133773 Approved by: https://github.com/janeyx99
This commit is contained in:
parent
547d921462
commit
8e27833e30
2 changed files with 11 additions and 1 deletions
|
|
@ -520,7 +520,7 @@ EMA models are constructed by specifying the ``multi_avg_fn`` argument as follow
|
|||
>>> decay = 0.999
|
||||
>>> averaged_model = AveragedModel(model, multi_avg_fn=get_ema_multi_avg_fn(decay))
|
||||
|
||||
Decay is a parameter between 0 and 1 that controls how fast the averaged parameters are decayed. If not provided to :func:`torch.optim.swa_utils.get_ema_multi_avg_fn`, the default is 0.999.
|
||||
Decay is a parameter between 0 and 1 that controls how fast the averaged parameters are decayed. If not provided to :func:`torch.optim.swa_utils.get_ema_multi_avg_fn`, the default is 0.999. Decay value should be close to 1.0, as smaller values can cause optimization convergence issues.
|
||||
|
||||
:func:`torch.optim.swa_utils.get_ema_multi_avg_fn` returns a function that applies the following EMA equation to the weights:
|
||||
|
||||
|
|
|
|||
|
|
@ -34,6 +34,11 @@ PARAM_LIST = Union[Tuple[Tensor, ...], List[Tensor]]
|
|||
def get_ema_multi_avg_fn(decay=0.999):
|
||||
"""Get the function applying exponential moving average (EMA) across multiple params."""
|
||||
|
||||
if decay < 0.0 or decay > 1.0:
|
||||
raise ValueError(
|
||||
f"Invalid decay value {decay} provided. Please provide a value in [0,1] range."
|
||||
)
|
||||
|
||||
@torch.no_grad()
|
||||
def ema_update(ema_param_list: PARAM_LIST, current_param_list: PARAM_LIST, _):
|
||||
# foreach lerp only handles float and complex
|
||||
|
|
@ -83,6 +88,11 @@ def get_swa_multi_avg_fn():
|
|||
def get_ema_avg_fn(decay=0.999):
|
||||
"""Get the function applying exponential moving average (EMA) across a single param."""
|
||||
|
||||
if decay < 0.0 or decay > 1.0:
|
||||
raise ValueError(
|
||||
f"Invalid decay value {decay} provided. Please provide a value in [0,1] range."
|
||||
)
|
||||
|
||||
@torch.no_grad()
|
||||
def ema_update(ema_param: Tensor, current_param: Tensor, num_averaged):
|
||||
return decay * ema_param + (1 - decay) * current_param
|
||||
|
|
|
|||
Loading…
Reference in a new issue