Comment unused code

This commit is contained in:
Antonin Raffin 2020-02-14 14:15:55 +01:00
parent e31b139c47
commit 4392759057
2 changed files with 17 additions and 30 deletions

View file

@ -22,11 +22,15 @@ def test_auto_wrap(model_class):
@pytest.mark.parametrize("model_class", MODEL_LIST)
def test_predict(model_class):
@pytest.mark.parametrize("env_id", ['Pendulum-v0', 'CartPole-v1'])
def test_predict(model_class, env_id):
if env_id == 'CartPole-v1' and model_class not in [PPO, A2C]:
return
# test detection of different shapes by the predict method
model = model_class('MlpPolicy', 'Pendulum-v0')
env = gym.make('Pendulum-v0')
vec_env = DummyVecEnv([lambda: gym.make('Pendulum-v0'), lambda: gym.make('Pendulum-v0')])
model = model_class('MlpPolicy', env_id)
env = gym.make(env_id)
vec_env = DummyVecEnv([lambda: gym.make(env_id), lambda: gym.make(env_id)])
obs = env.reset()
action = model.predict(obs)
@ -36,20 +40,3 @@ def test_predict(model_class):
vec_env_obs = vec_env.reset()
action = model.predict(vec_env_obs)
assert action.shape[0] == vec_env_obs.shape[0]
@pytest.mark.parametrize("model_class", [A2C, PPO])
def test_predict_discrete(model_class):
# test detection of different shapes by the predict method
model = model_class('MlpPolicy', 'CartPole-v1')
env = gym.make('CartPole-v1')
vec_env = DummyVecEnv([lambda: gym.make('CartPole-v1'), lambda: gym.make('CartPole-v1')])
obs = env.reset()
action = model.predict(obs)
assert action.shape == ()
assert env.action_space.contains(action)
vec_env_obs = vec_env.reset()
action = model.predict(vec_env_obs)
assert action.shape[0] == vec_env_obs.shape[0]

View file

@ -327,15 +327,15 @@ class BaseRLModel(ABC):
"Box environment, please use {} ".format(observation_space.shape) +
"or (n_env, {}) for the observation shape."
.format(", ".join(map(str, observation_space.shape))))
elif isinstance(observation_space, gym.spaces.Discrete):
if observation.shape == (): # A numpy array of a number, has shape empty tuple '()'
return False
elif len(observation.shape) == 1:
return True
else:
raise ValueError("Error: Unexpected observation shape {} for ".format(observation.shape) +
"Discrete environment, please use (1,) or (n_env, 1) for the observation shape.")
# TODO: add support for MultiDiscrete and MultiBinary action spaces
# TODO: add support for Discrete, MultiDiscrete and MultiBinary observation spaces
# elif isinstance(observation_space, gym.spaces.Discrete):
# if observation.shape == (): # A numpy array of a number, has shape empty tuple '()'
# return False
# elif len(observation.shape) == 1:
# return True
# else:
# raise ValueError("Error: Unexpected observation shape {} for ".format(observation.shape) +
# "Discrete environment, please use (1,) or (n_env, 1) for the observation shape.")
# elif isinstance(observation_space, gym.spaces.MultiDiscrete):
# if observation.shape == (len(observation_space.nvec),):
# return False