stable-baselines3/stable_baselines3/common/env_checker.py
Antonin RAFFIN 40e0b9d2c8
Add Gymnasium support (#1327)
* Fix failing set_env test

* Fix test failiing due to deprectation of env.seed

* Adjust mean reward threshold in failing test

* Fix her test failing due to rng

* Change seed and revert reward threshold to 90

* Pin gym version

* Make VecEnv compatible with gym seeding change

* Revert change to VecEnv reset signature

* Change subprocenv seed cmd to call reset instead

* Fix type check

* Add backward compat

* Add `compat_gym_seed` helper

* Add goal env checks in env_checker

* Add docs on  HER requirements for envs

* Capture user warning in test with inverted box space

* Update ale-py version

* Fix randint

* Allow noop_max to be zero

* Update changelog

* Update docker image

* Update doc conda env and dockerfile

* Custom envs should not have any warnings

* Fix test for numpy >= 1.21

* Add check for vectorized compute reward

* Bump to gym 0.24

* Fix gym default step docstring

* Test downgrading gym

* Revert "Test downgrading gym"

This reverts commit 0072b77156c006ada8a1d6e26ce347ed85a83eeb.

* Fix protobuf error

* Fix in dependencies

* Fix protobuf dep

* Use newest version of cartpole

* Update gym

* Fix warning

* Loosen required scipy version

* Scipy no longer needed

* Try gym 0.25

* Silence warnings from gym

* Filter warnings during tests

* Update doc

* Update requirements

* Add gym 26 compat in vec env

* Fixes in envs and tests for gym 0.26+

* Enforce gym 0.26 api

* format

* Fix formatting

* Fix dependencies

* Fix syntax

* Cleanup doc and warnings

* Faster tests

* Higher budget for HER perf test (revert prev change)

* Fixes and update doc

* Fix doc build

* Fix breaking change

* Fixes for rendering

* Rename variables in monitor

* update render method for gym 0.26 API

backwards compatible (mode argument is allowed) while using the gym 0.26 API (render mode is determined at environment creation)

* update tests and docs to new gym render API

* undo removal of render modes metatadata check

* set rgb_array as default render mode for gym.make

* undo changes & raise warning if not 'rgb_array'

* Fix type check

* Remove recursion and fix type checking

* Remove hacks for protobuf and gym 0.24

* Fix type annotations

* reuse existing render_mode attribute

* return tiled images for 'human' render mode

* Allow to use opencv for human render, fix typos

* Add warning when using non-zero start with Discrete (fixes #1197)

* Fix type checking

* Bug fixes and handle more cases

* Throw proper warnings

* Update test

* Fix new metadata name

* Ignore numpy warnings

* Fixes in vec recorder

* Global ignore

* Filter local warning too

* Monkey patch not needed for gym 26

* Add doc of VecEnv vs Gym API

* Add render test

* Fix return type

* Update VecEnv vs Gym API doc

* Fix for custom render mode

* Fix return type

* Fix type checking

* check test env test_buffer

* skip render check

* check env test_dict_env

* test_env test_gae

* check envs in remaining tests

* Update tests

* Add warning for Discrete action space with non-zero (#1295)

* Fix atari annotation

* ignore get_action_meanings [attr-defined]

* Fix mypy issues

* Add patch for gym/gymnasium transition

* Switch to gymnasium

* Rely on signature instead of version

* More patches

* Type ignore because of https://github.com/Farama-Foundation/Gymnasium/pull/39

* Fix doc build

* Fix pytype errors

* Fix atari requirement

* Update env checker due to change in dtype for Discrete

* Fix type hint

* Convert spaces for saved models

* Ignore pytype

* Remove gitlab CI

* Disable pytype for convert space

* Fix undefined info

* Fix undefined info

* Upgrade shimmy

* Fix wrappers type annotation (need PR from Gymnasium)

* Fix gymnasium dependency

* Fix dependency declaration

* Cap pygame version for python 3.7

* Point to master branch (v0.28.0)

* Fix: use main not master branch

* Rename done to terminated

* Fix pygame dependency for python 3.7

* Rename gym to gymnasium

* Update Gymnasium

* Fix test

* Fix tests

* Forks don't have access to private variables

* Fix linter warnings

* Update read the doc env

* Fix env checker for GoalEnv

* Fix import

* Update env checker (more info) and fix dtype

* Use micromamab for Docker

* Update dependencies

* Clarify VecEnv doc

* Fix Gymnasium version

* Copy file only after mamba install

* [ci skip] Update docker doc

* Polish code

* Reformat

* Remove deprecated features

* Ignore warning

* Update doc

* Update examples and changelog

* Fix type annotation bundle (SAC, TD3, A2C, PPO, base class) (#1436)

* Fix SAC type hints, improve DQN ones

* Fix A2C and TD3 type hints

* Fix PPO type hints

* Fix on-policy type hints

* Fix base class type annotation, do not use defaults

* Update version

* Disable mypy for python 3.7

* Rename Gym26StepReturn

* Update continuous critic type annotation

* Fix pytype complain

---------

Co-authored-by: Carlos Luis <carlos.luisgonc@gmail.com>
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Co-authored-by: Thomas Lips <37955681+tlpss@users.noreply.github.com>
Co-authored-by: tlips <thomas.lips@ugent.be>
Co-authored-by: tlpss <thomas17.lips@gmail.com>
Co-authored-by: Quentin GALLOUÉDEC <gallouedec.quentin@gmail.com>
2023-04-14 13:13:59 +02:00

445 lines
20 KiB
Python

import warnings
from typing import Any, Dict, Union
import gymnasium as gym
import numpy as np
from gymnasium import spaces
from stable_baselines3.common.preprocessing import check_for_nested_spaces, is_image_space_channels_first
from stable_baselines3.common.vec_env import DummyVecEnv, VecCheckNan
def _is_numpy_array_space(space: spaces.Space) -> bool:
"""
Returns False if provided space is not representable as a single numpy array
(e.g. Dict and Tuple spaces return False)
"""
return not isinstance(space, (spaces.Dict, spaces.Tuple))
def _check_image_input(observation_space: spaces.Box, key: str = "") -> None:
"""
Check that the input will be compatible with Stable-Baselines
when the observation is apparently an image.
:param observation_space: Observation space
:key: When the observation space comes from a Dict space, we pass the
corresponding key to have more precise warning messages. Defaults to "".
"""
if observation_space.dtype != np.uint8:
warnings.warn(
f"It seems that your observation {key} is an image but its `dtype` "
f"is ({observation_space.dtype}) whereas it has to be `np.uint8`. "
"If your observation is not an image, we recommend you to flatten the observation "
"to have only a 1D vector"
)
if np.any(observation_space.low != 0) or np.any(observation_space.high != 255):
warnings.warn(
f"It seems that your observation space {key} is an image but the "
"upper and lower bounds are not in [0, 255]. "
"Because the CNN policy normalize automatically the observation "
"you may encounter issue if the values are not in that range."
)
non_channel_idx = 0
# Check only if width/height of the image is big enough
if is_image_space_channels_first(observation_space):
non_channel_idx = -1
if observation_space.shape[non_channel_idx] < 36 or observation_space.shape[1] < 36:
warnings.warn(
"The minimal resolution for an image is 36x36 for the default `CnnPolicy`. "
"You might need to use a custom features extractor "
"cf. https://stable-baselines3.readthedocs.io/en/master/guide/custom_policy.html"
)
def _check_unsupported_spaces(env: gym.Env, observation_space: spaces.Space, action_space: spaces.Space) -> None:
"""Emit warnings when the observation space or action space used is not supported by Stable-Baselines."""
if isinstance(observation_space, spaces.Dict):
nested_dict = False
for key, space in observation_space.spaces.items():
if isinstance(space, spaces.Dict):
nested_dict = True
if isinstance(space, spaces.Discrete) and space.start != 0:
warnings.warn(
f"Discrete observation space (key '{key}') with a non-zero start is not supported by Stable-Baselines3. "
"You can use a wrapper or update your observation space."
)
if nested_dict:
warnings.warn(
"Nested observation spaces are not supported by Stable Baselines3 "
"(Dict spaces inside Dict space). "
"You should flatten it to have only one level of keys."
"For example, `dict(space1=dict(space2=Box(), space3=Box()), spaces4=Discrete())` "
"is not supported but `dict(space2=Box(), spaces3=Box(), spaces4=Discrete())` is."
)
if isinstance(observation_space, spaces.Tuple):
warnings.warn(
"The observation space is a Tuple,"
"this is currently not supported by Stable Baselines3. "
"However, you can convert it to a Dict observation space "
"(cf. https://github.com/openai/gym/blob/master/gym/spaces/dict.py). "
"which is supported by SB3."
)
if isinstance(observation_space, spaces.Discrete) and observation_space.start != 0:
warnings.warn(
"Discrete observation space with a non-zero start is not supported by Stable-Baselines3. "
"You can use a wrapper or update your observation space."
)
if isinstance(action_space, spaces.Discrete) and action_space.start != 0:
warnings.warn(
"Discrete action space with a non-zero start is not supported by Stable-Baselines3. "
"You can use a wrapper or update your action space."
)
if not _is_numpy_array_space(action_space):
warnings.warn(
"The action space is not based off a numpy array. Typically this means it's either a Dict or Tuple space. "
"This type of action space is currently not supported by Stable Baselines 3. You should try to flatten the "
"action using a wrapper."
)
def _check_nan(env: gym.Env) -> None:
"""Check for Inf and NaN using the VecWrapper."""
vec_env = VecCheckNan(DummyVecEnv([lambda: env]))
for _ in range(10):
action = np.array([env.action_space.sample()])
_, _, _, _ = vec_env.step(action)
def _is_goal_env(env: gym.Env) -> bool:
"""
Check if the env uses the convention for goal-conditioned envs (previously, the gym.GoalEnv interface)
"""
# We need to unwrap the env since gym.Wrapper has the compute_reward method
return hasattr(env.unwrapped, "compute_reward")
def _check_goal_env_obs(obs: dict, observation_space: spaces.Dict, method_name: str) -> None:
"""
Check that an environment implementing the `compute_rewards()` method
(previously known as GoalEnv in gym) contains three elements,
namely `observation`, `desired_goal`, and `achieved_goal`.
"""
assert len(observation_space.spaces) == 3, (
"A goal conditioned env must contain 3 observation keys: `observation`, `desired_goal`, and `achieved_goal`."
f"The current observation contains {len(observation_space.spaces)} keys: {list(observation_space.spaces.keys())}"
)
for key in ["achieved_goal", "desired_goal"]:
if key not in observation_space.spaces:
raise AssertionError(
f"The observation returned by the `{method_name}()` method of a goal-conditioned env requires the '{key}' "
"key to be part of the observation dictionary. "
f"Current keys are {list(observation_space.spaces.keys())}"
)
def _check_goal_env_compute_reward(
obs: Dict[str, Union[np.ndarray, int]],
env: gym.Env,
reward: float,
info: Dict[str, Any],
) -> None:
"""
Check that reward is computed with `compute_reward`
and that the implementation is vectorized.
"""
achieved_goal, desired_goal = obs["achieved_goal"], obs["desired_goal"]
assert reward == env.compute_reward( # type: ignore[attr-defined]
achieved_goal, desired_goal, info
), "The reward was not computed with `compute_reward()`"
achieved_goal, desired_goal = np.array(achieved_goal), np.array(desired_goal)
batch_achieved_goals = np.array([achieved_goal, achieved_goal])
batch_desired_goals = np.array([desired_goal, desired_goal])
if isinstance(achieved_goal, int) or len(achieved_goal.shape) == 0:
batch_achieved_goals = batch_achieved_goals.reshape(2, 1)
batch_desired_goals = batch_desired_goals.reshape(2, 1)
batch_infos = np.array([info, info])
rewards = env.compute_reward(batch_achieved_goals, batch_desired_goals, batch_infos) # type: ignore[attr-defined]
assert rewards.shape == (2,), f"Unexpected shape for vectorized computation of reward: {rewards.shape} != (2,)"
assert rewards[0] == reward, f"Vectorized computation of reward differs from single computation: {rewards[0]} != {reward}"
def _check_obs(obs: Union[tuple, dict, np.ndarray, int], observation_space: spaces.Space, method_name: str) -> None:
"""
Check that the observation returned by the environment
correspond to the declared one.
"""
if not isinstance(observation_space, spaces.Tuple):
assert not isinstance(
obs, tuple
), f"The observation returned by the `{method_name}()` method should be a single value, not a tuple"
# The check for a GoalEnv is done by the base class
if isinstance(observation_space, spaces.Discrete):
# Since https://github.com/Farama-Foundation/Gymnasium/pull/141,
# `sample()` will return a np.int64 instead of an int
assert np.issubdtype(type(obs), np.integer), f"The observation returned by `{method_name}()` method must be an int"
elif _is_numpy_array_space(observation_space):
assert isinstance(obs, np.ndarray), f"The observation returned by `{method_name}()` method must be a numpy array"
# Additional checks for numpy arrays, so the error message is clearer (see GH#1399)
if isinstance(obs, np.ndarray):
# check obs dimensions, dtype and bounds
assert observation_space.shape == obs.shape, (
f"The observation returned by the `{method_name}()` method does not match the shape "
f"of the given observation space {observation_space}. "
f"Expected: {observation_space.shape}, actual shape: {obs.shape}"
)
assert np.can_cast(obs.dtype, observation_space.dtype), (
f"The observation returned by the `{method_name}()` method does not match the data type (cannot cast) "
f"of the given observation space {observation_space}. "
f"Expected: {observation_space.dtype}, actual dtype: {obs.dtype}"
)
if isinstance(observation_space, spaces.Box):
assert np.all(obs >= observation_space.low), (
f"The observation returned by the `{method_name}()` method does not match the lower bound "
f"of the given observation space {observation_space}."
f"Expected: obs >= {np.min(observation_space.low)}, "
f"actual min value: {np.min(obs)} at index {np.argmin(obs)}"
)
assert np.all(obs <= observation_space.high), (
f"The observation returned by the `{method_name}()` method does not match the upper bound "
f"of the given observation space {observation_space}. "
f"Expected: obs <= {np.max(observation_space.high)}, "
f"actual max value: {np.max(obs)} at index {np.argmax(obs)}"
)
assert observation_space.contains(obs), (
f"The observation returned by the `{method_name}()` method "
f"does not match the given observation space {observation_space}"
)
def _check_box_obs(observation_space: spaces.Box, key: str = "") -> None:
"""
Check that the observation space is correctly formatted
when dealing with a ``Box()`` space. In particular, it checks:
- that the dimensions are big enough when it is an image, and that the type matches
- that the observation has an expected shape (warn the user if not)
"""
# If image, check the low and high values, the type and the number of channels
# and the shape (minimal value)
if len(observation_space.shape) == 3:
_check_image_input(observation_space, key)
if len(observation_space.shape) not in [1, 3]:
warnings.warn(
f"Your observation {key} has an unconventional shape (neither an image, nor a 1D vector). "
"We recommend you to flatten the observation "
"to have only a 1D vector or use a custom policy to properly process the data."
)
def _check_returned_values(env: gym.Env, observation_space: spaces.Space, action_space: spaces.Space) -> None:
"""
Check the returned values by the env when calling `.reset()` or `.step()` methods.
"""
# because env inherits from gymnasium.Env, we assume that `reset()` and `step()` methods exists
reset_returns = env.reset()
assert isinstance(reset_returns, tuple), "`reset()` must return a tuple (obs, info)"
assert len(reset_returns) == 2, f"`reset()` must return a tuple of size 2 (obs, info), not {len(reset_returns)}"
obs, info = reset_returns
assert isinstance(info, dict), f"The second element of the tuple return by `reset()` must be a dictionary not {info}"
if _is_goal_env(env):
# Make mypy happy, already checked
assert isinstance(observation_space, spaces.Dict)
_check_goal_env_obs(obs, observation_space, "reset")
elif isinstance(observation_space, spaces.Dict):
assert isinstance(obs, dict), "The observation returned by `reset()` must be a dictionary"
if not obs.keys() == observation_space.spaces.keys():
raise AssertionError(
"The observation keys returned by `reset()` must match the observation "
f"space keys: {obs.keys()} != {observation_space.spaces.keys()}"
)
for key in observation_space.spaces.keys():
try:
_check_obs(obs[key], observation_space.spaces[key], "reset")
except AssertionError as e:
raise AssertionError(f"Error while checking key={key}: " + str(e)) from e
else:
_check_obs(obs, observation_space, "reset")
# Sample a random action
action = action_space.sample()
data = env.step(action)
assert len(data) == 5, (
"The `step()` method must return five values: "
f"obs, reward, terminated, truncated, info. Actual: {len(data)} values returned."
)
# Unpack
obs, reward, terminated, truncated, info = data
if isinstance(observation_space, spaces.Dict):
assert isinstance(obs, dict), "The observation returned by `step()` must be a dictionary"
# Additional checks for GoalEnvs
if _is_goal_env(env):
# Make mypy happy, already checked
assert isinstance(observation_space, spaces.Dict)
_check_goal_env_obs(obs, observation_space, "step")
_check_goal_env_compute_reward(obs, env, float(reward), info)
if not obs.keys() == observation_space.spaces.keys():
raise AssertionError(
"The observation keys returned by `step()` must match the observation "
f"space keys: {obs.keys()} != {observation_space.spaces.keys()}"
)
for key in observation_space.spaces.keys():
try:
_check_obs(obs[key], observation_space.spaces[key], "step")
except AssertionError as e:
raise AssertionError(f"Error while checking key={key}: " + str(e)) from e
else:
_check_obs(obs, observation_space, "step")
# We also allow int because the reward will be cast to float
assert isinstance(reward, (float, int)), "The reward returned by `step()` must be a float"
assert isinstance(terminated, bool), "The `terminated` signal must be a boolean"
assert isinstance(truncated, bool), "The `truncated` signal must be a boolean"
assert isinstance(info, dict), "The `info` returned by `step()` must be a python dictionary"
# Goal conditioned env
if _is_goal_env(env):
# for mypy, env.unwrapped was checked by _is_goal_env()
assert hasattr(env, "compute_reward")
assert reward == env.compute_reward(obs["achieved_goal"], obs["desired_goal"], info)
def _check_spaces(env: gym.Env) -> None:
"""
Check that the observation and action spaces are defined and inherit from spaces.Space. For
envs that follow the goal-conditioned standard (previously, the gym.GoalEnv interface) we check
the observation space is gym.spaces.Dict
"""
# Helper to link to the code, because gym has no proper documentation
gym_spaces = " cf https://github.com/openai/gym/blob/master/gym/spaces/"
assert hasattr(env, "observation_space"), "You must specify an observation space (cf gym.spaces)" + gym_spaces
assert hasattr(env, "action_space"), "You must specify an action space (cf gym.spaces)" + gym_spaces
assert isinstance(env.observation_space, spaces.Space), (
"The observation space must inherit from gymnasium.spaces" + gym_spaces
)
assert isinstance(env.action_space, spaces.Space), "The action space must inherit from gymnasium.spaces" + gym_spaces
if _is_goal_env(env):
assert isinstance(
env.observation_space, spaces.Dict
), "Goal conditioned envs (previously gym.GoalEnv) require the observation space to be gym.spaces.Dict"
# Check render cannot be covered by CI
def _check_render(env: gym.Env, warn: bool = False) -> None: # pragma: no cover
"""
Check the instantiated render mode (if any) by calling the `render()`/`close()`
method of the environment.
:param env: The environment to check
:param warn: Whether to output additional warnings
:param headless: Whether to disable render modes
that require a graphical interface. False by default.
"""
render_modes = env.metadata.get("render_modes")
if render_modes is None:
if warn:
warnings.warn(
"No render modes was declared in the environment "
"(env.metadata['render_modes'] is None or not defined), "
"you may have trouble when calling `.render()`"
)
# Only check currrent render mode
if env.render_mode:
env.render()
env.close()
def check_env(env: gym.Env, warn: bool = True, skip_render_check: bool = True) -> None:
"""
Check that an environment follows Gym API.
This is particularly useful when using a custom environment.
Please take a look at https://github.com/openai/gym/blob/master/gym/core.py
for more information about the API.
It also optionally check that the environment is compatible with Stable-Baselines.
:param env: The Gym environment that will be checked
:param warn: Whether to output additional warnings
mainly related to the interaction with Stable Baselines
:param skip_render_check: Whether to skip the checks for the render method.
True by default (useful for the CI)
"""
assert isinstance(
env, gym.Env
), "Your environment must inherit from the gym.Env class cf https://github.com/openai/gym/blob/master/gym/core.py"
# ============= Check the spaces (observation and action) ================
_check_spaces(env)
# Define aliases for convenience
observation_space = env.observation_space
action_space = env.action_space
# Warn the user if needed.
# A warning means that the environment may run but not work properly with Stable Baselines algorithms
if warn:
_check_unsupported_spaces(env, observation_space, action_space)
obs_spaces = observation_space.spaces if isinstance(observation_space, spaces.Dict) else {"": observation_space}
for key, space in obs_spaces.items():
if isinstance(space, spaces.Box):
_check_box_obs(space, key)
# Check for the action space, it may lead to hard-to-debug issues
if isinstance(action_space, spaces.Box) and (
np.any(np.abs(action_space.low) != np.abs(action_space.high))
or np.any(action_space.low != -1)
or np.any(action_space.high != 1)
):
warnings.warn(
"We recommend you to use a symmetric and normalized Box action space (range=[-1, 1]) "
"cf https://stable-baselines3.readthedocs.io/en/master/guide/rl_tips.html"
)
if isinstance(action_space, spaces.Box):
assert np.all(
np.isfinite(np.array([action_space.low, action_space.high]))
), "Continuous action space must have a finite lower and upper bound"
if isinstance(action_space, spaces.Box) and action_space.dtype != np.dtype(np.float32):
warnings.warn(
f"Your action space has dtype {action_space.dtype}, we recommend using np.float32 to avoid cast errors."
)
# ============ Check the returned values ===============
_check_returned_values(env, observation_space, action_space)
# ==== Check the render method and the declared render modes ====
if not skip_render_check:
_check_render(env, warn) # pragma: no cover
try:
check_for_nested_spaces(env.observation_space)
# The check doesn't support nested observations/dict actions
# A warning about it has already been emitted
_check_nan(env)
except NotImplementedError:
pass