mirror of
https://github.com/saymrwulf/stable-baselines3.git
synced 2026-05-17 21:20:11 +00:00
Add support for dict/tuple obs space for VecCheckNaN (#1348)
* Add support for dict/tuple obs space for VecCheckNaN * Handle list too * Address comments from code review * Ignore B028 (explicit stack level)
This commit is contained in:
parent
085bdd5a68
commit
ed8783cb73
5 changed files with 45 additions and 14 deletions
|
|
@ -4,7 +4,7 @@ Changelog
|
|||
==========
|
||||
|
||||
|
||||
Release 1.8.0a6 (WIP)
|
||||
Release 1.8.0a7 (WIP)
|
||||
--------------------------
|
||||
|
||||
|
||||
|
|
@ -18,6 +18,7 @@ New Features:
|
|||
^^^^^^^^^^^^^
|
||||
- Added ``repeat_action_probability`` argument in ``AtariWrapper``.
|
||||
- Only use ``NoopResetEnv`` and ``MaxAndSkipEnv`` when needed in ``AtariWrapper``
|
||||
- Added support for dict/tuple observations spaces for ``VecCheckNan``, the check is now active in the ``env_checker()`` (@DavyMorgan)
|
||||
|
||||
`SB3-Contrib`_
|
||||
^^^^^^^^^^^^^^
|
||||
|
|
@ -1230,4 +1231,4 @@ And all the contributors:
|
|||
@Gregwar @ycheng517 @quantitative-technologies @bcollazo @git-thor @TibiGG @cool-RR @MWeltevrede
|
||||
@Melanol @qgallouedec @francescoluciano @jlp-ue @burakdmb @timothe-chaumont @honglu2875 @yuanmingqi
|
||||
@anand-bala @hughperkins @sidney-tio @AlexPasqua @dominicgkerr @Akhilez @Rocamonde @tobirohrer @ZikangXiong
|
||||
@DavyMorgan @luizapozzobon @Bonifatius94 @theSquaredError @harveybellini
|
||||
@DavyMorgan @luizapozzobon @Bonifatius94 @theSquaredError @harveybellini @DavyMorgan
|
||||
|
|
|
|||
|
|
@ -70,7 +70,8 @@ exclude = (?x)(
|
|||
|
||||
[flake8]
|
||||
# line breaks before and after binary operators
|
||||
ignore = W503,W504,E203,E231
|
||||
# ignore explicit stack level
|
||||
ignore = W503,W504,E203,E231,B028
|
||||
# Ignore import not used when aliases are defined
|
||||
per-file-ignores =
|
||||
# Default implementation in abstract methods
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@ import gym
|
|||
import numpy as np
|
||||
from gym import spaces
|
||||
|
||||
from stable_baselines3.common.preprocessing import is_image_space_channels_first
|
||||
from stable_baselines3.common.preprocessing import check_for_nested_spaces, is_image_space_channels_first
|
||||
from stable_baselines3.common.vec_env import DummyVecEnv, VecCheckNan
|
||||
|
||||
|
||||
|
|
@ -380,6 +380,10 @@ def check_env(env: gym.Env, warn: bool = True, skip_render_check: bool = True) -
|
|||
if not skip_render_check:
|
||||
_check_render(env, warn=warn) # pragma: no cover
|
||||
|
||||
# The check only works with numpy arrays
|
||||
if _is_numpy_array_space(observation_space) and _is_numpy_array_space(action_space):
|
||||
try:
|
||||
check_for_nested_spaces(env.observation_space)
|
||||
# The check doesn't support nested observations/dict actions
|
||||
# A warning about it has already been emitted
|
||||
_check_nan(env)
|
||||
except NotImplementedError:
|
||||
pass
|
||||
|
|
|
|||
|
|
@ -1,6 +1,8 @@
|
|||
import warnings
|
||||
from typing import List, Tuple
|
||||
|
||||
import numpy as np
|
||||
from gym import spaces
|
||||
|
||||
from stable_baselines3.common.vec_env.base_vec_env import VecEnv, VecEnvObs, VecEnvStepReturn, VecEnvWrapper
|
||||
|
||||
|
|
@ -26,6 +28,8 @@ class VecCheckNan(VecEnvWrapper):
|
|||
|
||||
self._actions: np.ndarray
|
||||
self._observations: VecEnvObs
|
||||
if isinstance(venv.action_space, spaces.Dict):
|
||||
raise NotImplementedError("VecCheckNan doesn't support dict action spaces")
|
||||
|
||||
def step_async(self, actions: np.ndarray) -> None:
|
||||
self._check_val(event="step_async", actions=actions)
|
||||
|
|
@ -44,19 +48,40 @@ class VecCheckNan(VecEnvWrapper):
|
|||
self._observations = observations
|
||||
return observations
|
||||
|
||||
def check_array_value(self, name: str, value: np.ndarray) -> List[Tuple[str, str]]:
|
||||
"""
|
||||
Check for inf and NaN for a single numpy array.
|
||||
|
||||
:param name: Name of the value being check
|
||||
:param value: Value (numpy array) to check
|
||||
:return: A list of issues found.
|
||||
"""
|
||||
found = []
|
||||
has_nan = np.any(np.isnan(value))
|
||||
has_inf = self.check_inf and np.any(np.isinf(value))
|
||||
if has_inf:
|
||||
found.append((name, "inf"))
|
||||
if has_nan:
|
||||
found.append((name, "nan"))
|
||||
return found
|
||||
|
||||
def _check_val(self, event: str, **kwargs) -> None:
|
||||
# if warn and warn once and have warned once: then stop checking
|
||||
if not self.raise_exception and self.warn_once and self._user_warned:
|
||||
return
|
||||
|
||||
found = []
|
||||
for name, val in kwargs.items():
|
||||
has_nan = np.any(np.isnan(val))
|
||||
has_inf = self.check_inf and np.any(np.isinf(val))
|
||||
if has_inf:
|
||||
found.append((name, "inf"))
|
||||
if has_nan:
|
||||
found.append((name, "nan"))
|
||||
for name, value in kwargs.items():
|
||||
if isinstance(value, (np.ndarray, list)):
|
||||
found += self.check_array_value(name, np.asarray(value))
|
||||
elif isinstance(value, dict):
|
||||
for inner_name, inner_val in value.items():
|
||||
found += self.check_array_value(f"{name}.{inner_name}", inner_val)
|
||||
elif isinstance(value, tuple):
|
||||
for idx, inner_val in enumerate(value):
|
||||
found += self.check_array_value(f"{name}.{idx}", inner_val)
|
||||
else:
|
||||
raise TypeError(f"Unsupported observation type {type(value)}.")
|
||||
|
||||
if found:
|
||||
self._user_warned = True
|
||||
|
|
|
|||
|
|
@ -1 +1 @@
|
|||
1.8.0a6
|
||||
1.8.0a7
|
||||
|
|
|
|||
Loading…
Reference in a new issue