stable-baselines3/stable_baselines3
liorcohen5 f5104a5efc
Allow to set a device when loading a model (#154)
* Added a 'device' keyword argument to BaseAlgorithm.load().
Edited the save and load test to also test the load method with all possible devices.
Added the changes to the changelog

* improved the load test to ensure that the model loads to the correct device.

* improved the test: now the correctness is improved. If the get_device policy would change, it wouldn't break the test.

* Update tests/test_save_load.py

@araffin's suggestion during the PR process

Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org>

* Update tests/test_save_load.py

Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org>

* Bug fixes: when comparing devices, comparing only device type since get_device() doesn't provide device index.
Now the code loads all of the model parameters from the saved state dict straight into the required device. (fixed load_from_zip_file).

* PR fixes: bug fix - a non-related test failed when running on GPU. updated the assertion to consider only types of devices. Also corrected a related bug in 'get_device()' method.

* Update changelog.rst

Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org>
2020-09-20 19:13:18 +02:00
..
a2c Match performance with stable-baselines (discrete case) (#110) 2020-08-03 22:22:51 +02:00
common Allow to set a device when loading a model (#154) 2020-09-20 19:13:18 +02:00
ddpg Auto-formatting with black and isort (#97) 2020-07-16 16:12:16 +02:00
dqn Update black version + update docker image (#151) 2020-08-27 23:02:59 +02:00
ppo Fix PPO logging of clip_fractions (#150) 2020-09-01 09:52:31 +02:00
sac Fix typos in SAC and TD3 (#145) 2020-08-23 17:44:35 +02:00
td3 Fix typos in SAC and TD3 (#145) 2020-08-23 17:44:35 +02:00
__init__.py Auto-formatting with black and isort (#97) 2020-07-16 16:12:16 +02:00
py.typed Rename to stable-baselines3 2020-05-05 15:02:35 +02:00
version.txt Remove "device" argument from policies (#141) 2020-08-23 13:27:52 +02:00