mirror of
https://github.com/saymrwulf/stable-baselines3.git
synced 2026-05-22 22:10:16 +00:00
Comment unused code
This commit is contained in:
parent
e31b139c47
commit
4392759057
2 changed files with 17 additions and 30 deletions
|
|
@ -22,11 +22,15 @@ def test_auto_wrap(model_class):
|
|||
|
||||
|
||||
@pytest.mark.parametrize("model_class", MODEL_LIST)
|
||||
def test_predict(model_class):
|
||||
@pytest.mark.parametrize("env_id", ['Pendulum-v0', 'CartPole-v1'])
|
||||
def test_predict(model_class, env_id):
|
||||
if env_id == 'CartPole-v1' and model_class not in [PPO, A2C]:
|
||||
return
|
||||
|
||||
# test detection of different shapes by the predict method
|
||||
model = model_class('MlpPolicy', 'Pendulum-v0')
|
||||
env = gym.make('Pendulum-v0')
|
||||
vec_env = DummyVecEnv([lambda: gym.make('Pendulum-v0'), lambda: gym.make('Pendulum-v0')])
|
||||
model = model_class('MlpPolicy', env_id)
|
||||
env = gym.make(env_id)
|
||||
vec_env = DummyVecEnv([lambda: gym.make(env_id), lambda: gym.make(env_id)])
|
||||
|
||||
obs = env.reset()
|
||||
action = model.predict(obs)
|
||||
|
|
@ -36,20 +40,3 @@ def test_predict(model_class):
|
|||
vec_env_obs = vec_env.reset()
|
||||
action = model.predict(vec_env_obs)
|
||||
assert action.shape[0] == vec_env_obs.shape[0]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model_class", [A2C, PPO])
|
||||
def test_predict_discrete(model_class):
|
||||
# test detection of different shapes by the predict method
|
||||
model = model_class('MlpPolicy', 'CartPole-v1')
|
||||
env = gym.make('CartPole-v1')
|
||||
vec_env = DummyVecEnv([lambda: gym.make('CartPole-v1'), lambda: gym.make('CartPole-v1')])
|
||||
|
||||
obs = env.reset()
|
||||
action = model.predict(obs)
|
||||
assert action.shape == ()
|
||||
assert env.action_space.contains(action)
|
||||
|
||||
vec_env_obs = vec_env.reset()
|
||||
action = model.predict(vec_env_obs)
|
||||
assert action.shape[0] == vec_env_obs.shape[0]
|
||||
|
|
|
|||
|
|
@ -327,15 +327,15 @@ class BaseRLModel(ABC):
|
|||
"Box environment, please use {} ".format(observation_space.shape) +
|
||||
"or (n_env, {}) for the observation shape."
|
||||
.format(", ".join(map(str, observation_space.shape))))
|
||||
elif isinstance(observation_space, gym.spaces.Discrete):
|
||||
if observation.shape == (): # A numpy array of a number, has shape empty tuple '()'
|
||||
return False
|
||||
elif len(observation.shape) == 1:
|
||||
return True
|
||||
else:
|
||||
raise ValueError("Error: Unexpected observation shape {} for ".format(observation.shape) +
|
||||
"Discrete environment, please use (1,) or (n_env, 1) for the observation shape.")
|
||||
# TODO: add support for MultiDiscrete and MultiBinary action spaces
|
||||
# TODO: add support for Discrete, MultiDiscrete and MultiBinary observation spaces
|
||||
# elif isinstance(observation_space, gym.spaces.Discrete):
|
||||
# if observation.shape == (): # A numpy array of a number, has shape empty tuple '()'
|
||||
# return False
|
||||
# elif len(observation.shape) == 1:
|
||||
# return True
|
||||
# else:
|
||||
# raise ValueError("Error: Unexpected observation shape {} for ".format(observation.shape) +
|
||||
# "Discrete environment, please use (1,) or (n_env, 1) for the observation shape.")
|
||||
# elif isinstance(observation_space, gym.spaces.MultiDiscrete):
|
||||
# if observation.shape == (len(observation_space.nvec),):
|
||||
# return False
|
||||
|
|
|
|||
Loading…
Reference in a new issue