stable-baselines3/tests/test_predict.py
liorcohen5 f5104a5efc
Allow to set a device when loading a model (#154)
* Added a 'device' keyword argument to BaseAlgorithm.load().
Edited the save and load test to also test the load method with all possible devices.
Added the changes to the changelog

* improved the load test to ensure that the model loads to the correct device.

* improved the test: now the correctness is improved. If the get_device policy would change, it wouldn't break the test.

* Update tests/test_save_load.py

@araffin's suggestion during the PR process

Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org>

* Update tests/test_save_load.py

Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org>

* Bug fixes: when comparing devices, comparing only device type since get_device() doesn't provide device index.
Now the code loads all of the model parameters from the saved state dict straight into the required device. (fixed load_from_zip_file).

* PR fixes: bug fix - a non-related test failed when running on GPU. updated the assertion to consider only types of devices. Also corrected a related bug in 'get_device()' method.

* Update changelog.rst

Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org>
2020-09-20 19:13:18 +02:00

61 lines
1.7 KiB
Python

import gym
import pytest
import torch as th
from stable_baselines3 import A2C, DQN, PPO, SAC, TD3
from stable_baselines3.common.utils import get_device
from stable_baselines3.common.vec_env import DummyVecEnv
MODEL_LIST = [
PPO,
A2C,
TD3,
SAC,
DQN,
]
@pytest.mark.parametrize("model_class", MODEL_LIST)
def test_auto_wrap(model_class):
# test auto wrapping of env into a VecEnv
# Use different environment for DQN
if model_class is DQN:
env_name = "CartPole-v0"
else:
env_name = "Pendulum-v0"
env = gym.make(env_name)
eval_env = gym.make(env_name)
model = model_class("MlpPolicy", env)
model.learn(100, eval_env=eval_env)
@pytest.mark.parametrize("model_class", MODEL_LIST)
@pytest.mark.parametrize("env_id", ["Pendulum-v0", "CartPole-v1"])
@pytest.mark.parametrize("device", ["cpu", "cuda", "auto"])
def test_predict(model_class, env_id, device):
if device == "cuda" and not th.cuda.is_available():
pytest.skip("CUDA not available")
if env_id == "CartPole-v1":
if model_class in [SAC, TD3]:
return
elif model_class in [DQN]:
return
# Test detection of different shapes by the predict method
model = model_class("MlpPolicy", env_id, device=device)
# Check that the policy is on the right device
assert get_device(device).type == model.policy.device.type
env = gym.make(env_id)
vec_env = DummyVecEnv([lambda: gym.make(env_id), lambda: gym.make(env_id)])
obs = env.reset()
action, _ = model.predict(obs)
assert action.shape == env.action_space.shape
assert env.action_space.contains(action)
vec_env_obs = vec_env.reset()
action, _ = model.predict(vec_env_obs)
assert action.shape[0] == vec_env_obs.shape[0]