diff --git a/docs/guide/integrations.rst b/docs/guide/integrations.rst index 3f7619d..49bbdb2 100644 --- a/docs/guide/integrations.rst +++ b/docs/guide/integrations.rst @@ -22,7 +22,7 @@ The full documentation is available here: https://docs.wandb.ai/guides/integrati config = { "policy_type": "MlpPolicy", "total_timesteps": 25000, - "env_name": "CartPole-v1", + "env_id": "CartPole-v1", } run = wandb.init( project="sb3", @@ -32,7 +32,7 @@ The full documentation is available here: https://docs.wandb.ai/guides/integrati # save_code=True, # optional ) - model = PPO(config["policy_type"], config["env_name"], verbose=1, tensorboard_log=f"runs/{run.id}") + model = PPO(config["policy_type"], config["env_id"], verbose=1, tensorboard_log=f"runs/{run.id}") model.learn( total_timesteps=config["total_timesteps"], callback=WandbCallback( diff --git a/tests/test_callbacks.py b/tests/test_callbacks.py index a0e20b7..420a16a 100644 --- a/tests/test_callbacks.py +++ b/tests/test_callbacks.py @@ -20,19 +20,26 @@ from stable_baselines3.common.envs import BitFlippingEnv, IdentityEnv from stable_baselines3.common.vec_env import DummyVecEnv, VecNormalize +def select_env(model_class) -> str: + if model_class is DQN: + return "CartPole-v1" + else: + return "Pendulum-v1" + + @pytest.mark.parametrize("model_class", [A2C, PPO, SAC, TD3, DQN, DDPG]) def test_callbacks(tmp_path, model_class): log_folder = tmp_path / "logs/callbacks/" # DQN only support discrete actions - env_name = select_env(model_class) + env_id = select_env(model_class) # Create RL model # Small network for fast test - model = model_class("MlpPolicy", env_name, policy_kwargs=dict(net_arch=[32])) + model = model_class("MlpPolicy", env_id, policy_kwargs=dict(net_arch=[32])) checkpoint_callback = CheckpointCallback(save_freq=1000, save_path=log_folder) - eval_env = gym.make(env_name) + eval_env = gym.make(env_id) # Stop training if the performance is good enough callback_on_best = StopTrainingOnRewardThreshold(reward_threshold=-1200, verbose=1) @@ -82,7 +89,7 @@ def test_callbacks(tmp_path, model_class): n_envs = 2 # Pendulum-v1 has a timelimit of 200 timesteps max_episode_length = 200 - envs = make_vec_env(env_name, n_envs=n_envs, seed=0) + envs = make_vec_env(env_id, n_envs=n_envs, seed=0) model = model_class("MlpPolicy", envs, policy_kwargs=dict(net_arch=[32])) @@ -100,13 +107,6 @@ def test_callbacks(tmp_path, model_class): shutil.rmtree(log_folder) -def select_env(model_class) -> str: - if model_class is DQN: - return "CartPole-v1" - else: - return "Pendulum-v1" - - def test_eval_callback_vec_env(): # tests that eval callback does not crash when given a vector n_eval_envs = 3 @@ -153,17 +153,17 @@ def test_eval_callback_logs_are_written_with_the_correct_timestep(tmp_path): pytest.importorskip("tensorboard") from tensorboard.backend.event_processing.event_accumulator import EventAccumulator - env_name = select_env(DQN) + env_id = select_env(DQN) model = DQN( "MlpPolicy", - env_name, + env_id, policy_kwargs=dict(net_arch=[32]), tensorboard_log=tmp_path, verbose=1, seed=1, ) - eval_env = gym.make(env_name) + eval_env = gym.make(env_id) eval_freq = 101 eval_callback = EvalCallback(eval_env, eval_freq=eval_freq, warn=False) model.learn(500, callback=eval_callback) diff --git a/tests/test_predict.py b/tests/test_predict.py index 22ff4fd..6343e2f 100644 --- a/tests/test_predict.py +++ b/tests/test_predict.py @@ -40,10 +40,10 @@ def test_auto_wrap(model_class): """Test auto wrapping of env into a VecEnv.""" # Use different environment for DQN if model_class is DQN: - env_name = "CartPole-v1" + env_id = "CartPole-v1" else: - env_name = "Pendulum-v1" - env = gym.make(env_name) + env_id = "Pendulum-v1" + env = gym.make(env_id) model = model_class("MlpPolicy", env) model.learn(100)