From 1881d904a03884fcdeea545b9d71f7cd71dd26fb Mon Sep 17 00:00:00 2001 From: Antonin RAFFIN Date: Fri, 8 Oct 2021 18:08:31 +0200 Subject: [PATCH] Doc fix and improve error messages (#598) * Fix custom env doc * Catch common mistake * Improve `EvalCallback` error message * Lint test * Update docs/guide/custom_env.rst Co-authored-by: Adam Gleave Co-authored-by: Adam Gleave --- docs/guide/custom_env.rst | 21 +++++++++++++--- docs/misc/changelog.rst | 3 +++ stable_baselines3/common/base_class.py | 4 ++++ stable_baselines3/common/callbacks.py | 10 +++++++- tests/test_callbacks.py | 33 +++++++++++++++++++++++++- tests/test_dict_env.py | 7 ++++++ 6 files changed, 73 insertions(+), 5 deletions(-) diff --git a/docs/guide/custom_env.rst b/docs/guide/custom_env.rst index 2b1e4b9..83f387c 100644 --- a/docs/guide/custom_env.rst +++ b/docs/guide/custom_env.rst @@ -8,7 +8,7 @@ That is to say, your environment must implement the following methods (and inher .. note:: - If you are using images as input, the input values must be in [0, 255] and np.uint8 as the observation + If you are using images as input, the observation must be of type ``np.uint8`` and be contained in [0, 255] is normalized (dividing by 255 to have values in [0, 1]) when using CNN policies. Images can be either channel-first or channel-last. @@ -76,12 +76,27 @@ To check that your environment follows the gym interface, please use: We have created a `colab notebook `_ for a concrete example of creating a custom environment. -You can also find a `complete guide online `_ +You can also find a `complete guide online `_ on creating a custom Gym environment. Optionally, you can also register the environment with gym, -that will allow you to create the RL agent in one line (and use ``gym.make()`` to instantiate the env). +that will allow you to create the RL agent in one line (and use ``gym.make()`` to instantiate the env): + +.. code-block:: python + + from gym.envs.registration import register + # Example for the CartPole environment + register( + # unique identifier for the env `name-version` + id="CartPole-v1", + # path to the class for creating the env + # Note: entry_point also accept a class as input (and not only a string) + entry_point="gym.envs.classic_control:CartPoleEnv", + # Max number of steps per episode, using a `TimeLimitWrapper` + max_episode_steps=500, + ) + In the project, for testing purposes, we use a custom environment named ``IdentityEnv`` diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index a8bf9b7..70eff4c 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -32,6 +32,8 @@ Deprecations: Others: ^^^^^^^ - Cap gym max version to 0.19 to avoid issues with atari-py and other breaking changes +- Improved error message when using dict observation with the wrong policy +- Improved error message when using ``EvalCallback`` with two envs not wrapped the same way. Documentation: ^^^^^^^^^^^^^^ @@ -41,6 +43,7 @@ Documentation: - Added ONNX export instructions (@batu) - Update read the doc env (fixed ``docutils`` issue) - Fix PPO environment name (@IljaAvadiev) +- Fix custom env doc and add env registration example Release 1.2.0 (2021-09-03) diff --git a/stable_baselines3/common/base_class.py b/stable_baselines3/common/base_class.py index 6b0df32..23f14e3 100644 --- a/stable_baselines3/common/base_class.py +++ b/stable_baselines3/common/base_class.py @@ -175,6 +175,10 @@ class BaseAlgorithm(ABC): "Error: the model does not support multiple envs; it requires " "a single vectorized environment." ) + # Catch common mistake: using MlpPolicy/CnnPolicy instead of MultiInputPolicy + if policy in ["MlpPolicy", "CnnPolicy"] and isinstance(self.observation_space, gym.spaces.Dict): + raise ValueError(f"You must use `MultiInputPolicy` when working with dict observation space, not {policy}") + if self.use_sde and not isinstance(self.action_space, gym.spaces.Box): raise ValueError("generalized State-Dependent Exploration (gSDE) can only be used with continuous actions.") diff --git a/stable_baselines3/common/callbacks.py b/stable_baselines3/common/callbacks.py index a45fbf6..9825347 100644 --- a/stable_baselines3/common/callbacks.py +++ b/stable_baselines3/common/callbacks.py @@ -362,7 +362,15 @@ class EvalCallback(EventCallback): if self.eval_freq > 0 and self.n_calls % self.eval_freq == 0: # Sync training and eval env if there is VecNormalize - sync_envs_normalization(self.training_env, self.eval_env) + if self.model.get_vec_normalize_env() is not None: + try: + sync_envs_normalization(self.training_env, self.eval_env) + except AttributeError: + raise AssertionError( + "Training and eval env are not wrapped the same way, " + "see https://stable-baselines3.readthedocs.io/en/master/guide/callbacks.html#evalcallback " + "and warning above." + ) # Reset success rate buffer self._is_success_buffer = [] diff --git a/tests/test_callbacks.py b/tests/test_callbacks.py index c94bb8f..56fc141 100644 --- a/tests/test_callbacks.py +++ b/tests/test_callbacks.py @@ -16,7 +16,7 @@ from stable_baselines3.common.callbacks import ( ) from stable_baselines3.common.env_util import make_vec_env from stable_baselines3.common.envs import BitFlippingEnv, IdentityEnv -from stable_baselines3.common.vec_env import DummyVecEnv +from stable_baselines3.common.vec_env import DummyVecEnv, VecNormalize @pytest.mark.parametrize("model_class", [A2C, PPO, SAC, TD3, DQN, DDPG]) @@ -167,3 +167,34 @@ def test_eval_callback_logs_are_written_with_the_correct_timestep(tmp_path): acc.Reload() for event in acc.scalars.Items("eval/mean_reward"): assert event.step % eval_freq == 0 + + +def test_eval_friendly_error(): + # tests that eval callback does not crash when given a vector + train_env = VecNormalize(DummyVecEnv([lambda: gym.make("CartPole-v1")])) + eval_env = DummyVecEnv([lambda: gym.make("CartPole-v1")]) + eval_env = VecNormalize(eval_env, training=False, norm_reward=False) + _ = train_env.reset() + original_obs = train_env.get_original_obs() + model = A2C("MlpPolicy", train_env, n_steps=50, seed=0) + + eval_callback = EvalCallback( + eval_env, + eval_freq=100, + warn=False, + ) + model.learn(100, callback=eval_callback) + + # Check synchronization + assert np.allclose(train_env.normalize_obs(original_obs), eval_env.normalize_obs(original_obs)) + + wrong_eval_env = gym.make("CartPole-v1") + eval_callback = EvalCallback( + wrong_eval_env, + eval_freq=100, + warn=False, + ) + + with pytest.warns(Warning): + with pytest.raises(AssertionError): + model.learn(100, callback=eval_callback) diff --git a/tests/test_dict_env.py b/tests/test_dict_env.py index b165180..7d936c5 100644 --- a/tests/test_dict_env.py +++ b/tests/test_dict_env.py @@ -79,6 +79,13 @@ class DummyDictEnv(gym.Env): pass +@pytest.mark.parametrize("policy", ["MlpPolicy", "CnnPolicy"]) +def test_policy_hint(policy): + # Common mistake: using the wrong policy + with pytest.raises(ValueError): + PPO(policy, BitFlippingEnv(n_bits=4)) + + @pytest.mark.parametrize("model_class", [PPO, A2C]) def test_goal_env(model_class): env = BitFlippingEnv(n_bits=4)