mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
Summary: This PR intends to address apaszke's concerns in https://github.com/pytorch/pytorch/pull/14253#issuecomment-441740016. Preserving the rng state is now controlled by a kwarg rather than a global state, hopefully in a python 2.7-compatible way. Additionally, the checkpointing function stashes and restores the RNG states of 1. devices associated with all input tensor args to run_fn as well as 2. the current device. I could easily change this to only save and restore the RNG states associated 1. alone. This would simplify the logic to create a [deduplicated, ordered](https://github.com/pytorch/pytorch/compare/master...mcarilli:checkpointing_rng_touchup?expand=1#diff-58da227fc9b1d56752b7dfad90428fe0R37) list of devices considered active. I'm wondering if the [get_device_states](https://github.com/pytorch/pytorch/compare/master...mcarilli:checkpointing_rng_touchup?expand=1#diff-58da227fc9b1d56752b7dfad90428fe0R32) and [set_device_states](https://github.com/pytorch/pytorch/compare/master...mcarilli:checkpointing_rng_touchup?expand=1#diff-58da227fc9b1d56752b7dfad90428fe0R47) functions are general enough to reside elsewhere (presumably torch/random.py). I'm also wondering if the check on [torch.cuda._initialized](https://github.com/pytorch/pytorch/compare/master...mcarilli:checkpointing_rng_touchup?expand=1#diff-58da227fc9b1d56752b7dfad90428fe0R47) would be better placed within `get_device_states`. Pull Request resolved: https://github.com/pytorch/pytorch/pull/14518 Differential Revision: D13356210 Pulled By: ezyang fbshipit-source-id: afa4cc21ce7862142d5cb1dec3750018df222039 |
||
|---|---|---|
| .. | ||
| caffe2 | ||
| cpp | ||
| source | ||
| libtorch.rst | ||
| make.bat | ||
| Makefile | ||
| requirements.txt | ||