Add has_attr for VecEnv (#2077)

* Add `has_attr` for `VecEnv`

* Add special case for gymnasium<1.0

* Update changelog.rst

* Update black version
This commit is contained in:
Antonin RAFFIN 2025-02-03 10:43:56 +01:00 committed by GitHub
parent ee8a77defb
commit b8b2d30a83
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 89 additions and 15 deletions

View file

@ -3,6 +3,41 @@
Changelog
==========
Release 2.6.0a0 (WIP)
--------------------------
Breaking Changes:
^^^^^^^^^^^^^^^^^
New Features:
^^^^^^^^^^^^^
- Added ``has_attr`` method for ``VecEnv`` to check if an attribute exists
Bug Fixes:
^^^^^^^^^^
- `SubProcVecEnv` will now exit gracefully (without big traceback) when using `KeyboardInterrupt`
`SB3-Contrib`_
^^^^^^^^^^^^^^
`RL Zoo`_
^^^^^^^^^
`SBX`_ (SB3 + Jax)
^^^^^^^^^^^^^^^^^^
Deprecations:
^^^^^^^^^^^^^
Others:
^^^^^^^
- Updated black from v24 to v25
Documentation:
^^^^^^^^^^^^^^
Release 2.5.0 (2025-01-27)
--------------------------
@ -19,23 +54,11 @@ New Features:
- Added support for NumPy v2.0: ``VecNormalize`` now cast normalized rewards to float32, updated bit flipping env to avoid overflow issues too
- Added official support for Python 3.12
Bug Fixes:
^^^^^^^^^^
`SB3-Contrib`_
^^^^^^^^^^^^^^
`RL Zoo`_
^^^^^^^^^
`SBX`_ (SB3 + Jax)
^^^^^^^^^^^^^^^^^^
- Added SimBa Policy: Simplicity Bias for Scaling Up Parameters in DRL
- Added support for parameter resets
Deprecations:
^^^^^^^^^^^^^
Others:
^^^^^^^
- Updated Dockerfile

View file

@ -98,7 +98,7 @@ setup(
# Lint code and sort imports (flake8 and isort replacement)
"ruff>=0.3.1",
# Reformat
"black>=24.2.0,<25",
"black>=25.1.0,<26",
],
"docs": [
"sphinx>=5,<9",

View file

@ -147,6 +147,21 @@ class VecEnv(ABC):
"""
raise NotImplementedError()
def has_attr(self, attr_name: str) -> bool:
"""
Check if an attribute exists for a vectorized environment.
:param attr_name: The name of the attribute to check
:return: True if 'attr_name' exists in all environments
"""
# Default implementation, will not work with things that cannot be pickled:
# https://github.com/Stable-Baselines-Team/stable-baselines3-contrib/issues/49
try:
self.get_attr(attr_name)
return True
except AttributeError:
return False
@abstractmethod
def get_attr(self, attr_name: str, indices: VecEnvIndices = None) -> list[Any]:
"""
@ -392,6 +407,9 @@ class VecEnvWrapper(VecEnv):
def get_images(self) -> Sequence[Optional[np.ndarray]]:
return self.venv.get_images()
def has_attr(self, attr_name: str) -> bool:
return self.venv.has_attr(attr_name)
def get_attr(self, attr_name: str, indices: VecEnvIndices = None) -> list[Any]:
return self.venv.get_attr(attr_name, indices)

View file

@ -17,7 +17,7 @@ from stable_baselines3.common.vec_env.base_vec_env import (
from stable_baselines3.common.vec_env.patch_gym import _patch_env
def _worker(
def _worker( # noqa: C901
remote: mp.connection.Connection,
parent_remote: mp.connection.Connection,
env_fn_wrapper: CloudpickleWrapper,
@ -58,6 +58,12 @@ def _worker(
remote.send(method(*data[1], **data[2]))
elif cmd == "get_attr":
remote.send(env.get_wrapper_attr(data))
elif cmd == "has_attr":
try:
env.get_wrapper_attr(data)
remote.send(True)
except AttributeError:
remote.send(False)
elif cmd == "set_attr":
remote.send(setattr(env, data[0], data[1])) # type: ignore[func-returns-value]
elif cmd == "is_wrapped":
@ -66,6 +72,8 @@ def _worker(
raise NotImplementedError(f"`{cmd}` is not implemented in the worker")
except EOFError:
break
except KeyboardInterrupt:
break
class SubprocVecEnv(VecEnv):
@ -165,6 +173,13 @@ class SubprocVecEnv(VecEnv):
outputs = [pipe.recv() for pipe in self.remotes]
return outputs
def has_attr(self, attr_name: str) -> bool:
"""Check if an attribute exists for a vectorized environment. (see base class)."""
target_remotes = self._get_target_remotes(indices=None)
for remote in target_remotes:
remote.send(("has_attr", attr_name))
return all([remote.recv() for remote in target_remotes])
def get_attr(self, attr_name: str, indices: VecEnvIndices = None) -> list[Any]:
"""Return attribute from vectorized environment (see base class)."""
target_remotes = self._get_target_remotes(indices)

View file

@ -1 +1 @@
2.5.0
2.6.0a0

View file

@ -123,12 +123,30 @@ def test_vecenv_custom_calls(vec_env_class, vec_env_wrapper):
# we need a X server to test the "human" mode (uses OpenCV)
# vec_env.render(mode="human")
# Set a new attribute, on the last wrapper and on the env
assert not vec_env.has_attr("dummy")
# Set value for the last wrapper only
vec_env.set_attr("dummy", 12)
assert vec_env.get_attr("dummy") == [12] * N_ENVS
if vec_env_class == DummyVecEnv:
assert vec_env.envs[0].dummy == 12
assert not vec_env.has_attr("dummy2")
# Set the value on the original env
# `set_wrapper_attr` doesn't exist before v1.0
if gym.__version__ > "1":
vec_env.env_method("set_wrapper_attr", "dummy2", 2)
assert vec_env.get_attr("dummy2") == [2] * N_ENVS
if vec_env_class == DummyVecEnv:
assert vec_env.envs[0].unwrapped.dummy2 == 2
env_method_results = vec_env.env_method("custom_method", 1, indices=None, dim_1=2)
setattr_results = []
# Set current_step to an arbitrary value
for env_idx in range(N_ENVS):
setattr_results.append(vec_env.set_attr("current_step", env_idx, indices=env_idx))
# Retrieve the value for each environment
assert vec_env.has_attr("current_step")
getattr_results = vec_env.get_attr("current_step")
assert len(env_method_results) == N_ENVS