2023-04-14 11:13:59 +00:00
|
|
|
import gymnasium as gym
|
2021-09-28 19:57:49 +00:00
|
|
|
import numpy as np
|
2020-02-14 13:03:41 +00:00
|
|
|
import pytest
|
2020-08-23 11:27:52 +00:00
|
|
|
import torch as th
|
2023-04-14 11:13:59 +00:00
|
|
|
from gymnasium import spaces
|
2020-02-14 13:03:41 +00:00
|
|
|
|
2020-07-16 14:12:16 +00:00
|
|
|
from stable_baselines3 import A2C, DQN, PPO, SAC, TD3
|
2023-04-14 11:13:59 +00:00
|
|
|
from stable_baselines3.common.env_checker import check_env
|
2021-09-28 19:57:49 +00:00
|
|
|
from stable_baselines3.common.envs import IdentityEnv
|
2020-08-23 11:27:52 +00:00
|
|
|
from stable_baselines3.common.utils import get_device
|
2020-05-05 13:02:35 +00:00
|
|
|
from stable_baselines3.common.vec_env import DummyVecEnv
|
2020-02-14 13:03:41 +00:00
|
|
|
|
|
|
|
|
MODEL_LIST = [
|
|
|
|
|
PPO,
|
|
|
|
|
A2C,
|
|
|
|
|
TD3,
|
|
|
|
|
SAC,
|
2020-06-29 09:16:54 +00:00
|
|
|
DQN,
|
2020-02-14 13:03:41 +00:00
|
|
|
]
|
|
|
|
|
|
2020-03-12 10:12:10 +00:00
|
|
|
|
2023-01-02 13:51:11 +00:00
|
|
|
class SubClassedBox(spaces.Box):
|
2021-09-28 19:57:49 +00:00
|
|
|
def __init__(self, *args, **kwargs):
|
|
|
|
|
super().__init__(*args, **kwargs)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class CustomSubClassedSpaceEnv(gym.Env):
|
|
|
|
|
def __init__(self):
|
|
|
|
|
super().__init__()
|
|
|
|
|
self.observation_space = SubClassedBox(-1, 1, shape=(2,), dtype=np.float32)
|
|
|
|
|
self.action_space = SubClassedBox(-1, 1, shape=(2,), dtype=np.float32)
|
|
|
|
|
|
2023-05-20 08:30:54 +00:00
|
|
|
def reset(self, seed=None):
|
2023-04-14 11:13:59 +00:00
|
|
|
return self.observation_space.sample(), {}
|
2021-09-28 19:57:49 +00:00
|
|
|
|
|
|
|
|
def step(self, action):
|
2023-04-14 11:13:59 +00:00
|
|
|
return self.observation_space.sample(), 0.0, np.random.rand() > 0.5, False, {}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.parametrize("env_cls", [CustomSubClassedSpaceEnv])
|
|
|
|
|
def test_env(env_cls):
|
|
|
|
|
# Check the env used for testing
|
|
|
|
|
check_env(env_cls(), skip_render_check=True)
|
2021-09-28 19:57:49 +00:00
|
|
|
|
|
|
|
|
|
2020-02-14 13:03:41 +00:00
|
|
|
@pytest.mark.parametrize("model_class", MODEL_LIST)
|
|
|
|
|
def test_auto_wrap(model_class):
|
2022-10-11 08:55:16 +00:00
|
|
|
"""Test auto wrapping of env into a VecEnv."""
|
2020-06-29 09:16:54 +00:00
|
|
|
# Use different environment for DQN
|
|
|
|
|
if model_class is DQN:
|
2022-12-20 15:01:26 +00:00
|
|
|
env_id = "CartPole-v1"
|
2020-06-29 09:16:54 +00:00
|
|
|
else:
|
2022-12-20 15:01:26 +00:00
|
|
|
env_id = "Pendulum-v1"
|
|
|
|
|
env = gym.make(env_id)
|
2020-07-16 14:12:16 +00:00
|
|
|
model = model_class("MlpPolicy", env)
|
2022-10-11 08:55:16 +00:00
|
|
|
model.learn(100)
|
2020-02-14 13:03:41 +00:00
|
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.parametrize("model_class", MODEL_LIST)
|
2022-02-04 23:13:57 +00:00
|
|
|
@pytest.mark.parametrize("env_id", ["Pendulum-v1", "CartPole-v1"])
|
2020-08-23 11:27:52 +00:00
|
|
|
@pytest.mark.parametrize("device", ["cpu", "cuda", "auto"])
|
|
|
|
|
def test_predict(model_class, env_id, device):
|
|
|
|
|
if device == "cuda" and not th.cuda.is_available():
|
|
|
|
|
pytest.skip("CUDA not available")
|
|
|
|
|
|
2020-07-16 14:12:16 +00:00
|
|
|
if env_id == "CartPole-v1":
|
2020-06-29 09:16:54 +00:00
|
|
|
if model_class in [SAC, TD3]:
|
|
|
|
|
return
|
|
|
|
|
elif model_class in [DQN]:
|
2020-02-14 13:15:55 +00:00
|
|
|
return
|
2020-02-14 13:03:41 +00:00
|
|
|
|
2020-08-23 11:27:52 +00:00
|
|
|
# Test detection of different shapes by the predict method
|
|
|
|
|
model = model_class("MlpPolicy", env_id, device=device)
|
|
|
|
|
# Check that the policy is on the right device
|
2020-09-20 17:13:18 +00:00
|
|
|
assert get_device(device).type == model.policy.device.type
|
2020-08-23 11:27:52 +00:00
|
|
|
|
2020-02-14 13:15:55 +00:00
|
|
|
env = gym.make(env_id)
|
|
|
|
|
vec_env = DummyVecEnv([lambda: gym.make(env_id), lambda: gym.make(env_id)])
|
2020-02-14 13:03:41 +00:00
|
|
|
|
2023-04-14 11:13:59 +00:00
|
|
|
obs, _ = env.reset()
|
2020-03-18 14:11:19 +00:00
|
|
|
action, _ = model.predict(obs)
|
2022-07-18 09:22:19 +00:00
|
|
|
assert isinstance(action, np.ndarray)
|
2020-02-14 13:15:55 +00:00
|
|
|
assert action.shape == env.action_space.shape
|
2020-02-14 13:03:41 +00:00
|
|
|
assert env.action_space.contains(action)
|
|
|
|
|
|
|
|
|
|
vec_env_obs = vec_env.reset()
|
2020-03-18 14:11:19 +00:00
|
|
|
action, _ = model.predict(vec_env_obs)
|
2022-07-18 09:22:19 +00:00
|
|
|
assert isinstance(action, np.ndarray)
|
2020-02-14 13:03:41 +00:00
|
|
|
assert action.shape[0] == vec_env_obs.shape[0]
|
2020-11-16 22:43:26 +00:00
|
|
|
|
|
|
|
|
# Special case for DQN to check the epsilon greedy exploration
|
|
|
|
|
if model_class == DQN:
|
|
|
|
|
model.exploration_rate = 1.0
|
|
|
|
|
action, _ = model.predict(obs, deterministic=False)
|
|
|
|
|
assert action.shape == env.action_space.shape
|
|
|
|
|
assert env.action_space.contains(action)
|
|
|
|
|
|
|
|
|
|
action, _ = model.predict(vec_env_obs, deterministic=False)
|
|
|
|
|
assert action.shape[0] == vec_env_obs.shape[0]
|
2021-09-28 19:57:49 +00:00
|
|
|
|
|
|
|
|
|
|
|
|
|
def test_dqn_epsilon_greedy():
|
|
|
|
|
env = IdentityEnv(2)
|
|
|
|
|
model = DQN("MlpPolicy", env)
|
|
|
|
|
model.exploration_rate = 1.0
|
2023-04-14 11:13:59 +00:00
|
|
|
obs, _ = env.reset()
|
2021-09-28 19:57:49 +00:00
|
|
|
# is vectorized should not crash with discrete obs
|
|
|
|
|
action, _ = model.predict(obs, deterministic=False)
|
|
|
|
|
assert env.action_space.contains(action)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.parametrize("model_class", [A2C, SAC, PPO, TD3])
|
|
|
|
|
def test_subclassed_space_env(model_class):
|
|
|
|
|
env = CustomSubClassedSpaceEnv()
|
|
|
|
|
model = model_class("MlpPolicy", env, policy_kwargs=dict(net_arch=[32]))
|
|
|
|
|
model.learn(300)
|
2023-04-14 11:13:59 +00:00
|
|
|
obs, _ = env.reset()
|
2021-09-28 19:57:49 +00:00
|
|
|
env.step(model.predict(obs))
|
2023-09-25 10:39:22 +00:00
|
|
|
|
|
|
|
|
|
|
|
|
|
def test_mixing_gym_vecenv_api():
|
|
|
|
|
env = gym.make("CartPole-v1")
|
|
|
|
|
model = PPO("MlpPolicy", env)
|
|
|
|
|
# Reset return a tuple (obs, info)
|
|
|
|
|
wrong_obs = env.reset()
|
|
|
|
|
with pytest.raises(ValueError, match="mixing Gym API"):
|
|
|
|
|
model.predict(wrong_obs)
|