stable-baselines3/tests/test_callbacks.py
Antonin RAFFIN 40e0b9d2c8
Add Gymnasium support (#1327)
* Fix failing set_env test

* Fix test failiing due to deprectation of env.seed

* Adjust mean reward threshold in failing test

* Fix her test failing due to rng

* Change seed and revert reward threshold to 90

* Pin gym version

* Make VecEnv compatible with gym seeding change

* Revert change to VecEnv reset signature

* Change subprocenv seed cmd to call reset instead

* Fix type check

* Add backward compat

* Add `compat_gym_seed` helper

* Add goal env checks in env_checker

* Add docs on  HER requirements for envs

* Capture user warning in test with inverted box space

* Update ale-py version

* Fix randint

* Allow noop_max to be zero

* Update changelog

* Update docker image

* Update doc conda env and dockerfile

* Custom envs should not have any warnings

* Fix test for numpy >= 1.21

* Add check for vectorized compute reward

* Bump to gym 0.24

* Fix gym default step docstring

* Test downgrading gym

* Revert "Test downgrading gym"

This reverts commit 0072b77156c006ada8a1d6e26ce347ed85a83eeb.

* Fix protobuf error

* Fix in dependencies

* Fix protobuf dep

* Use newest version of cartpole

* Update gym

* Fix warning

* Loosen required scipy version

* Scipy no longer needed

* Try gym 0.25

* Silence warnings from gym

* Filter warnings during tests

* Update doc

* Update requirements

* Add gym 26 compat in vec env

* Fixes in envs and tests for gym 0.26+

* Enforce gym 0.26 api

* format

* Fix formatting

* Fix dependencies

* Fix syntax

* Cleanup doc and warnings

* Faster tests

* Higher budget for HER perf test (revert prev change)

* Fixes and update doc

* Fix doc build

* Fix breaking change

* Fixes for rendering

* Rename variables in monitor

* update render method for gym 0.26 API

backwards compatible (mode argument is allowed) while using the gym 0.26 API (render mode is determined at environment creation)

* update tests and docs to new gym render API

* undo removal of render modes metatadata check

* set rgb_array as default render mode for gym.make

* undo changes & raise warning if not 'rgb_array'

* Fix type check

* Remove recursion and fix type checking

* Remove hacks for protobuf and gym 0.24

* Fix type annotations

* reuse existing render_mode attribute

* return tiled images for 'human' render mode

* Allow to use opencv for human render, fix typos

* Add warning when using non-zero start with Discrete (fixes #1197)

* Fix type checking

* Bug fixes and handle more cases

* Throw proper warnings

* Update test

* Fix new metadata name

* Ignore numpy warnings

* Fixes in vec recorder

* Global ignore

* Filter local warning too

* Monkey patch not needed for gym 26

* Add doc of VecEnv vs Gym API

* Add render test

* Fix return type

* Update VecEnv vs Gym API doc

* Fix for custom render mode

* Fix return type

* Fix type checking

* check test env test_buffer

* skip render check

* check env test_dict_env

* test_env test_gae

* check envs in remaining tests

* Update tests

* Add warning for Discrete action space with non-zero (#1295)

* Fix atari annotation

* ignore get_action_meanings [attr-defined]

* Fix mypy issues

* Add patch for gym/gymnasium transition

* Switch to gymnasium

* Rely on signature instead of version

* More patches

* Type ignore because of https://github.com/Farama-Foundation/Gymnasium/pull/39

* Fix doc build

* Fix pytype errors

* Fix atari requirement

* Update env checker due to change in dtype for Discrete

* Fix type hint

* Convert spaces for saved models

* Ignore pytype

* Remove gitlab CI

* Disable pytype for convert space

* Fix undefined info

* Fix undefined info

* Upgrade shimmy

* Fix wrappers type annotation (need PR from Gymnasium)

* Fix gymnasium dependency

* Fix dependency declaration

* Cap pygame version for python 3.7

* Point to master branch (v0.28.0)

* Fix: use main not master branch

* Rename done to terminated

* Fix pygame dependency for python 3.7

* Rename gym to gymnasium

* Update Gymnasium

* Fix test

* Fix tests

* Forks don't have access to private variables

* Fix linter warnings

* Update read the doc env

* Fix env checker for GoalEnv

* Fix import

* Update env checker (more info) and fix dtype

* Use micromamab for Docker

* Update dependencies

* Clarify VecEnv doc

* Fix Gymnasium version

* Copy file only after mamba install

* [ci skip] Update docker doc

* Polish code

* Reformat

* Remove deprecated features

* Ignore warning

* Update doc

* Update examples and changelog

* Fix type annotation bundle (SAC, TD3, A2C, PPO, base class) (#1436)

* Fix SAC type hints, improve DQN ones

* Fix A2C and TD3 type hints

* Fix PPO type hints

* Fix on-policy type hints

* Fix base class type annotation, do not use defaults

* Update version

* Disable mypy for python 3.7

* Rename Gym26StepReturn

* Update continuous critic type annotation

* Fix pytype complain

---------

Co-authored-by: Carlos Luis <carlos.luisgonc@gmail.com>
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Co-authored-by: Thomas Lips <37955681+tlpss@users.noreply.github.com>
Co-authored-by: tlips <thomas.lips@ugent.be>
Co-authored-by: tlpss <thomas17.lips@gmail.com>
Co-authored-by: Quentin GALLOUÉDEC <gallouedec.quentin@gmail.com>
2023-04-14 13:13:59 +02:00

230 lines
8.1 KiB
Python

import os
import shutil
import gymnasium as gym
import numpy as np
import pytest
from stable_baselines3 import A2C, DDPG, DQN, PPO, SAC, TD3, HerReplayBuffer
from stable_baselines3.common.callbacks import (
CallbackList,
CheckpointCallback,
EvalCallback,
EveryNTimesteps,
StopTrainingOnMaxEpisodes,
StopTrainingOnNoModelImprovement,
StopTrainingOnRewardThreshold,
)
from stable_baselines3.common.env_util import make_vec_env
from stable_baselines3.common.envs import BitFlippingEnv, IdentityEnv
from stable_baselines3.common.vec_env import DummyVecEnv, VecNormalize
def select_env(model_class) -> str:
if model_class is DQN:
return "CartPole-v1"
else:
return "Pendulum-v1"
@pytest.mark.parametrize("model_class", [A2C, PPO, SAC, TD3, DQN, DDPG])
def test_callbacks(tmp_path, model_class):
log_folder = tmp_path / "logs/callbacks/"
# DQN only support discrete actions
env_id = select_env(model_class)
# Create RL model
# Small network for fast test
model = model_class("MlpPolicy", env_id, policy_kwargs=dict(net_arch=[32]))
checkpoint_callback = CheckpointCallback(save_freq=1000, save_path=log_folder)
eval_env = gym.make(env_id)
# Stop training if the performance is good enough
callback_on_best = StopTrainingOnRewardThreshold(reward_threshold=-1200, verbose=1)
# Stop training if there is no model improvement after 2 evaluations
callback_no_model_improvement = StopTrainingOnNoModelImprovement(max_no_improvement_evals=2, min_evals=1, verbose=1)
eval_callback = EvalCallback(
eval_env,
callback_on_new_best=callback_on_best,
callback_after_eval=callback_no_model_improvement,
best_model_save_path=log_folder,
log_path=log_folder,
eval_freq=100,
warn=False,
)
# Equivalent to the `checkpoint_callback`
# but here in an event-driven manner
checkpoint_on_event = CheckpointCallback(save_freq=1, save_path=log_folder, name_prefix="event")
event_callback = EveryNTimesteps(n_steps=500, callback=checkpoint_on_event)
# Stop training if max number of episodes is reached
callback_max_episodes = StopTrainingOnMaxEpisodes(max_episodes=100, verbose=1)
callback = CallbackList([checkpoint_callback, eval_callback, event_callback, callback_max_episodes])
model.learn(500, callback=callback)
# Check access to local variables
assert model.env.observation_space.contains(callback.locals["new_obs"][0])
# Check that the child callback was called
assert checkpoint_callback.locals["new_obs"] is callback.locals["new_obs"]
assert event_callback.locals["new_obs"] is callback.locals["new_obs"]
assert checkpoint_on_event.locals["new_obs"] is callback.locals["new_obs"]
# Check that internal callback counters match models' counters
assert event_callback.num_timesteps == model.num_timesteps
assert event_callback.n_calls == model.num_timesteps
model.learn(500, callback=None)
# Transform callback into a callback list automatically and use progress bar
model.learn(500, callback=[checkpoint_callback, eval_callback], progress_bar=True)
# Automatic wrapping, old way of doing callbacks
model.learn(500, callback=lambda _locals, _globals: True)
# Testing models that support multiple envs
if model_class in [A2C, PPO]:
max_episodes = 1
n_envs = 2
# Pendulum-v1 has a timelimit of 200 timesteps
max_episode_length = 200
envs = make_vec_env(env_id, n_envs=n_envs, seed=0)
model = model_class("MlpPolicy", envs, policy_kwargs=dict(net_arch=[32]))
callback_max_episodes = StopTrainingOnMaxEpisodes(max_episodes=max_episodes, verbose=1)
callback = CallbackList([callback_max_episodes])
model.learn(1000, callback=callback)
# Check that the actual number of episodes and timesteps per env matches the expected one
episodes_per_env = callback_max_episodes.n_episodes // n_envs
assert episodes_per_env == max_episodes
timesteps_per_env = model.num_timesteps // n_envs
assert timesteps_per_env == max_episode_length
if os.path.exists(log_folder):
shutil.rmtree(log_folder)
def test_eval_callback_vec_env():
# tests that eval callback does not crash when given a vector
n_eval_envs = 3
train_env = IdentityEnv()
eval_env = DummyVecEnv([lambda: IdentityEnv()] * n_eval_envs)
model = A2C("MlpPolicy", train_env, seed=0)
eval_callback = EvalCallback(
eval_env,
eval_freq=100,
warn=False,
)
model.learn(300, callback=eval_callback)
assert eval_callback.last_mean_reward == 100.0
def test_eval_success_logging(tmp_path):
n_bits = 2
n_envs = 2
env = BitFlippingEnv(n_bits=n_bits)
eval_env = DummyVecEnv([lambda: BitFlippingEnv(n_bits=n_bits)] * n_envs)
eval_callback = EvalCallback(
eval_env,
eval_freq=250,
log_path=tmp_path,
warn=False,
)
model = DQN(
"MultiInputPolicy",
env,
replay_buffer_class=HerReplayBuffer,
learning_starts=100,
seed=0,
)
model.learn(500, callback=eval_callback)
assert len(eval_callback._is_success_buffer) > 0
# More than 50% success rate
assert np.mean(eval_callback._is_success_buffer) > 0.5
def test_eval_callback_logs_are_written_with_the_correct_timestep(tmp_path):
# Skip if no tensorboard installed
pytest.importorskip("tensorboard")
from tensorboard.backend.event_processing.event_accumulator import EventAccumulator
env_id = select_env(DQN)
model = DQN(
"MlpPolicy",
env_id,
policy_kwargs=dict(net_arch=[32]),
tensorboard_log=tmp_path,
verbose=1,
seed=1,
)
eval_env = gym.make(env_id)
eval_freq = 101
eval_callback = EvalCallback(eval_env, eval_freq=eval_freq, warn=False)
model.learn(500, callback=eval_callback)
acc = EventAccumulator(str(tmp_path / "DQN_1"))
acc.Reload()
for event in acc.scalars.Items("eval/mean_reward"):
assert event.step % eval_freq == 0
def test_eval_friendly_error():
# tests that eval callback does not crash when given a vector
train_env = VecNormalize(DummyVecEnv([lambda: gym.make("CartPole-v1")]))
eval_env = DummyVecEnv([lambda: gym.make("CartPole-v1")])
eval_env = VecNormalize(eval_env, training=False, norm_reward=False)
_ = train_env.reset()
original_obs = train_env.get_original_obs()
model = A2C("MlpPolicy", train_env, n_steps=50, seed=0)
eval_callback = EvalCallback(
eval_env,
eval_freq=100,
warn=False,
)
model.learn(100, callback=eval_callback)
# Check synchronization
assert np.allclose(train_env.normalize_obs(original_obs), eval_env.normalize_obs(original_obs))
wrong_eval_env = gym.make("CartPole-v1")
eval_callback = EvalCallback(
wrong_eval_env,
eval_freq=100,
warn=False,
)
with pytest.warns(Warning):
with pytest.raises(AssertionError):
model.learn(100, callback=eval_callback)
def test_checkpoint_additional_info(tmp_path):
# tests if the replay buffer and the VecNormalize stats are saved with every checkpoint
dummy_vec_env = DummyVecEnv([lambda: gym.make("CartPole-v1")])
env = VecNormalize(dummy_vec_env)
checkpoint_dir = tmp_path / "checkpoints"
checkpoint_callback = CheckpointCallback(
save_freq=200,
save_path=checkpoint_dir,
save_replay_buffer=True,
save_vecnormalize=True,
verbose=2,
)
model = DQN("MlpPolicy", env, learning_starts=100, buffer_size=500, seed=0)
model.learn(200, callback=checkpoint_callback)
assert os.path.exists(checkpoint_dir / "rl_model_200_steps.zip")
assert os.path.exists(checkpoint_dir / "rl_model_replay_buffer_200_steps.pkl")
assert os.path.exists(checkpoint_dir / "rl_model_vecnormalize_200_steps.pkl")
# Check that checkpoints can be properly loaded
model = DQN.load(checkpoint_dir / "rl_model_200_steps.zip")
model.load_replay_buffer(checkpoint_dir / "rl_model_replay_buffer_200_steps.pkl")
VecNormalize.load(checkpoint_dir / "rl_model_vecnormalize_200_steps.pkl", dummy_vec_env)