diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index c2449f0..9e2934e 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -4,7 +4,7 @@ Changelog ========== -Release 1.8.0a6 (WIP) +Release 1.8.0a7 (WIP) -------------------------- @@ -18,6 +18,7 @@ New Features: ^^^^^^^^^^^^^ - Added ``repeat_action_probability`` argument in ``AtariWrapper``. - Only use ``NoopResetEnv`` and ``MaxAndSkipEnv`` when needed in ``AtariWrapper`` +- Added support for dict/tuple observations spaces for ``VecCheckNan``, the check is now active in the ``env_checker()`` (@DavyMorgan) `SB3-Contrib`_ ^^^^^^^^^^^^^^ @@ -1230,4 +1231,4 @@ And all the contributors: @Gregwar @ycheng517 @quantitative-technologies @bcollazo @git-thor @TibiGG @cool-RR @MWeltevrede @Melanol @qgallouedec @francescoluciano @jlp-ue @burakdmb @timothe-chaumont @honglu2875 @yuanmingqi @anand-bala @hughperkins @sidney-tio @AlexPasqua @dominicgkerr @Akhilez @Rocamonde @tobirohrer @ZikangXiong -@DavyMorgan @luizapozzobon @Bonifatius94 @theSquaredError @harveybellini +@DavyMorgan @luizapozzobon @Bonifatius94 @theSquaredError @harveybellini @DavyMorgan diff --git a/setup.cfg b/setup.cfg index 3698ded..11bb464 100644 --- a/setup.cfg +++ b/setup.cfg @@ -70,7 +70,8 @@ exclude = (?x)( [flake8] # line breaks before and after binary operators -ignore = W503,W504,E203,E231 +# ignore explicit stack level +ignore = W503,W504,E203,E231,B028 # Ignore import not used when aliases are defined per-file-ignores = # Default implementation in abstract methods diff --git a/stable_baselines3/common/env_checker.py b/stable_baselines3/common/env_checker.py index cb682ca..ce01f2e 100644 --- a/stable_baselines3/common/env_checker.py +++ b/stable_baselines3/common/env_checker.py @@ -5,7 +5,7 @@ import gym import numpy as np from gym import spaces -from stable_baselines3.common.preprocessing import is_image_space_channels_first +from stable_baselines3.common.preprocessing import check_for_nested_spaces, is_image_space_channels_first from stable_baselines3.common.vec_env import DummyVecEnv, VecCheckNan @@ -380,6 +380,10 @@ def check_env(env: gym.Env, warn: bool = True, skip_render_check: bool = True) - if not skip_render_check: _check_render(env, warn=warn) # pragma: no cover - # The check only works with numpy arrays - if _is_numpy_array_space(observation_space) and _is_numpy_array_space(action_space): + try: + check_for_nested_spaces(env.observation_space) + # The check doesn't support nested observations/dict actions + # A warning about it has already been emitted _check_nan(env) + except NotImplementedError: + pass diff --git a/stable_baselines3/common/vec_env/vec_check_nan.py b/stable_baselines3/common/vec_env/vec_check_nan.py index a6409bc..98ad217 100644 --- a/stable_baselines3/common/vec_env/vec_check_nan.py +++ b/stable_baselines3/common/vec_env/vec_check_nan.py @@ -1,6 +1,8 @@ import warnings +from typing import List, Tuple import numpy as np +from gym import spaces from stable_baselines3.common.vec_env.base_vec_env import VecEnv, VecEnvObs, VecEnvStepReturn, VecEnvWrapper @@ -26,6 +28,8 @@ class VecCheckNan(VecEnvWrapper): self._actions: np.ndarray self._observations: VecEnvObs + if isinstance(venv.action_space, spaces.Dict): + raise NotImplementedError("VecCheckNan doesn't support dict action spaces") def step_async(self, actions: np.ndarray) -> None: self._check_val(event="step_async", actions=actions) @@ -44,19 +48,40 @@ class VecCheckNan(VecEnvWrapper): self._observations = observations return observations + def check_array_value(self, name: str, value: np.ndarray) -> List[Tuple[str, str]]: + """ + Check for inf and NaN for a single numpy array. + + :param name: Name of the value being check + :param value: Value (numpy array) to check + :return: A list of issues found. + """ + found = [] + has_nan = np.any(np.isnan(value)) + has_inf = self.check_inf and np.any(np.isinf(value)) + if has_inf: + found.append((name, "inf")) + if has_nan: + found.append((name, "nan")) + return found + def _check_val(self, event: str, **kwargs) -> None: # if warn and warn once and have warned once: then stop checking if not self.raise_exception and self.warn_once and self._user_warned: return found = [] - for name, val in kwargs.items(): - has_nan = np.any(np.isnan(val)) - has_inf = self.check_inf and np.any(np.isinf(val)) - if has_inf: - found.append((name, "inf")) - if has_nan: - found.append((name, "nan")) + for name, value in kwargs.items(): + if isinstance(value, (np.ndarray, list)): + found += self.check_array_value(name, np.asarray(value)) + elif isinstance(value, dict): + for inner_name, inner_val in value.items(): + found += self.check_array_value(f"{name}.{inner_name}", inner_val) + elif isinstance(value, tuple): + for idx, inner_val in enumerate(value): + found += self.check_array_value(f"{name}.{idx}", inner_val) + else: + raise TypeError(f"Unsupported observation type {type(value)}.") if found: self._user_warned = True diff --git a/stable_baselines3/version.txt b/stable_baselines3/version.txt index 838407c..61ecd05 100644 --- a/stable_baselines3/version.txt +++ b/stable_baselines3/version.txt @@ -1 +1 @@ -1.8.0a6 +1.8.0a7