mirror of
https://github.com/saymrwulf/stable-baselines3.git
synced 2026-05-23 22:20:18 +00:00
* Added a 'device' keyword argument to BaseAlgorithm.load(). Edited the save and load test to also test the load method with all possible devices. Added the changes to the changelog * improved the load test to ensure that the model loads to the correct device. * improved the test: now the correctness is improved. If the get_device policy would change, it wouldn't break the test. * Update tests/test_save_load.py @araffin's suggestion during the PR process Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org> * Update tests/test_save_load.py Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org> * Bug fixes: when comparing devices, comparing only device type since get_device() doesn't provide device index. Now the code loads all of the model parameters from the saved state dict straight into the required device. (fixed load_from_zip_file). * PR fixes: bug fix - a non-related test failed when running on GPU. updated the assertion to consider only types of devices. Also corrected a related bug in 'get_device()' method. * Update changelog.rst Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org> |
||
|---|---|---|
| .. | ||
| sb2_compat | ||
| vec_env | ||
| __init__.py | ||
| atari_wrappers.py | ||
| base_class.py | ||
| bit_flipping_env.py | ||
| buffers.py | ||
| callbacks.py | ||
| cmd_util.py | ||
| distributions.py | ||
| env_checker.py | ||
| evaluation.py | ||
| identity_env.py | ||
| logger.py | ||
| monitor.py | ||
| noise.py | ||
| off_policy_algorithm.py | ||
| on_policy_algorithm.py | ||
| policies.py | ||
| preprocessing.py | ||
| results_plotter.py | ||
| running_mean_std.py | ||
| save_util.py | ||
| torch_layers.py | ||
| type_aliases.py | ||
| utils.py | ||