diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index cf0db00..c7967a1 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -3,6 +3,41 @@ Changelog ========== +Release 2.6.0a0 (WIP) +-------------------------- + + +Breaking Changes: +^^^^^^^^^^^^^^^^^ + +New Features: +^^^^^^^^^^^^^ +- Added ``has_attr`` method for ``VecEnv`` to check if an attribute exists + +Bug Fixes: +^^^^^^^^^^ +- `SubProcVecEnv` will now exit gracefully (without big traceback) when using `KeyboardInterrupt` + +`SB3-Contrib`_ +^^^^^^^^^^^^^^ + +`RL Zoo`_ +^^^^^^^^^ + +`SBX`_ (SB3 + Jax) +^^^^^^^^^^^^^^^^^^ + +Deprecations: +^^^^^^^^^^^^^ + +Others: +^^^^^^^ +- Updated black from v24 to v25 + +Documentation: +^^^^^^^^^^^^^^ + + Release 2.5.0 (2025-01-27) -------------------------- @@ -19,23 +54,11 @@ New Features: - Added support for NumPy v2.0: ``VecNormalize`` now cast normalized rewards to float32, updated bit flipping env to avoid overflow issues too - Added official support for Python 3.12 -Bug Fixes: -^^^^^^^^^^ - -`SB3-Contrib`_ -^^^^^^^^^^^^^^ - -`RL Zoo`_ -^^^^^^^^^ - `SBX`_ (SB3 + Jax) ^^^^^^^^^^^^^^^^^^ - Added SimBa Policy: Simplicity Bias for Scaling Up Parameters in DRL - Added support for parameter resets -Deprecations: -^^^^^^^^^^^^^ - Others: ^^^^^^^ - Updated Dockerfile diff --git a/setup.py b/setup.py index fa24fc8..8123cf4 100644 --- a/setup.py +++ b/setup.py @@ -98,7 +98,7 @@ setup( # Lint code and sort imports (flake8 and isort replacement) "ruff>=0.3.1", # Reformat - "black>=24.2.0,<25", + "black>=25.1.0,<26", ], "docs": [ "sphinx>=5,<9", diff --git a/stable_baselines3/common/vec_env/base_vec_env.py b/stable_baselines3/common/vec_env/base_vec_env.py index 71ee15e..3701131 100644 --- a/stable_baselines3/common/vec_env/base_vec_env.py +++ b/stable_baselines3/common/vec_env/base_vec_env.py @@ -147,6 +147,21 @@ class VecEnv(ABC): """ raise NotImplementedError() + def has_attr(self, attr_name: str) -> bool: + """ + Check if an attribute exists for a vectorized environment. + + :param attr_name: The name of the attribute to check + :return: True if 'attr_name' exists in all environments + """ + # Default implementation, will not work with things that cannot be pickled: + # https://github.com/Stable-Baselines-Team/stable-baselines3-contrib/issues/49 + try: + self.get_attr(attr_name) + return True + except AttributeError: + return False + @abstractmethod def get_attr(self, attr_name: str, indices: VecEnvIndices = None) -> list[Any]: """ @@ -392,6 +407,9 @@ class VecEnvWrapper(VecEnv): def get_images(self) -> Sequence[Optional[np.ndarray]]: return self.venv.get_images() + def has_attr(self, attr_name: str) -> bool: + return self.venv.has_attr(attr_name) + def get_attr(self, attr_name: str, indices: VecEnvIndices = None) -> list[Any]: return self.venv.get_attr(attr_name, indices) diff --git a/stable_baselines3/common/vec_env/subproc_vec_env.py b/stable_baselines3/common/vec_env/subproc_vec_env.py index 225eadd..1563d70 100644 --- a/stable_baselines3/common/vec_env/subproc_vec_env.py +++ b/stable_baselines3/common/vec_env/subproc_vec_env.py @@ -17,7 +17,7 @@ from stable_baselines3.common.vec_env.base_vec_env import ( from stable_baselines3.common.vec_env.patch_gym import _patch_env -def _worker( +def _worker( # noqa: C901 remote: mp.connection.Connection, parent_remote: mp.connection.Connection, env_fn_wrapper: CloudpickleWrapper, @@ -58,6 +58,12 @@ def _worker( remote.send(method(*data[1], **data[2])) elif cmd == "get_attr": remote.send(env.get_wrapper_attr(data)) + elif cmd == "has_attr": + try: + env.get_wrapper_attr(data) + remote.send(True) + except AttributeError: + remote.send(False) elif cmd == "set_attr": remote.send(setattr(env, data[0], data[1])) # type: ignore[func-returns-value] elif cmd == "is_wrapped": @@ -66,6 +72,8 @@ def _worker( raise NotImplementedError(f"`{cmd}` is not implemented in the worker") except EOFError: break + except KeyboardInterrupt: + break class SubprocVecEnv(VecEnv): @@ -165,6 +173,13 @@ class SubprocVecEnv(VecEnv): outputs = [pipe.recv() for pipe in self.remotes] return outputs + def has_attr(self, attr_name: str) -> bool: + """Check if an attribute exists for a vectorized environment. (see base class).""" + target_remotes = self._get_target_remotes(indices=None) + for remote in target_remotes: + remote.send(("has_attr", attr_name)) + return all([remote.recv() for remote in target_remotes]) + def get_attr(self, attr_name: str, indices: VecEnvIndices = None) -> list[Any]: """Return attribute from vectorized environment (see base class).""" target_remotes = self._get_target_remotes(indices) diff --git a/stable_baselines3/version.txt b/stable_baselines3/version.txt index 437459c..3d87ca9 100644 --- a/stable_baselines3/version.txt +++ b/stable_baselines3/version.txt @@ -1 +1 @@ -2.5.0 +2.6.0a0 diff --git a/tests/test_vec_envs.py b/tests/test_vec_envs.py index 7e4e5ec..43a693d 100644 --- a/tests/test_vec_envs.py +++ b/tests/test_vec_envs.py @@ -123,12 +123,30 @@ def test_vecenv_custom_calls(vec_env_class, vec_env_wrapper): # we need a X server to test the "human" mode (uses OpenCV) # vec_env.render(mode="human") + # Set a new attribute, on the last wrapper and on the env + assert not vec_env.has_attr("dummy") + # Set value for the last wrapper only + vec_env.set_attr("dummy", 12) + assert vec_env.get_attr("dummy") == [12] * N_ENVS + if vec_env_class == DummyVecEnv: + assert vec_env.envs[0].dummy == 12 + + assert not vec_env.has_attr("dummy2") + # Set the value on the original env + # `set_wrapper_attr` doesn't exist before v1.0 + if gym.__version__ > "1": + vec_env.env_method("set_wrapper_attr", "dummy2", 2) + assert vec_env.get_attr("dummy2") == [2] * N_ENVS + if vec_env_class == DummyVecEnv: + assert vec_env.envs[0].unwrapped.dummy2 == 2 + env_method_results = vec_env.env_method("custom_method", 1, indices=None, dim_1=2) setattr_results = [] # Set current_step to an arbitrary value for env_idx in range(N_ENVS): setattr_results.append(vec_env.set_attr("current_step", env_idx, indices=env_idx)) # Retrieve the value for each environment + assert vec_env.has_attr("current_step") getattr_results = vec_env.get_attr("current_step") assert len(env_method_results) == N_ENVS