env_id consistency in tests (#1224)

Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org>
This commit is contained in:
Quentin Gallouédec 2022-12-20 16:01:26 +01:00 committed by GitHub
parent 7fb8336f40
commit 96b1a7cf01
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 19 additions and 19 deletions

View file

@ -22,7 +22,7 @@ The full documentation is available here: https://docs.wandb.ai/guides/integrati
config = {
"policy_type": "MlpPolicy",
"total_timesteps": 25000,
"env_name": "CartPole-v1",
"env_id": "CartPole-v1",
}
run = wandb.init(
project="sb3",
@ -32,7 +32,7 @@ The full documentation is available here: https://docs.wandb.ai/guides/integrati
# save_code=True, # optional
)
model = PPO(config["policy_type"], config["env_name"], verbose=1, tensorboard_log=f"runs/{run.id}")
model = PPO(config["policy_type"], config["env_id"], verbose=1, tensorboard_log=f"runs/{run.id}")
model.learn(
total_timesteps=config["total_timesteps"],
callback=WandbCallback(

View file

@ -20,19 +20,26 @@ from stable_baselines3.common.envs import BitFlippingEnv, IdentityEnv
from stable_baselines3.common.vec_env import DummyVecEnv, VecNormalize
def select_env(model_class) -> str:
if model_class is DQN:
return "CartPole-v1"
else:
return "Pendulum-v1"
@pytest.mark.parametrize("model_class", [A2C, PPO, SAC, TD3, DQN, DDPG])
def test_callbacks(tmp_path, model_class):
log_folder = tmp_path / "logs/callbacks/"
# DQN only support discrete actions
env_name = select_env(model_class)
env_id = select_env(model_class)
# Create RL model
# Small network for fast test
model = model_class("MlpPolicy", env_name, policy_kwargs=dict(net_arch=[32]))
model = model_class("MlpPolicy", env_id, policy_kwargs=dict(net_arch=[32]))
checkpoint_callback = CheckpointCallback(save_freq=1000, save_path=log_folder)
eval_env = gym.make(env_name)
eval_env = gym.make(env_id)
# Stop training if the performance is good enough
callback_on_best = StopTrainingOnRewardThreshold(reward_threshold=-1200, verbose=1)
@ -82,7 +89,7 @@ def test_callbacks(tmp_path, model_class):
n_envs = 2
# Pendulum-v1 has a timelimit of 200 timesteps
max_episode_length = 200
envs = make_vec_env(env_name, n_envs=n_envs, seed=0)
envs = make_vec_env(env_id, n_envs=n_envs, seed=0)
model = model_class("MlpPolicy", envs, policy_kwargs=dict(net_arch=[32]))
@ -100,13 +107,6 @@ def test_callbacks(tmp_path, model_class):
shutil.rmtree(log_folder)
def select_env(model_class) -> str:
if model_class is DQN:
return "CartPole-v1"
else:
return "Pendulum-v1"
def test_eval_callback_vec_env():
# tests that eval callback does not crash when given a vector
n_eval_envs = 3
@ -153,17 +153,17 @@ def test_eval_callback_logs_are_written_with_the_correct_timestep(tmp_path):
pytest.importorskip("tensorboard")
from tensorboard.backend.event_processing.event_accumulator import EventAccumulator
env_name = select_env(DQN)
env_id = select_env(DQN)
model = DQN(
"MlpPolicy",
env_name,
env_id,
policy_kwargs=dict(net_arch=[32]),
tensorboard_log=tmp_path,
verbose=1,
seed=1,
)
eval_env = gym.make(env_name)
eval_env = gym.make(env_id)
eval_freq = 101
eval_callback = EvalCallback(eval_env, eval_freq=eval_freq, warn=False)
model.learn(500, callback=eval_callback)

View file

@ -40,10 +40,10 @@ def test_auto_wrap(model_class):
"""Test auto wrapping of env into a VecEnv."""
# Use different environment for DQN
if model_class is DQN:
env_name = "CartPole-v1"
env_id = "CartPole-v1"
else:
env_name = "Pendulum-v1"
env = gym.make(env_name)
env_id = "Pendulum-v1"
env = gym.make(env_id)
model = model_class("MlpPolicy", env)
model.learn(100)