mirror of
https://github.com/saymrwulf/stable-baselines3.git
synced 2026-05-14 20:58:03 +00:00
Add has_attr for VecEnv (#2077)
* Add `has_attr` for `VecEnv` * Add special case for gymnasium<1.0 * Update changelog.rst * Update black version
This commit is contained in:
parent
ee8a77defb
commit
b8b2d30a83
6 changed files with 89 additions and 15 deletions
|
|
@ -3,6 +3,41 @@
|
|||
Changelog
|
||||
==========
|
||||
|
||||
Release 2.6.0a0 (WIP)
|
||||
--------------------------
|
||||
|
||||
|
||||
Breaking Changes:
|
||||
^^^^^^^^^^^^^^^^^
|
||||
|
||||
New Features:
|
||||
^^^^^^^^^^^^^
|
||||
- Added ``has_attr`` method for ``VecEnv`` to check if an attribute exists
|
||||
|
||||
Bug Fixes:
|
||||
^^^^^^^^^^
|
||||
- `SubProcVecEnv` will now exit gracefully (without big traceback) when using `KeyboardInterrupt`
|
||||
|
||||
`SB3-Contrib`_
|
||||
^^^^^^^^^^^^^^
|
||||
|
||||
`RL Zoo`_
|
||||
^^^^^^^^^
|
||||
|
||||
`SBX`_ (SB3 + Jax)
|
||||
^^^^^^^^^^^^^^^^^^
|
||||
|
||||
Deprecations:
|
||||
^^^^^^^^^^^^^
|
||||
|
||||
Others:
|
||||
^^^^^^^
|
||||
- Updated black from v24 to v25
|
||||
|
||||
Documentation:
|
||||
^^^^^^^^^^^^^^
|
||||
|
||||
|
||||
Release 2.5.0 (2025-01-27)
|
||||
--------------------------
|
||||
|
||||
|
|
@ -19,23 +54,11 @@ New Features:
|
|||
- Added support for NumPy v2.0: ``VecNormalize`` now cast normalized rewards to float32, updated bit flipping env to avoid overflow issues too
|
||||
- Added official support for Python 3.12
|
||||
|
||||
Bug Fixes:
|
||||
^^^^^^^^^^
|
||||
|
||||
`SB3-Contrib`_
|
||||
^^^^^^^^^^^^^^
|
||||
|
||||
`RL Zoo`_
|
||||
^^^^^^^^^
|
||||
|
||||
`SBX`_ (SB3 + Jax)
|
||||
^^^^^^^^^^^^^^^^^^
|
||||
- Added SimBa Policy: Simplicity Bias for Scaling Up Parameters in DRL
|
||||
- Added support for parameter resets
|
||||
|
||||
Deprecations:
|
||||
^^^^^^^^^^^^^
|
||||
|
||||
Others:
|
||||
^^^^^^^
|
||||
- Updated Dockerfile
|
||||
|
|
|
|||
2
setup.py
2
setup.py
|
|
@ -98,7 +98,7 @@ setup(
|
|||
# Lint code and sort imports (flake8 and isort replacement)
|
||||
"ruff>=0.3.1",
|
||||
# Reformat
|
||||
"black>=24.2.0,<25",
|
||||
"black>=25.1.0,<26",
|
||||
],
|
||||
"docs": [
|
||||
"sphinx>=5,<9",
|
||||
|
|
|
|||
|
|
@ -147,6 +147,21 @@ class VecEnv(ABC):
|
|||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def has_attr(self, attr_name: str) -> bool:
|
||||
"""
|
||||
Check if an attribute exists for a vectorized environment.
|
||||
|
||||
:param attr_name: The name of the attribute to check
|
||||
:return: True if 'attr_name' exists in all environments
|
||||
"""
|
||||
# Default implementation, will not work with things that cannot be pickled:
|
||||
# https://github.com/Stable-Baselines-Team/stable-baselines3-contrib/issues/49
|
||||
try:
|
||||
self.get_attr(attr_name)
|
||||
return True
|
||||
except AttributeError:
|
||||
return False
|
||||
|
||||
@abstractmethod
|
||||
def get_attr(self, attr_name: str, indices: VecEnvIndices = None) -> list[Any]:
|
||||
"""
|
||||
|
|
@ -392,6 +407,9 @@ class VecEnvWrapper(VecEnv):
|
|||
def get_images(self) -> Sequence[Optional[np.ndarray]]:
|
||||
return self.venv.get_images()
|
||||
|
||||
def has_attr(self, attr_name: str) -> bool:
|
||||
return self.venv.has_attr(attr_name)
|
||||
|
||||
def get_attr(self, attr_name: str, indices: VecEnvIndices = None) -> list[Any]:
|
||||
return self.venv.get_attr(attr_name, indices)
|
||||
|
||||
|
|
|
|||
|
|
@ -17,7 +17,7 @@ from stable_baselines3.common.vec_env.base_vec_env import (
|
|||
from stable_baselines3.common.vec_env.patch_gym import _patch_env
|
||||
|
||||
|
||||
def _worker(
|
||||
def _worker( # noqa: C901
|
||||
remote: mp.connection.Connection,
|
||||
parent_remote: mp.connection.Connection,
|
||||
env_fn_wrapper: CloudpickleWrapper,
|
||||
|
|
@ -58,6 +58,12 @@ def _worker(
|
|||
remote.send(method(*data[1], **data[2]))
|
||||
elif cmd == "get_attr":
|
||||
remote.send(env.get_wrapper_attr(data))
|
||||
elif cmd == "has_attr":
|
||||
try:
|
||||
env.get_wrapper_attr(data)
|
||||
remote.send(True)
|
||||
except AttributeError:
|
||||
remote.send(False)
|
||||
elif cmd == "set_attr":
|
||||
remote.send(setattr(env, data[0], data[1])) # type: ignore[func-returns-value]
|
||||
elif cmd == "is_wrapped":
|
||||
|
|
@ -66,6 +72,8 @@ def _worker(
|
|||
raise NotImplementedError(f"`{cmd}` is not implemented in the worker")
|
||||
except EOFError:
|
||||
break
|
||||
except KeyboardInterrupt:
|
||||
break
|
||||
|
||||
|
||||
class SubprocVecEnv(VecEnv):
|
||||
|
|
@ -165,6 +173,13 @@ class SubprocVecEnv(VecEnv):
|
|||
outputs = [pipe.recv() for pipe in self.remotes]
|
||||
return outputs
|
||||
|
||||
def has_attr(self, attr_name: str) -> bool:
|
||||
"""Check if an attribute exists for a vectorized environment. (see base class)."""
|
||||
target_remotes = self._get_target_remotes(indices=None)
|
||||
for remote in target_remotes:
|
||||
remote.send(("has_attr", attr_name))
|
||||
return all([remote.recv() for remote in target_remotes])
|
||||
|
||||
def get_attr(self, attr_name: str, indices: VecEnvIndices = None) -> list[Any]:
|
||||
"""Return attribute from vectorized environment (see base class)."""
|
||||
target_remotes = self._get_target_remotes(indices)
|
||||
|
|
|
|||
|
|
@ -1 +1 @@
|
|||
2.5.0
|
||||
2.6.0a0
|
||||
|
|
|
|||
|
|
@ -123,12 +123,30 @@ def test_vecenv_custom_calls(vec_env_class, vec_env_wrapper):
|
|||
# we need a X server to test the "human" mode (uses OpenCV)
|
||||
# vec_env.render(mode="human")
|
||||
|
||||
# Set a new attribute, on the last wrapper and on the env
|
||||
assert not vec_env.has_attr("dummy")
|
||||
# Set value for the last wrapper only
|
||||
vec_env.set_attr("dummy", 12)
|
||||
assert vec_env.get_attr("dummy") == [12] * N_ENVS
|
||||
if vec_env_class == DummyVecEnv:
|
||||
assert vec_env.envs[0].dummy == 12
|
||||
|
||||
assert not vec_env.has_attr("dummy2")
|
||||
# Set the value on the original env
|
||||
# `set_wrapper_attr` doesn't exist before v1.0
|
||||
if gym.__version__ > "1":
|
||||
vec_env.env_method("set_wrapper_attr", "dummy2", 2)
|
||||
assert vec_env.get_attr("dummy2") == [2] * N_ENVS
|
||||
if vec_env_class == DummyVecEnv:
|
||||
assert vec_env.envs[0].unwrapped.dummy2 == 2
|
||||
|
||||
env_method_results = vec_env.env_method("custom_method", 1, indices=None, dim_1=2)
|
||||
setattr_results = []
|
||||
# Set current_step to an arbitrary value
|
||||
for env_idx in range(N_ENVS):
|
||||
setattr_results.append(vec_env.set_attr("current_step", env_idx, indices=env_idx))
|
||||
# Retrieve the value for each environment
|
||||
assert vec_env.has_attr("current_step")
|
||||
getattr_results = vec_env.get_attr("current_step")
|
||||
|
||||
assert len(env_method_results) == N_ENVS
|
||||
|
|
|
|||
Loading…
Reference in a new issue