From f5104a5efca7d1aefbbbb86a41a19c156f6abb2d Mon Sep 17 00:00:00 2001 From: liorcohen5 Date: Sun, 20 Sep 2020 20:13:18 +0300 Subject: [PATCH] Allow to set a device when loading a model (#154) * Added a 'device' keyword argument to BaseAlgorithm.load(). Edited the save and load test to also test the load method with all possible devices. Added the changes to the changelog * improved the load test to ensure that the model loads to the correct device. * improved the test: now the correctness is improved. If the get_device policy would change, it wouldn't break the test. * Update tests/test_save_load.py @araffin's suggestion during the PR process Co-authored-by: Antonin RAFFIN * Update tests/test_save_load.py Co-authored-by: Antonin RAFFIN * Bug fixes: when comparing devices, comparing only device type since get_device() doesn't provide device index. Now the code loads all of the model parameters from the saved state dict straight into the required device. (fixed load_from_zip_file). * PR fixes: bug fix - a non-related test failed when running on GPU. updated the assertion to consider only types of devices. Also corrected a related bug in 'get_device()' method. * Update changelog.rst Co-authored-by: Antonin RAFFIN --- docs/misc/changelog.rst | 4 ++- stable_baselines3/common/base_class.py | 9 ++++--- stable_baselines3/common/save_util.py | 4 ++- stable_baselines3/common/utils.py | 2 +- tests/test_predict.py | 2 +- tests/test_save_load.py | 35 ++++++++++++++++++-------- 6 files changed, 38 insertions(+), 18 deletions(-) diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index fbc4900..cf95624 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -14,12 +14,14 @@ New Features: ^^^^^^^^^^^^^ - Added ``unwrap_vec_wrapper()`` to ``common.vec_env`` to extract ``VecEnvWrapper`` if needed - Added ``StopTrainingOnMaxEpisodes`` to callback collection (@xicocaio) +- Added ``device`` keyword argument to ``BaseAlgorithm.load()`` (@liorcohen5) - Callbacks have access to rollout collection locals as in SB2. (@PartiallyTyped) Bug Fixes: ^^^^^^^^^^ - Fixed a bug where the environment was reset twice when using ``evaluate_policy`` - Fix logging of ``clip_fraction`` in PPO (@diditforlulz273) +- Fixed a bug where cuda support was wrongly checked when passing the GPU index, e.g., ``device="cuda:0"`` (@liorcohen5) Deprecations: ^^^^^^^^^^^^^ @@ -399,4 +401,4 @@ And all the contributors: @MarvineGothic @jdossgollin @SyllogismRXS @rusu24edward @jbulow @Antymon @seheevic @justinkterry @edbeeching @flodorner @KuKuXia @NeoExtended @PartiallyTyped @mmcenta @richardwu @kinalmehta @rolandgvc @tkelestemur @mloo3 @tirafesi @blurLake @koulakis @joeljosephjin @shwang @rk37 @andyshih12 @RaphaelWag @xicocaio -@diditforlulz273 +@diditforlulz273 @liorcohen5 diff --git a/stable_baselines3/common/base_class.py b/stable_baselines3/common/base_class.py index a3de8ce..5c87bca 100644 --- a/stable_baselines3/common/base_class.py +++ b/stable_baselines3/common/base_class.py @@ -316,16 +316,19 @@ class BaseAlgorithm(ABC): return self.policy.predict(observation, state, mask, deterministic) @classmethod - def load(cls, load_path: str, env: Optional[GymEnv] = None, **kwargs) -> "BaseAlgorithm": + def load( + cls, load_path: str, env: Optional[GymEnv] = None, device: Union[th.device, str] = "auto", **kwargs + ) -> "BaseAlgorithm": """ Load the model from a zip-file :param load_path: the location of the saved data :param env: the new environment to run the loaded model on (can be None if you only need prediction from a trained model) has priority over any saved environment + :param device: (Union[th.device, str]) Device on which the code should run. :param kwargs: extra arguments to change the model when loading """ - data, params, tensors = load_from_zip_file(load_path) + data, params, tensors = load_from_zip_file(load_path, device=device) if "policy_kwargs" in data: for arg_to_remove in ["device"]: @@ -352,7 +355,7 @@ class BaseAlgorithm(ABC): model = cls( policy=data["policy_class"], env=env, - device="auto", + device=device, _init_setup_model=False, # pytype: disable=not-instantiable,wrong-keyword-args ) diff --git a/stable_baselines3/common/save_util.py b/stable_baselines3/common/save_util.py index 326db1e..0decdc9 100644 --- a/stable_baselines3/common/save_util.py +++ b/stable_baselines3/common/save_util.py @@ -352,6 +352,7 @@ def load_from_pkl(path: Union[str, pathlib.Path, io.BufferedIOBase], verbose=0) def load_from_zip_file( load_path: Union[str, pathlib.Path, io.BufferedIOBase], load_data: bool = True, + device: Union[th.device, str] = "auto", verbose=0, ) -> (Tuple[Optional[Dict[str, Any]], Optional[TensorDict], Optional[TensorDict]]): """ @@ -360,13 +361,14 @@ def load_from_zip_file( :param load_path: (str, pathlib.Path, io.BufferedIOBase) Where to load the model from :param load_data: Whether we should load and return data (class parameters). Mainly used by 'load_parameters' to only load model parameters (weights) + :param device: (Union[th.device, str]) Device on which the code should run. :return: (dict),(dict),(dict) Class parameters, model state_dicts (dict of state_dict) and dict of extra tensors """ load_path = open_path(load_path, "r", verbose=verbose, suffix="zip") # set device to cpu if cuda is not available - device = get_device() + device = get_device(device=device) # Open the zip archive and load data try: diff --git a/stable_baselines3/common/utils.py b/stable_baselines3/common/utils.py index 1c1b850..aeab67b 100644 --- a/stable_baselines3/common/utils.py +++ b/stable_baselines3/common/utils.py @@ -145,7 +145,7 @@ def get_device(device: Union[th.device, str] = "auto") -> th.device: device = th.device(device) # Cuda not available - if device == th.device("cuda") and not th.cuda.is_available(): + if device.type == th.device("cuda").type and not th.cuda.is_available(): return th.device("cpu") return device diff --git a/tests/test_predict.py b/tests/test_predict.py index 288ab0d..d7bb0e4 100644 --- a/tests/test_predict.py +++ b/tests/test_predict.py @@ -46,7 +46,7 @@ def test_predict(model_class, env_id, device): # Test detection of different shapes by the predict method model = model_class("MlpPolicy", env_id, device=device) # Check that the policy is on the right device - assert get_device(device) == model.policy.device + assert get_device(device).type == model.policy.device.type env = gym.make(env_id) vec_env = DummyVecEnv([lambda: gym.make(env_id), lambda: gym.make(env_id)]) diff --git a/tests/test_save_load.py b/tests/test_save_load.py index d56c40d..8ebeabd 100644 --- a/tests/test_save_load.py +++ b/tests/test_save_load.py @@ -13,6 +13,7 @@ from stable_baselines3 import A2C, DDPG, DQN, PPO, SAC, TD3 from stable_baselines3.common.base_class import BaseAlgorithm from stable_baselines3.common.identity_env import FakeImageEnv, IdentityEnv, IdentityEnvBox from stable_baselines3.common.save_util import load_from_pkl, open_path, save_to_pkl +from stable_baselines3.common.utils import get_device from stable_baselines3.common.vec_env import DummyVecEnv MODEL_LIST = [PPO, A2C, TD3, SAC, DQN, DDPG] @@ -70,21 +71,33 @@ def test_save_load(tmp_path, model_class): # Check model.save(tmp_path / "test_save.zip") del model - model = model_class.load(str(tmp_path / "test_save.zip"), env=env) - # check if params are still the same after load - new_params = model.policy.state_dict() + # Check if the model loads as expected for every possible choice of device: + for device in ["auto", "cpu", "cuda"]: + model = model_class.load(str(tmp_path / "test_save.zip"), env=env, device=device) - # Check that all params are the same as before save load procedure now - for key in params: - assert th.allclose(params[key], new_params[key]), "Model parameters not the same after save and load." + # check if the model was loaded to the correct device + assert model.device.type == get_device(device).type + assert model.policy.device.type == get_device(device).type - # check if model still selects the same actions - new_selected_actions, _ = model.predict(observations, deterministic=True) - assert np.allclose(selected_actions, new_selected_actions, 1e-4) + # check if params are still the same after load + new_params = model.policy.state_dict() - # check if learn still works - model.learn(total_timesteps=1000, eval_freq=500) + # Check that all params are the same as before save load procedure now + for key in params: + assert new_params[key].device.type == get_device(device).type + assert th.allclose( + params[key].to("cpu"), new_params[key].to("cpu") + ), "Model parameters not the same after save and load." + + # check if model still selects the same actions + new_selected_actions, _ = model.predict(observations, deterministic=True) + assert np.allclose(selected_actions, new_selected_actions, 1e-4) + + # check if learn still works + model.learn(total_timesteps=1000, eval_freq=500) + + del model # clear file from os os.remove(tmp_path / "test_save.zip")