mirror of
https://github.com/saymrwulf/stable-baselines3.git
synced 2026-05-18 21:30:19 +00:00
Doc fix and improve error messages (#598)
* Fix custom env doc * Catch common mistake * Improve `EvalCallback` error message * Lint test * Update docs/guide/custom_env.rst Co-authored-by: Adam Gleave <adam@gleave.me> Co-authored-by: Adam Gleave <adam@gleave.me>
This commit is contained in:
parent
740d61ada3
commit
1881d904a0
6 changed files with 73 additions and 5 deletions
|
|
@ -8,7 +8,7 @@ That is to say, your environment must implement the following methods (and inher
|
|||
|
||||
|
||||
.. note::
|
||||
If you are using images as input, the input values must be in [0, 255] and np.uint8 as the observation
|
||||
If you are using images as input, the observation must be of type ``np.uint8`` and be contained in [0, 255]
|
||||
is normalized (dividing by 255 to have values in [0, 1]) when using CNN policies. Images can be either
|
||||
channel-first or channel-last.
|
||||
|
||||
|
|
@ -76,12 +76,27 @@ To check that your environment follows the gym interface, please use:
|
|||
We have created a `colab notebook <https://colab.research.google.com/github/araffin/rl-tutorial-jnrr19/blob/master/5_custom_gym_env.ipynb>`_ for
|
||||
a concrete example of creating a custom environment.
|
||||
|
||||
You can also find a `complete guide online <https://github.com/openai/gym/blob/master/docs/creating-environments.md>`_
|
||||
You can also find a `complete guide online <https://github.com/openai/gym/blob/master/docs/creating_environments.md>`_
|
||||
on creating a custom Gym environment.
|
||||
|
||||
|
||||
Optionally, you can also register the environment with gym,
|
||||
that will allow you to create the RL agent in one line (and use ``gym.make()`` to instantiate the env).
|
||||
that will allow you to create the RL agent in one line (and use ``gym.make()`` to instantiate the env):
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from gym.envs.registration import register
|
||||
# Example for the CartPole environment
|
||||
register(
|
||||
# unique identifier for the env `name-version`
|
||||
id="CartPole-v1",
|
||||
# path to the class for creating the env
|
||||
# Note: entry_point also accept a class as input (and not only a string)
|
||||
entry_point="gym.envs.classic_control:CartPoleEnv",
|
||||
# Max number of steps per episode, using a `TimeLimitWrapper`
|
||||
max_episode_steps=500,
|
||||
)
|
||||
|
||||
|
||||
|
||||
In the project, for testing purposes, we use a custom environment named ``IdentityEnv``
|
||||
|
|
|
|||
|
|
@ -32,6 +32,8 @@ Deprecations:
|
|||
Others:
|
||||
^^^^^^^
|
||||
- Cap gym max version to 0.19 to avoid issues with atari-py and other breaking changes
|
||||
- Improved error message when using dict observation with the wrong policy
|
||||
- Improved error message when using ``EvalCallback`` with two envs not wrapped the same way.
|
||||
|
||||
Documentation:
|
||||
^^^^^^^^^^^^^^
|
||||
|
|
@ -41,6 +43,7 @@ Documentation:
|
|||
- Added ONNX export instructions (@batu)
|
||||
- Update read the doc env (fixed ``docutils`` issue)
|
||||
- Fix PPO environment name (@IljaAvadiev)
|
||||
- Fix custom env doc and add env registration example
|
||||
|
||||
|
||||
Release 1.2.0 (2021-09-03)
|
||||
|
|
|
|||
|
|
@ -175,6 +175,10 @@ class BaseAlgorithm(ABC):
|
|||
"Error: the model does not support multiple envs; it requires " "a single vectorized environment."
|
||||
)
|
||||
|
||||
# Catch common mistake: using MlpPolicy/CnnPolicy instead of MultiInputPolicy
|
||||
if policy in ["MlpPolicy", "CnnPolicy"] and isinstance(self.observation_space, gym.spaces.Dict):
|
||||
raise ValueError(f"You must use `MultiInputPolicy` when working with dict observation space, not {policy}")
|
||||
|
||||
if self.use_sde and not isinstance(self.action_space, gym.spaces.Box):
|
||||
raise ValueError("generalized State-Dependent Exploration (gSDE) can only be used with continuous actions.")
|
||||
|
||||
|
|
|
|||
|
|
@ -362,7 +362,15 @@ class EvalCallback(EventCallback):
|
|||
|
||||
if self.eval_freq > 0 and self.n_calls % self.eval_freq == 0:
|
||||
# Sync training and eval env if there is VecNormalize
|
||||
sync_envs_normalization(self.training_env, self.eval_env)
|
||||
if self.model.get_vec_normalize_env() is not None:
|
||||
try:
|
||||
sync_envs_normalization(self.training_env, self.eval_env)
|
||||
except AttributeError:
|
||||
raise AssertionError(
|
||||
"Training and eval env are not wrapped the same way, "
|
||||
"see https://stable-baselines3.readthedocs.io/en/master/guide/callbacks.html#evalcallback "
|
||||
"and warning above."
|
||||
)
|
||||
|
||||
# Reset success rate buffer
|
||||
self._is_success_buffer = []
|
||||
|
|
|
|||
|
|
@ -16,7 +16,7 @@ from stable_baselines3.common.callbacks import (
|
|||
)
|
||||
from stable_baselines3.common.env_util import make_vec_env
|
||||
from stable_baselines3.common.envs import BitFlippingEnv, IdentityEnv
|
||||
from stable_baselines3.common.vec_env import DummyVecEnv
|
||||
from stable_baselines3.common.vec_env import DummyVecEnv, VecNormalize
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model_class", [A2C, PPO, SAC, TD3, DQN, DDPG])
|
||||
|
|
@ -167,3 +167,34 @@ def test_eval_callback_logs_are_written_with_the_correct_timestep(tmp_path):
|
|||
acc.Reload()
|
||||
for event in acc.scalars.Items("eval/mean_reward"):
|
||||
assert event.step % eval_freq == 0
|
||||
|
||||
|
||||
def test_eval_friendly_error():
|
||||
# tests that eval callback does not crash when given a vector
|
||||
train_env = VecNormalize(DummyVecEnv([lambda: gym.make("CartPole-v1")]))
|
||||
eval_env = DummyVecEnv([lambda: gym.make("CartPole-v1")])
|
||||
eval_env = VecNormalize(eval_env, training=False, norm_reward=False)
|
||||
_ = train_env.reset()
|
||||
original_obs = train_env.get_original_obs()
|
||||
model = A2C("MlpPolicy", train_env, n_steps=50, seed=0)
|
||||
|
||||
eval_callback = EvalCallback(
|
||||
eval_env,
|
||||
eval_freq=100,
|
||||
warn=False,
|
||||
)
|
||||
model.learn(100, callback=eval_callback)
|
||||
|
||||
# Check synchronization
|
||||
assert np.allclose(train_env.normalize_obs(original_obs), eval_env.normalize_obs(original_obs))
|
||||
|
||||
wrong_eval_env = gym.make("CartPole-v1")
|
||||
eval_callback = EvalCallback(
|
||||
wrong_eval_env,
|
||||
eval_freq=100,
|
||||
warn=False,
|
||||
)
|
||||
|
||||
with pytest.warns(Warning):
|
||||
with pytest.raises(AssertionError):
|
||||
model.learn(100, callback=eval_callback)
|
||||
|
|
|
|||
|
|
@ -79,6 +79,13 @@ class DummyDictEnv(gym.Env):
|
|||
pass
|
||||
|
||||
|
||||
@pytest.mark.parametrize("policy", ["MlpPolicy", "CnnPolicy"])
|
||||
def test_policy_hint(policy):
|
||||
# Common mistake: using the wrong policy
|
||||
with pytest.raises(ValueError):
|
||||
PPO(policy, BitFlippingEnv(n_bits=4))
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model_class", [PPO, A2C])
|
||||
def test_goal_env(model_class):
|
||||
env = BitFlippingEnv(n_bits=4)
|
||||
|
|
|
|||
Loading…
Reference in a new issue