mirror of
https://github.com/saymrwulf/stable-baselines3.git
synced 2026-05-14 20:58:03 +00:00
Fix render_mode when loading VecNormalize (#1671)
* Fix render_mode when loading VecNormalize * Switch from isort to ruff, and cap black version * Add test and update changelog
This commit is contained in:
parent
57dbefe80c
commit
99712760c8
8 changed files with 39 additions and 21 deletions
|
|
@ -38,7 +38,7 @@ pip install -e .[docs,tests,extra]
|
|||
|
||||
## Codestyle
|
||||
|
||||
We use [black codestyle](https://github.com/psf/black) (max line length of 127 characters) together with [isort](https://github.com/timothycrosley/isort) to sort the imports.
|
||||
We use [black codestyle](https://github.com/psf/black) (max line length of 127 characters) together with [ruff](https://github.com/astral-sh/ruff) (isort rules) to sort the imports.
|
||||
For the documentation, we use the default line length of 88 characters per line.
|
||||
|
||||
**Please run `make format`** to reformat your code. You can check the codestyle using `make check-codestyle` and `make lint`.
|
||||
|
|
@ -63,7 +63,7 @@ def my_function(arg1: type1, arg2: type2) -> returntype:
|
|||
|
||||
Before proposing a PR, please open an issue, where the feature will be discussed. This prevent from duplicated PR to be proposed and also ease the code review process.
|
||||
|
||||
Each PR need to be reviewed and accepted by at least one of the maintainers (@hill-a, @araffin, @ernestum, @AdamGleave or @Miffyli).
|
||||
Each PR need to be reviewed and accepted by at least one of the maintainers (@hill-a, @araffin, @ernestum, @AdamGleave, @Miffyli or @qgallouedec).
|
||||
A PR must pass the Continuous Integration tests to be merged with the master branch.
|
||||
|
||||
|
||||
|
|
@ -85,7 +85,7 @@ Type checking with `pytype` and `mypy`:
|
|||
make type
|
||||
```
|
||||
|
||||
Codestyle check with `black`, `isort` and `ruff`:
|
||||
Codestyle check with `black`, and `ruff` (`isort` rules):
|
||||
|
||||
```
|
||||
make check-codestyle
|
||||
|
|
|
|||
4
Makefile
4
Makefile
|
|
@ -27,13 +27,13 @@ lint:
|
|||
|
||||
format:
|
||||
# Sort imports
|
||||
isort ${LINT_PATHS}
|
||||
ruff --select I ${LINT_PATHS} --fix
|
||||
# Reformat using black
|
||||
black ${LINT_PATHS}
|
||||
|
||||
check-codestyle:
|
||||
# Sort imports
|
||||
isort --check ${LINT_PATHS}
|
||||
ruff --select I ${LINT_PATHS}
|
||||
# Reformat using black
|
||||
black --check ${LINT_PATHS}
|
||||
|
||||
|
|
|
|||
|
|
@ -3,11 +3,12 @@
|
|||
Changelog
|
||||
==========
|
||||
|
||||
Release 2.2.0a1 (WIP)
|
||||
Release 2.2.0a2 (WIP)
|
||||
--------------------------
|
||||
|
||||
Breaking Changes:
|
||||
^^^^^^^^^^^^^^^^^
|
||||
- Switched to ``ruff`` for sorting imports (isort is no longer needed), black and ruff version now require a minimum version
|
||||
|
||||
New Features:
|
||||
^^^^^^^^^^^^^
|
||||
|
|
@ -18,6 +19,10 @@ New Features:
|
|||
`RL Zoo`_
|
||||
^^^^^^^^^
|
||||
|
||||
`SBX`_
|
||||
^^^^^^^^^
|
||||
- Added ``DDPG`` and ``TD3``
|
||||
|
||||
Bug Fixes:
|
||||
^^^^^^^^^^
|
||||
- Prevents using squash_output and not use_sde in ActorCritcPolicy (@PatrickHelm)
|
||||
|
|
@ -25,7 +30,8 @@ Bug Fixes:
|
|||
- Moves VectorizedActionNoise into ``_setup_learn()`` in OffPolicyAlgorithm (@PatrickHelm)
|
||||
- Prevents out of bound error on Windows if no seed is passed (@PatrickHelm)
|
||||
- Calls ``callback.update_locals()`` before ``callback.on_rollout_end()`` in OnPolicyAlgorithm (@PatrickHelm)
|
||||
- Fixes replay buffer device after loading in OffPolicyAlgorithm (@PatrickHelm)
|
||||
- Fixed replay buffer device after loading in OffPolicyAlgorithm (@PatrickHelm)
|
||||
- Fixed ``render_mode`` which was not properly loaded when using ``VecNormalize.load()``
|
||||
|
||||
|
||||
Deprecations:
|
||||
|
|
@ -1424,6 +1430,7 @@ and `Quentin Gallouédec`_ (aka @qgallouedec).
|
|||
|
||||
.. _SB3-Contrib: https://github.com/Stable-Baselines-Team/stable-baselines3-contrib
|
||||
.. _RL Zoo: https://github.com/DLR-RM/rl-baselines3-zoo
|
||||
.. _SBX: https://github.com/araffin/sbx
|
||||
|
||||
Contributors:
|
||||
-------------
|
||||
|
|
|
|||
|
|
@ -24,11 +24,6 @@ max-complexity = 15
|
|||
[tool.black]
|
||||
line-length = 127
|
||||
|
||||
[tool.isort]
|
||||
profile = "black"
|
||||
line_length = 127
|
||||
src_paths = ["stable_baselines3"]
|
||||
|
||||
[tool.pytype]
|
||||
inputs = ["stable_baselines3"]
|
||||
disable = ["pyi-error"]
|
||||
|
|
|
|||
8
setup.py
8
setup.py
|
|
@ -120,12 +120,10 @@ setup(
|
|||
# Type check
|
||||
"pytype",
|
||||
"mypy",
|
||||
# Lint code (flake8 replacement)
|
||||
"ruff",
|
||||
# Sort imports
|
||||
"isort>=5.0",
|
||||
# Lint code and sort imports (flake8 and isort replacement)
|
||||
"ruff>=0.0.288",
|
||||
# Reformat
|
||||
"black",
|
||||
"black>=23.9.1,<24",
|
||||
],
|
||||
"docs": [
|
||||
"sphinx>=5.3,<7.0",
|
||||
|
|
|
|||
|
|
@ -163,6 +163,7 @@ class VecNormalize(VecEnvWrapper):
|
|||
self.venv = venv
|
||||
self.num_envs = venv.num_envs
|
||||
self.class_attributes = dict(inspect.getmembers(self.__class__))
|
||||
self.render_mode = venv.render_mode
|
||||
|
||||
# Check that the observation_space shape match
|
||||
utils.check_shape_equal(self.observation_space, venv.observation_space)
|
||||
|
|
|
|||
|
|
@ -1 +1 @@
|
|||
2.2.0a1
|
||||
2.2.0a2
|
||||
|
|
|
|||
|
|
@ -123,6 +123,10 @@ def make_env():
|
|||
return Monitor(gym.make(ENV_ID))
|
||||
|
||||
|
||||
def make_env_render():
|
||||
return Monitor(gym.make(ENV_ID, render_mode="rgb_array"))
|
||||
|
||||
|
||||
def make_dict_env():
|
||||
return Monitor(DummyDictEnv())
|
||||
|
||||
|
|
@ -257,14 +261,17 @@ def test_obs_rms_vec_normalize():
|
|||
assert np.allclose(env.ret_rms.mean, 5.688, atol=1e-3)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("make_env", [make_env, make_dict_env, make_image_env])
|
||||
def test_vec_env(tmp_path, make_env):
|
||||
@pytest.mark.parametrize("make_gym_env", [make_env, make_dict_env, make_image_env])
|
||||
def test_vec_env(tmp_path, make_gym_env):
|
||||
"""Test VecNormalize Object"""
|
||||
clip_obs = 0.5
|
||||
clip_reward = 5.0
|
||||
|
||||
orig_venv = DummyVecEnv([make_env])
|
||||
orig_venv = DummyVecEnv([make_gym_env])
|
||||
norm_venv = VecNormalize(orig_venv, norm_obs=True, norm_reward=True, clip_obs=clip_obs, clip_reward=clip_reward)
|
||||
assert orig_venv.render_mode is None
|
||||
assert norm_venv.render_mode is None
|
||||
|
||||
_, done = norm_venv.reset(), [False]
|
||||
while not done[0]:
|
||||
actions = [norm_venv.action_space.sample()]
|
||||
|
|
@ -278,9 +285,19 @@ def test_vec_env(tmp_path, make_env):
|
|||
|
||||
path = tmp_path / "vec_normalize"
|
||||
norm_venv.save(path)
|
||||
assert orig_venv.render_mode is None
|
||||
deserialized = VecNormalize.load(path, venv=orig_venv)
|
||||
assert deserialized.render_mode is None
|
||||
check_vec_norm_equal(norm_venv, deserialized)
|
||||
|
||||
# Check that render mode is properly updated
|
||||
vec_env = DummyVecEnv([make_env_render])
|
||||
assert vec_env.render_mode == "rgb_array"
|
||||
# Test that loading and wrapping keep the correct render mode
|
||||
if make_gym_env == make_env:
|
||||
assert VecNormalize.load(path, venv=vec_env).render_mode == "rgb_array"
|
||||
assert VecNormalize(vec_env).render_mode == "rgb_array"
|
||||
|
||||
|
||||
def test_get_original():
|
||||
venv = _make_warmstart_cartpole()
|
||||
|
|
|
|||
Loading…
Reference in a new issue