From 4392759057958520e504bca407dc249dd5db8f33 Mon Sep 17 00:00:00 2001 From: Antonin Raffin Date: Fri, 14 Feb 2020 14:15:55 +0100 Subject: [PATCH] Comment unused code --- tests/test_predict.py | 29 ++++++++------------------- torchy_baselines/common/base_class.py | 18 ++++++++--------- 2 files changed, 17 insertions(+), 30 deletions(-) diff --git a/tests/test_predict.py b/tests/test_predict.py index e75deec..6f2245c 100644 --- a/tests/test_predict.py +++ b/tests/test_predict.py @@ -22,11 +22,15 @@ def test_auto_wrap(model_class): @pytest.mark.parametrize("model_class", MODEL_LIST) -def test_predict(model_class): +@pytest.mark.parametrize("env_id", ['Pendulum-v0', 'CartPole-v1']) +def test_predict(model_class, env_id): + if env_id == 'CartPole-v1' and model_class not in [PPO, A2C]: + return + # test detection of different shapes by the predict method - model = model_class('MlpPolicy', 'Pendulum-v0') - env = gym.make('Pendulum-v0') - vec_env = DummyVecEnv([lambda: gym.make('Pendulum-v0'), lambda: gym.make('Pendulum-v0')]) + model = model_class('MlpPolicy', env_id) + env = gym.make(env_id) + vec_env = DummyVecEnv([lambda: gym.make(env_id), lambda: gym.make(env_id)]) obs = env.reset() action = model.predict(obs) @@ -36,20 +40,3 @@ def test_predict(model_class): vec_env_obs = vec_env.reset() action = model.predict(vec_env_obs) assert action.shape[0] == vec_env_obs.shape[0] - - -@pytest.mark.parametrize("model_class", [A2C, PPO]) -def test_predict_discrete(model_class): - # test detection of different shapes by the predict method - model = model_class('MlpPolicy', 'CartPole-v1') - env = gym.make('CartPole-v1') - vec_env = DummyVecEnv([lambda: gym.make('CartPole-v1'), lambda: gym.make('CartPole-v1')]) - - obs = env.reset() - action = model.predict(obs) - assert action.shape == () - assert env.action_space.contains(action) - - vec_env_obs = vec_env.reset() - action = model.predict(vec_env_obs) - assert action.shape[0] == vec_env_obs.shape[0] diff --git a/torchy_baselines/common/base_class.py b/torchy_baselines/common/base_class.py index 15da3bb..6b3fb6a 100644 --- a/torchy_baselines/common/base_class.py +++ b/torchy_baselines/common/base_class.py @@ -327,15 +327,15 @@ class BaseRLModel(ABC): "Box environment, please use {} ".format(observation_space.shape) + "or (n_env, {}) for the observation shape." .format(", ".join(map(str, observation_space.shape)))) - elif isinstance(observation_space, gym.spaces.Discrete): - if observation.shape == (): # A numpy array of a number, has shape empty tuple '()' - return False - elif len(observation.shape) == 1: - return True - else: - raise ValueError("Error: Unexpected observation shape {} for ".format(observation.shape) + - "Discrete environment, please use (1,) or (n_env, 1) for the observation shape.") - # TODO: add support for MultiDiscrete and MultiBinary action spaces + # TODO: add support for Discrete, MultiDiscrete and MultiBinary observation spaces + # elif isinstance(observation_space, gym.spaces.Discrete): + # if observation.shape == (): # A numpy array of a number, has shape empty tuple '()' + # return False + # elif len(observation.shape) == 1: + # return True + # else: + # raise ValueError("Error: Unexpected observation shape {} for ".format(observation.shape) + + # "Discrete environment, please use (1,) or (n_env, 1) for the observation shape.") # elif isinstance(observation_space, gym.spaces.MultiDiscrete): # if observation.shape == (len(observation_space.nvec),): # return False