Add support for dict/tuple obs space for VecCheckNaN (#1348)

* Add support for dict/tuple obs space for VecCheckNaN

* Handle list too

* Address comments from code review

* Ignore B028 (explicit stack level)
This commit is contained in:
Antonin RAFFIN 2023-02-27 13:45:17 +01:00 committed by GitHub
parent 085bdd5a68
commit ed8783cb73
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 45 additions and 14 deletions

View file

@ -4,7 +4,7 @@ Changelog
==========
Release 1.8.0a6 (WIP)
Release 1.8.0a7 (WIP)
--------------------------
@ -18,6 +18,7 @@ New Features:
^^^^^^^^^^^^^
- Added ``repeat_action_probability`` argument in ``AtariWrapper``.
- Only use ``NoopResetEnv`` and ``MaxAndSkipEnv`` when needed in ``AtariWrapper``
- Added support for dict/tuple observations spaces for ``VecCheckNan``, the check is now active in the ``env_checker()`` (@DavyMorgan)
`SB3-Contrib`_
^^^^^^^^^^^^^^
@ -1230,4 +1231,4 @@ And all the contributors:
@Gregwar @ycheng517 @quantitative-technologies @bcollazo @git-thor @TibiGG @cool-RR @MWeltevrede
@Melanol @qgallouedec @francescoluciano @jlp-ue @burakdmb @timothe-chaumont @honglu2875 @yuanmingqi
@anand-bala @hughperkins @sidney-tio @AlexPasqua @dominicgkerr @Akhilez @Rocamonde @tobirohrer @ZikangXiong
@DavyMorgan @luizapozzobon @Bonifatius94 @theSquaredError @harveybellini
@DavyMorgan @luizapozzobon @Bonifatius94 @theSquaredError @harveybellini @DavyMorgan

View file

@ -70,7 +70,8 @@ exclude = (?x)(
[flake8]
# line breaks before and after binary operators
ignore = W503,W504,E203,E231
# ignore explicit stack level
ignore = W503,W504,E203,E231,B028
# Ignore import not used when aliases are defined
per-file-ignores =
# Default implementation in abstract methods

View file

@ -5,7 +5,7 @@ import gym
import numpy as np
from gym import spaces
from stable_baselines3.common.preprocessing import is_image_space_channels_first
from stable_baselines3.common.preprocessing import check_for_nested_spaces, is_image_space_channels_first
from stable_baselines3.common.vec_env import DummyVecEnv, VecCheckNan
@ -380,6 +380,10 @@ def check_env(env: gym.Env, warn: bool = True, skip_render_check: bool = True) -
if not skip_render_check:
_check_render(env, warn=warn) # pragma: no cover
# The check only works with numpy arrays
if _is_numpy_array_space(observation_space) and _is_numpy_array_space(action_space):
try:
check_for_nested_spaces(env.observation_space)
# The check doesn't support nested observations/dict actions
# A warning about it has already been emitted
_check_nan(env)
except NotImplementedError:
pass

View file

@ -1,6 +1,8 @@
import warnings
from typing import List, Tuple
import numpy as np
from gym import spaces
from stable_baselines3.common.vec_env.base_vec_env import VecEnv, VecEnvObs, VecEnvStepReturn, VecEnvWrapper
@ -26,6 +28,8 @@ class VecCheckNan(VecEnvWrapper):
self._actions: np.ndarray
self._observations: VecEnvObs
if isinstance(venv.action_space, spaces.Dict):
raise NotImplementedError("VecCheckNan doesn't support dict action spaces")
def step_async(self, actions: np.ndarray) -> None:
self._check_val(event="step_async", actions=actions)
@ -44,19 +48,40 @@ class VecCheckNan(VecEnvWrapper):
self._observations = observations
return observations
def check_array_value(self, name: str, value: np.ndarray) -> List[Tuple[str, str]]:
"""
Check for inf and NaN for a single numpy array.
:param name: Name of the value being check
:param value: Value (numpy array) to check
:return: A list of issues found.
"""
found = []
has_nan = np.any(np.isnan(value))
has_inf = self.check_inf and np.any(np.isinf(value))
if has_inf:
found.append((name, "inf"))
if has_nan:
found.append((name, "nan"))
return found
def _check_val(self, event: str, **kwargs) -> None:
# if warn and warn once and have warned once: then stop checking
if not self.raise_exception and self.warn_once and self._user_warned:
return
found = []
for name, val in kwargs.items():
has_nan = np.any(np.isnan(val))
has_inf = self.check_inf and np.any(np.isinf(val))
if has_inf:
found.append((name, "inf"))
if has_nan:
found.append((name, "nan"))
for name, value in kwargs.items():
if isinstance(value, (np.ndarray, list)):
found += self.check_array_value(name, np.asarray(value))
elif isinstance(value, dict):
for inner_name, inner_val in value.items():
found += self.check_array_value(f"{name}.{inner_name}", inner_val)
elif isinstance(value, tuple):
for idx, inner_val in enumerate(value):
found += self.check_array_value(f"{name}.{idx}", inner_val)
else:
raise TypeError(f"Unsupported observation type {type(value)}.")
if found:
self._user_warned = True

View file

@ -1 +1 @@
1.8.0a6
1.8.0a7