mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
According to the documentation, decay is a number in [0,1] range,[ i.e.](https://pytorch.org/docs/stable/optim.html) ``` Decay is a parameter between 0 and 1 that controls how fast the averaged parameters are decayed. If not provided to get_ema_multi_avg_fn, the default is 0.999. ``` An inspection of `swa_utils.py` indicates there are no checks for invalid values of `decay`. Adding asserts as suggested in this PR ensures valid compute range (one way to enforce correct behavior, there are perhaps more suitable ones). Papers `torch` cites for reference idea/implementation also consider exclusively this range (e.g., https://arxiv.org/pdf/2310.04415). Fixes https://github.com/pytorch/pytorch/issues/133772 Pull Request resolved: https://github.com/pytorch/pytorch/pull/133773 Approved by: https://github.com/janeyx99 |
||
|---|---|---|
| .. | ||
| _multi_tensor | ||
| __init__.py | ||
| _adafactor.py | ||
| _functional.py | ||
| adadelta.py | ||
| adagrad.py | ||
| adam.py | ||
| adamax.py | ||
| adamw.py | ||
| asgd.py | ||
| lbfgs.py | ||
| lr_scheduler.py | ||
| nadam.py | ||
| optimizer.py | ||
| radam.py | ||
| rmsprop.py | ||
| rprop.py | ||
| sgd.py | ||
| sparse_adam.py | ||
| swa_utils.py | ||