mirror of
https://github.com/saymrwulf/stable-baselines3.git
synced 2026-05-18 21:30:19 +00:00
Allow to set a device when loading a model (#154)
* Added a 'device' keyword argument to BaseAlgorithm.load(). Edited the save and load test to also test the load method with all possible devices. Added the changes to the changelog * improved the load test to ensure that the model loads to the correct device. * improved the test: now the correctness is improved. If the get_device policy would change, it wouldn't break the test. * Update tests/test_save_load.py @araffin's suggestion during the PR process Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org> * Update tests/test_save_load.py Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org> * Bug fixes: when comparing devices, comparing only device type since get_device() doesn't provide device index. Now the code loads all of the model parameters from the saved state dict straight into the required device. (fixed load_from_zip_file). * PR fixes: bug fix - a non-related test failed when running on GPU. updated the assertion to consider only types of devices. Also corrected a related bug in 'get_device()' method. * Update changelog.rst Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org>
This commit is contained in:
parent
583d4b8e41
commit
f5104a5efc
6 changed files with 38 additions and 18 deletions
|
|
@ -14,12 +14,14 @@ New Features:
|
|||
^^^^^^^^^^^^^
|
||||
- Added ``unwrap_vec_wrapper()`` to ``common.vec_env`` to extract ``VecEnvWrapper`` if needed
|
||||
- Added ``StopTrainingOnMaxEpisodes`` to callback collection (@xicocaio)
|
||||
- Added ``device`` keyword argument to ``BaseAlgorithm.load()`` (@liorcohen5)
|
||||
- Callbacks have access to rollout collection locals as in SB2. (@PartiallyTyped)
|
||||
|
||||
Bug Fixes:
|
||||
^^^^^^^^^^
|
||||
- Fixed a bug where the environment was reset twice when using ``evaluate_policy``
|
||||
- Fix logging of ``clip_fraction`` in PPO (@diditforlulz273)
|
||||
- Fixed a bug where cuda support was wrongly checked when passing the GPU index, e.g., ``device="cuda:0"`` (@liorcohen5)
|
||||
|
||||
Deprecations:
|
||||
^^^^^^^^^^^^^
|
||||
|
|
@ -399,4 +401,4 @@ And all the contributors:
|
|||
@MarvineGothic @jdossgollin @SyllogismRXS @rusu24edward @jbulow @Antymon @seheevic @justinkterry @edbeeching
|
||||
@flodorner @KuKuXia @NeoExtended @PartiallyTyped @mmcenta @richardwu @kinalmehta @rolandgvc @tkelestemur @mloo3
|
||||
@tirafesi @blurLake @koulakis @joeljosephjin @shwang @rk37 @andyshih12 @RaphaelWag @xicocaio
|
||||
@diditforlulz273
|
||||
@diditforlulz273 @liorcohen5
|
||||
|
|
|
|||
|
|
@ -316,16 +316,19 @@ class BaseAlgorithm(ABC):
|
|||
return self.policy.predict(observation, state, mask, deterministic)
|
||||
|
||||
@classmethod
|
||||
def load(cls, load_path: str, env: Optional[GymEnv] = None, **kwargs) -> "BaseAlgorithm":
|
||||
def load(
|
||||
cls, load_path: str, env: Optional[GymEnv] = None, device: Union[th.device, str] = "auto", **kwargs
|
||||
) -> "BaseAlgorithm":
|
||||
"""
|
||||
Load the model from a zip-file
|
||||
|
||||
:param load_path: the location of the saved data
|
||||
:param env: the new environment to run the loaded model on
|
||||
(can be None if you only need prediction from a trained model) has priority over any saved environment
|
||||
:param device: (Union[th.device, str]) Device on which the code should run.
|
||||
:param kwargs: extra arguments to change the model when loading
|
||||
"""
|
||||
data, params, tensors = load_from_zip_file(load_path)
|
||||
data, params, tensors = load_from_zip_file(load_path, device=device)
|
||||
|
||||
if "policy_kwargs" in data:
|
||||
for arg_to_remove in ["device"]:
|
||||
|
|
@ -352,7 +355,7 @@ class BaseAlgorithm(ABC):
|
|||
model = cls(
|
||||
policy=data["policy_class"],
|
||||
env=env,
|
||||
device="auto",
|
||||
device=device,
|
||||
_init_setup_model=False, # pytype: disable=not-instantiable,wrong-keyword-args
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -352,6 +352,7 @@ def load_from_pkl(path: Union[str, pathlib.Path, io.BufferedIOBase], verbose=0)
|
|||
def load_from_zip_file(
|
||||
load_path: Union[str, pathlib.Path, io.BufferedIOBase],
|
||||
load_data: bool = True,
|
||||
device: Union[th.device, str] = "auto",
|
||||
verbose=0,
|
||||
) -> (Tuple[Optional[Dict[str, Any]], Optional[TensorDict], Optional[TensorDict]]):
|
||||
"""
|
||||
|
|
@ -360,13 +361,14 @@ def load_from_zip_file(
|
|||
:param load_path: (str, pathlib.Path, io.BufferedIOBase) Where to load the model from
|
||||
:param load_data: Whether we should load and return data
|
||||
(class parameters). Mainly used by 'load_parameters' to only load model parameters (weights)
|
||||
:param device: (Union[th.device, str]) Device on which the code should run.
|
||||
:return: (dict),(dict),(dict) Class parameters, model state_dicts (dict of state_dict)
|
||||
and dict of extra tensors
|
||||
"""
|
||||
load_path = open_path(load_path, "r", verbose=verbose, suffix="zip")
|
||||
|
||||
# set device to cpu if cuda is not available
|
||||
device = get_device()
|
||||
device = get_device(device=device)
|
||||
|
||||
# Open the zip archive and load data
|
||||
try:
|
||||
|
|
|
|||
|
|
@ -145,7 +145,7 @@ def get_device(device: Union[th.device, str] = "auto") -> th.device:
|
|||
device = th.device(device)
|
||||
|
||||
# Cuda not available
|
||||
if device == th.device("cuda") and not th.cuda.is_available():
|
||||
if device.type == th.device("cuda").type and not th.cuda.is_available():
|
||||
return th.device("cpu")
|
||||
|
||||
return device
|
||||
|
|
|
|||
|
|
@ -46,7 +46,7 @@ def test_predict(model_class, env_id, device):
|
|||
# Test detection of different shapes by the predict method
|
||||
model = model_class("MlpPolicy", env_id, device=device)
|
||||
# Check that the policy is on the right device
|
||||
assert get_device(device) == model.policy.device
|
||||
assert get_device(device).type == model.policy.device.type
|
||||
|
||||
env = gym.make(env_id)
|
||||
vec_env = DummyVecEnv([lambda: gym.make(env_id), lambda: gym.make(env_id)])
|
||||
|
|
|
|||
|
|
@ -13,6 +13,7 @@ from stable_baselines3 import A2C, DDPG, DQN, PPO, SAC, TD3
|
|||
from stable_baselines3.common.base_class import BaseAlgorithm
|
||||
from stable_baselines3.common.identity_env import FakeImageEnv, IdentityEnv, IdentityEnvBox
|
||||
from stable_baselines3.common.save_util import load_from_pkl, open_path, save_to_pkl
|
||||
from stable_baselines3.common.utils import get_device
|
||||
from stable_baselines3.common.vec_env import DummyVecEnv
|
||||
|
||||
MODEL_LIST = [PPO, A2C, TD3, SAC, DQN, DDPG]
|
||||
|
|
@ -70,21 +71,33 @@ def test_save_load(tmp_path, model_class):
|
|||
# Check
|
||||
model.save(tmp_path / "test_save.zip")
|
||||
del model
|
||||
model = model_class.load(str(tmp_path / "test_save.zip"), env=env)
|
||||
|
||||
# check if params are still the same after load
|
||||
new_params = model.policy.state_dict()
|
||||
# Check if the model loads as expected for every possible choice of device:
|
||||
for device in ["auto", "cpu", "cuda"]:
|
||||
model = model_class.load(str(tmp_path / "test_save.zip"), env=env, device=device)
|
||||
|
||||
# Check that all params are the same as before save load procedure now
|
||||
for key in params:
|
||||
assert th.allclose(params[key], new_params[key]), "Model parameters not the same after save and load."
|
||||
# check if the model was loaded to the correct device
|
||||
assert model.device.type == get_device(device).type
|
||||
assert model.policy.device.type == get_device(device).type
|
||||
|
||||
# check if model still selects the same actions
|
||||
new_selected_actions, _ = model.predict(observations, deterministic=True)
|
||||
assert np.allclose(selected_actions, new_selected_actions, 1e-4)
|
||||
# check if params are still the same after load
|
||||
new_params = model.policy.state_dict()
|
||||
|
||||
# check if learn still works
|
||||
model.learn(total_timesteps=1000, eval_freq=500)
|
||||
# Check that all params are the same as before save load procedure now
|
||||
for key in params:
|
||||
assert new_params[key].device.type == get_device(device).type
|
||||
assert th.allclose(
|
||||
params[key].to("cpu"), new_params[key].to("cpu")
|
||||
), "Model parameters not the same after save and load."
|
||||
|
||||
# check if model still selects the same actions
|
||||
new_selected_actions, _ = model.predict(observations, deterministic=True)
|
||||
assert np.allclose(selected_actions, new_selected_actions, 1e-4)
|
||||
|
||||
# check if learn still works
|
||||
model.learn(total_timesteps=1000, eval_freq=500)
|
||||
|
||||
del model
|
||||
|
||||
# clear file from os
|
||||
os.remove(tmp_path / "test_save.zip")
|
||||
|
|
|
|||
Loading…
Reference in a new issue