stable-baselines3/stable_baselines3/common
liorcohen5 f5104a5efc
Allow to set a device when loading a model (#154)
* Added a 'device' keyword argument to BaseAlgorithm.load().
Edited the save and load test to also test the load method with all possible devices.
Added the changes to the changelog

* improved the load test to ensure that the model loads to the correct device.

* improved the test: now the correctness is improved. If the get_device policy would change, it wouldn't break the test.

* Update tests/test_save_load.py

@araffin's suggestion during the PR process

Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org>

* Update tests/test_save_load.py

Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org>

* Bug fixes: when comparing devices, comparing only device type since get_device() doesn't provide device index.
Now the code loads all of the model parameters from the saved state dict straight into the required device. (fixed load_from_zip_file).

* PR fixes: bug fix - a non-related test failed when running on GPU. updated the assertion to consider only types of devices. Also corrected a related bug in 'get_device()' method.

* Update changelog.rst

Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org>
2020-09-20 19:13:18 +02:00
..
sb2_compat Match performance with stable-baselines (discrete case) (#110) 2020-08-03 22:22:51 +02:00
vec_env Fix double reset and improve typing coverage (#136) 2020-08-05 13:12:02 +03:00
__init__.py Auto-formatting with black and isort (#97) 2020-07-16 16:12:16 +02:00
atari_wrappers.py Auto-formatting with black and isort (#97) 2020-07-16 16:12:16 +02:00
base_class.py Allow to set a device when loading a model (#154) 2020-09-20 19:13:18 +02:00
bit_flipping_env.py Auto-formatting with black and isort (#97) 2020-07-16 16:12:16 +02:00
buffers.py Update black version + update docker image (#151) 2020-08-27 23:02:59 +02:00
callbacks.py Fix f-string in max episodes callback (#152) 2020-08-29 20:04:19 +02:00
cmd_util.py Auto-formatting with black and isort (#97) 2020-07-16 16:12:16 +02:00
distributions.py Auto-formatting with black and isort (#97) 2020-07-16 16:12:16 +02:00
env_checker.py Fix double reset and improve typing coverage (#136) 2020-08-05 13:12:02 +03:00
evaluation.py Fix double reset and improve typing coverage (#136) 2020-08-05 13:12:02 +03:00
identity_env.py Auto-formatting with black and isort (#97) 2020-07-16 16:12:16 +02:00
logger.py Auto-formatting with black and isort (#97) 2020-07-16 16:12:16 +02:00
monitor.py Auto-formatting with black and isort (#97) 2020-07-16 16:12:16 +02:00
noise.py Auto-formatting with black and isort (#97) 2020-07-16 16:12:16 +02:00
off_policy_algorithm.py Add StopTrainingOnMaxEpisodes to callback collection (#147) 2020-08-28 11:36:33 +02:00
on_policy_algorithm.py Add StopTrainingOnMaxEpisodes to callback collection (#147) 2020-08-28 11:36:33 +02:00
policies.py Update black version + update docker image (#151) 2020-08-27 23:02:59 +02:00
preprocessing.py Auto-formatting with black and isort (#97) 2020-07-16 16:12:16 +02:00
results_plotter.py Auto-formatting with black and isort (#97) 2020-07-16 16:12:16 +02:00
running_mean_std.py Auto-formatting with black and isort (#97) 2020-07-16 16:12:16 +02:00
save_util.py Allow to set a device when loading a model (#154) 2020-09-20 19:13:18 +02:00
torch_layers.py Match performance with stable-baselines (discrete case) (#110) 2020-08-03 22:22:51 +02:00
type_aliases.py Auto-formatting with black and isort (#97) 2020-07-16 16:12:16 +02:00
utils.py Allow to set a device when loading a model (#154) 2020-09-20 19:13:18 +02:00