mirror of
https://github.com/saymrwulf/stable-baselines3.git
synced 2026-06-26 03:01:19 +00:00
env_id consistency in tests (#1224)
Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org>
This commit is contained in:
parent
7fb8336f40
commit
96b1a7cf01
3 changed files with 19 additions and 19 deletions
|
|
@ -22,7 +22,7 @@ The full documentation is available here: https://docs.wandb.ai/guides/integrati
|
|||
config = {
|
||||
"policy_type": "MlpPolicy",
|
||||
"total_timesteps": 25000,
|
||||
"env_name": "CartPole-v1",
|
||||
"env_id": "CartPole-v1",
|
||||
}
|
||||
run = wandb.init(
|
||||
project="sb3",
|
||||
|
|
@ -32,7 +32,7 @@ The full documentation is available here: https://docs.wandb.ai/guides/integrati
|
|||
# save_code=True, # optional
|
||||
)
|
||||
|
||||
model = PPO(config["policy_type"], config["env_name"], verbose=1, tensorboard_log=f"runs/{run.id}")
|
||||
model = PPO(config["policy_type"], config["env_id"], verbose=1, tensorboard_log=f"runs/{run.id}")
|
||||
model.learn(
|
||||
total_timesteps=config["total_timesteps"],
|
||||
callback=WandbCallback(
|
||||
|
|
|
|||
|
|
@ -20,19 +20,26 @@ from stable_baselines3.common.envs import BitFlippingEnv, IdentityEnv
|
|||
from stable_baselines3.common.vec_env import DummyVecEnv, VecNormalize
|
||||
|
||||
|
||||
def select_env(model_class) -> str:
|
||||
if model_class is DQN:
|
||||
return "CartPole-v1"
|
||||
else:
|
||||
return "Pendulum-v1"
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model_class", [A2C, PPO, SAC, TD3, DQN, DDPG])
|
||||
def test_callbacks(tmp_path, model_class):
|
||||
log_folder = tmp_path / "logs/callbacks/"
|
||||
|
||||
# DQN only support discrete actions
|
||||
env_name = select_env(model_class)
|
||||
env_id = select_env(model_class)
|
||||
# Create RL model
|
||||
# Small network for fast test
|
||||
model = model_class("MlpPolicy", env_name, policy_kwargs=dict(net_arch=[32]))
|
||||
model = model_class("MlpPolicy", env_id, policy_kwargs=dict(net_arch=[32]))
|
||||
|
||||
checkpoint_callback = CheckpointCallback(save_freq=1000, save_path=log_folder)
|
||||
|
||||
eval_env = gym.make(env_name)
|
||||
eval_env = gym.make(env_id)
|
||||
# Stop training if the performance is good enough
|
||||
callback_on_best = StopTrainingOnRewardThreshold(reward_threshold=-1200, verbose=1)
|
||||
|
||||
|
|
@ -82,7 +89,7 @@ def test_callbacks(tmp_path, model_class):
|
|||
n_envs = 2
|
||||
# Pendulum-v1 has a timelimit of 200 timesteps
|
||||
max_episode_length = 200
|
||||
envs = make_vec_env(env_name, n_envs=n_envs, seed=0)
|
||||
envs = make_vec_env(env_id, n_envs=n_envs, seed=0)
|
||||
|
||||
model = model_class("MlpPolicy", envs, policy_kwargs=dict(net_arch=[32]))
|
||||
|
||||
|
|
@ -100,13 +107,6 @@ def test_callbacks(tmp_path, model_class):
|
|||
shutil.rmtree(log_folder)
|
||||
|
||||
|
||||
def select_env(model_class) -> str:
|
||||
if model_class is DQN:
|
||||
return "CartPole-v1"
|
||||
else:
|
||||
return "Pendulum-v1"
|
||||
|
||||
|
||||
def test_eval_callback_vec_env():
|
||||
# tests that eval callback does not crash when given a vector
|
||||
n_eval_envs = 3
|
||||
|
|
@ -153,17 +153,17 @@ def test_eval_callback_logs_are_written_with_the_correct_timestep(tmp_path):
|
|||
pytest.importorskip("tensorboard")
|
||||
from tensorboard.backend.event_processing.event_accumulator import EventAccumulator
|
||||
|
||||
env_name = select_env(DQN)
|
||||
env_id = select_env(DQN)
|
||||
model = DQN(
|
||||
"MlpPolicy",
|
||||
env_name,
|
||||
env_id,
|
||||
policy_kwargs=dict(net_arch=[32]),
|
||||
tensorboard_log=tmp_path,
|
||||
verbose=1,
|
||||
seed=1,
|
||||
)
|
||||
|
||||
eval_env = gym.make(env_name)
|
||||
eval_env = gym.make(env_id)
|
||||
eval_freq = 101
|
||||
eval_callback = EvalCallback(eval_env, eval_freq=eval_freq, warn=False)
|
||||
model.learn(500, callback=eval_callback)
|
||||
|
|
|
|||
|
|
@ -40,10 +40,10 @@ def test_auto_wrap(model_class):
|
|||
"""Test auto wrapping of env into a VecEnv."""
|
||||
# Use different environment for DQN
|
||||
if model_class is DQN:
|
||||
env_name = "CartPole-v1"
|
||||
env_id = "CartPole-v1"
|
||||
else:
|
||||
env_name = "Pendulum-v1"
|
||||
env = gym.make(env_name)
|
||||
env_id = "Pendulum-v1"
|
||||
env = gym.make(env_id)
|
||||
model = model_class("MlpPolicy", env)
|
||||
model.learn(100)
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue