Doc fix and improve error messages (#598)

* Fix custom env doc

* Catch common mistake

* Improve `EvalCallback` error message

* Lint test

* Update docs/guide/custom_env.rst

Co-authored-by: Adam Gleave <adam@gleave.me>

Co-authored-by: Adam Gleave <adam@gleave.me>
This commit is contained in:
Antonin RAFFIN 2021-10-08 18:08:31 +02:00 committed by GitHub
parent 740d61ada3
commit 1881d904a0
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
6 changed files with 73 additions and 5 deletions

View file

@ -8,7 +8,7 @@ That is to say, your environment must implement the following methods (and inher
.. note::
If you are using images as input, the input values must be in [0, 255] and np.uint8 as the observation
If you are using images as input, the observation must be of type ``np.uint8`` and be contained in [0, 255]
is normalized (dividing by 255 to have values in [0, 1]) when using CNN policies. Images can be either
channel-first or channel-last.
@ -76,12 +76,27 @@ To check that your environment follows the gym interface, please use:
We have created a `colab notebook <https://colab.research.google.com/github/araffin/rl-tutorial-jnrr19/blob/master/5_custom_gym_env.ipynb>`_ for
a concrete example of creating a custom environment.
You can also find a `complete guide online <https://github.com/openai/gym/blob/master/docs/creating-environments.md>`_
You can also find a `complete guide online <https://github.com/openai/gym/blob/master/docs/creating_environments.md>`_
on creating a custom Gym environment.
Optionally, you can also register the environment with gym,
that will allow you to create the RL agent in one line (and use ``gym.make()`` to instantiate the env).
that will allow you to create the RL agent in one line (and use ``gym.make()`` to instantiate the env):
.. code-block:: python
from gym.envs.registration import register
# Example for the CartPole environment
register(
# unique identifier for the env `name-version`
id="CartPole-v1",
# path to the class for creating the env
# Note: entry_point also accept a class as input (and not only a string)
entry_point="gym.envs.classic_control:CartPoleEnv",
# Max number of steps per episode, using a `TimeLimitWrapper`
max_episode_steps=500,
)
In the project, for testing purposes, we use a custom environment named ``IdentityEnv``

View file

@ -32,6 +32,8 @@ Deprecations:
Others:
^^^^^^^
- Cap gym max version to 0.19 to avoid issues with atari-py and other breaking changes
- Improved error message when using dict observation with the wrong policy
- Improved error message when using ``EvalCallback`` with two envs not wrapped the same way.
Documentation:
^^^^^^^^^^^^^^
@ -41,6 +43,7 @@ Documentation:
- Added ONNX export instructions (@batu)
- Update read the doc env (fixed ``docutils`` issue)
- Fix PPO environment name (@IljaAvadiev)
- Fix custom env doc and add env registration example
Release 1.2.0 (2021-09-03)

View file

@ -175,6 +175,10 @@ class BaseAlgorithm(ABC):
"Error: the model does not support multiple envs; it requires " "a single vectorized environment."
)
# Catch common mistake: using MlpPolicy/CnnPolicy instead of MultiInputPolicy
if policy in ["MlpPolicy", "CnnPolicy"] and isinstance(self.observation_space, gym.spaces.Dict):
raise ValueError(f"You must use `MultiInputPolicy` when working with dict observation space, not {policy}")
if self.use_sde and not isinstance(self.action_space, gym.spaces.Box):
raise ValueError("generalized State-Dependent Exploration (gSDE) can only be used with continuous actions.")

View file

@ -362,7 +362,15 @@ class EvalCallback(EventCallback):
if self.eval_freq > 0 and self.n_calls % self.eval_freq == 0:
# Sync training and eval env if there is VecNormalize
sync_envs_normalization(self.training_env, self.eval_env)
if self.model.get_vec_normalize_env() is not None:
try:
sync_envs_normalization(self.training_env, self.eval_env)
except AttributeError:
raise AssertionError(
"Training and eval env are not wrapped the same way, "
"see https://stable-baselines3.readthedocs.io/en/master/guide/callbacks.html#evalcallback "
"and warning above."
)
# Reset success rate buffer
self._is_success_buffer = []

View file

@ -16,7 +16,7 @@ from stable_baselines3.common.callbacks import (
)
from stable_baselines3.common.env_util import make_vec_env
from stable_baselines3.common.envs import BitFlippingEnv, IdentityEnv
from stable_baselines3.common.vec_env import DummyVecEnv
from stable_baselines3.common.vec_env import DummyVecEnv, VecNormalize
@pytest.mark.parametrize("model_class", [A2C, PPO, SAC, TD3, DQN, DDPG])
@ -167,3 +167,34 @@ def test_eval_callback_logs_are_written_with_the_correct_timestep(tmp_path):
acc.Reload()
for event in acc.scalars.Items("eval/mean_reward"):
assert event.step % eval_freq == 0
def test_eval_friendly_error():
# tests that eval callback does not crash when given a vector
train_env = VecNormalize(DummyVecEnv([lambda: gym.make("CartPole-v1")]))
eval_env = DummyVecEnv([lambda: gym.make("CartPole-v1")])
eval_env = VecNormalize(eval_env, training=False, norm_reward=False)
_ = train_env.reset()
original_obs = train_env.get_original_obs()
model = A2C("MlpPolicy", train_env, n_steps=50, seed=0)
eval_callback = EvalCallback(
eval_env,
eval_freq=100,
warn=False,
)
model.learn(100, callback=eval_callback)
# Check synchronization
assert np.allclose(train_env.normalize_obs(original_obs), eval_env.normalize_obs(original_obs))
wrong_eval_env = gym.make("CartPole-v1")
eval_callback = EvalCallback(
wrong_eval_env,
eval_freq=100,
warn=False,
)
with pytest.warns(Warning):
with pytest.raises(AssertionError):
model.learn(100, callback=eval_callback)

View file

@ -79,6 +79,13 @@ class DummyDictEnv(gym.Env):
pass
@pytest.mark.parametrize("policy", ["MlpPolicy", "CnnPolicy"])
def test_policy_hint(policy):
# Common mistake: using the wrong policy
with pytest.raises(ValueError):
PPO(policy, BitFlippingEnv(n_bits=4))
@pytest.mark.parametrize("model_class", [PPO, A2C])
def test_goal_env(model_class):
env = BitFlippingEnv(n_bits=4)