stable-baselines3/tests/test_predict.py

44 lines
1.2 KiB
Python
Raw Normal View History

2020-02-14 13:03:41 +00:00
import gym
import pytest
from torchy_baselines import A2C, CEMRL, PPO, SAC, TD3
from torchy_baselines.common.vec_env import DummyVecEnv
MODEL_LIST = [
CEMRL,
PPO,
A2C,
TD3,
SAC,
]
2020-03-12 10:12:10 +00:00
2020-02-14 13:03:41 +00:00
@pytest.mark.parametrize("model_class", MODEL_LIST)
def test_auto_wrap(model_class):
# test auto wrapping of env into a VecEnv
env = gym.make('Pendulum-v0')
eval_env = gym.make('Pendulum-v0')
model = model_class('MlpPolicy', env)
model.learn(100, eval_env=eval_env)
@pytest.mark.parametrize("model_class", MODEL_LIST)
2020-02-14 13:15:55 +00:00
@pytest.mark.parametrize("env_id", ['Pendulum-v0', 'CartPole-v1'])
def test_predict(model_class, env_id):
if env_id == 'CartPole-v1' and model_class not in [PPO, A2C]:
return
2020-02-14 13:03:41 +00:00
# test detection of different shapes by the predict method
2020-02-14 13:15:55 +00:00
model = model_class('MlpPolicy', env_id)
env = gym.make(env_id)
vec_env = DummyVecEnv([lambda: gym.make(env_id), lambda: gym.make(env_id)])
2020-02-14 13:03:41 +00:00
obs = env.reset()
action = model.predict(obs)
2020-02-14 13:15:55 +00:00
assert action.shape == env.action_space.shape
2020-02-14 13:03:41 +00:00
assert env.action_space.contains(action)
vec_env_obs = vec_env.reset()
action = model.predict(vec_env_obs)
assert action.shape[0] == vec_env_obs.shape[0]