Fix render_mode when loading VecNormalize (#1671)

* Fix render_mode when loading VecNormalize

* Switch from isort to ruff, and cap black version

* Add test and update changelog
This commit is contained in:
Antonin RAFFIN 2023-09-12 11:28:32 +02:00 committed by GitHub
parent 57dbefe80c
commit 99712760c8
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
8 changed files with 39 additions and 21 deletions

View file

@ -38,7 +38,7 @@ pip install -e .[docs,tests,extra]
## Codestyle
We use [black codestyle](https://github.com/psf/black) (max line length of 127 characters) together with [isort](https://github.com/timothycrosley/isort) to sort the imports.
We use [black codestyle](https://github.com/psf/black) (max line length of 127 characters) together with [ruff](https://github.com/astral-sh/ruff) (isort rules) to sort the imports.
For the documentation, we use the default line length of 88 characters per line.
**Please run `make format`** to reformat your code. You can check the codestyle using `make check-codestyle` and `make lint`.
@ -63,7 +63,7 @@ def my_function(arg1: type1, arg2: type2) -> returntype:
Before proposing a PR, please open an issue, where the feature will be discussed. This prevent from duplicated PR to be proposed and also ease the code review process.
Each PR need to be reviewed and accepted by at least one of the maintainers (@hill-a, @araffin, @ernestum, @AdamGleave or @Miffyli).
Each PR need to be reviewed and accepted by at least one of the maintainers (@hill-a, @araffin, @ernestum, @AdamGleave, @Miffyli or @qgallouedec).
A PR must pass the Continuous Integration tests to be merged with the master branch.
@ -85,7 +85,7 @@ Type checking with `pytype` and `mypy`:
make type
```
Codestyle check with `black`, `isort` and `ruff`:
Codestyle check with `black`, and `ruff` (`isort` rules):
```
make check-codestyle

View file

@ -27,13 +27,13 @@ lint:
format:
# Sort imports
isort ${LINT_PATHS}
ruff --select I ${LINT_PATHS} --fix
# Reformat using black
black ${LINT_PATHS}
check-codestyle:
# Sort imports
isort --check ${LINT_PATHS}
ruff --select I ${LINT_PATHS}
# Reformat using black
black --check ${LINT_PATHS}

View file

@ -3,11 +3,12 @@
Changelog
==========
Release 2.2.0a1 (WIP)
Release 2.2.0a2 (WIP)
--------------------------
Breaking Changes:
^^^^^^^^^^^^^^^^^
- Switched to ``ruff`` for sorting imports (isort is no longer needed), black and ruff version now require a minimum version
New Features:
^^^^^^^^^^^^^
@ -18,6 +19,10 @@ New Features:
`RL Zoo`_
^^^^^^^^^
`SBX`_
^^^^^^^^^
- Added ``DDPG`` and ``TD3``
Bug Fixes:
^^^^^^^^^^
- Prevents using squash_output and not use_sde in ActorCritcPolicy (@PatrickHelm)
@ -25,7 +30,8 @@ Bug Fixes:
- Moves VectorizedActionNoise into ``_setup_learn()`` in OffPolicyAlgorithm (@PatrickHelm)
- Prevents out of bound error on Windows if no seed is passed (@PatrickHelm)
- Calls ``callback.update_locals()`` before ``callback.on_rollout_end()`` in OnPolicyAlgorithm (@PatrickHelm)
- Fixes replay buffer device after loading in OffPolicyAlgorithm (@PatrickHelm)
- Fixed replay buffer device after loading in OffPolicyAlgorithm (@PatrickHelm)
- Fixed ``render_mode`` which was not properly loaded when using ``VecNormalize.load()``
Deprecations:
@ -1424,6 +1430,7 @@ and `Quentin Gallouédec`_ (aka @qgallouedec).
.. _SB3-Contrib: https://github.com/Stable-Baselines-Team/stable-baselines3-contrib
.. _RL Zoo: https://github.com/DLR-RM/rl-baselines3-zoo
.. _SBX: https://github.com/araffin/sbx
Contributors:
-------------

View file

@ -24,11 +24,6 @@ max-complexity = 15
[tool.black]
line-length = 127
[tool.isort]
profile = "black"
line_length = 127
src_paths = ["stable_baselines3"]
[tool.pytype]
inputs = ["stable_baselines3"]
disable = ["pyi-error"]

View file

@ -120,12 +120,10 @@ setup(
# Type check
"pytype",
"mypy",
# Lint code (flake8 replacement)
"ruff",
# Sort imports
"isort>=5.0",
# Lint code and sort imports (flake8 and isort replacement)
"ruff>=0.0.288",
# Reformat
"black",
"black>=23.9.1,<24",
],
"docs": [
"sphinx>=5.3,<7.0",

View file

@ -163,6 +163,7 @@ class VecNormalize(VecEnvWrapper):
self.venv = venv
self.num_envs = venv.num_envs
self.class_attributes = dict(inspect.getmembers(self.__class__))
self.render_mode = venv.render_mode
# Check that the observation_space shape match
utils.check_shape_equal(self.observation_space, venv.observation_space)

View file

@ -1 +1 @@
2.2.0a1
2.2.0a2

View file

@ -123,6 +123,10 @@ def make_env():
return Monitor(gym.make(ENV_ID))
def make_env_render():
return Monitor(gym.make(ENV_ID, render_mode="rgb_array"))
def make_dict_env():
return Monitor(DummyDictEnv())
@ -257,14 +261,17 @@ def test_obs_rms_vec_normalize():
assert np.allclose(env.ret_rms.mean, 5.688, atol=1e-3)
@pytest.mark.parametrize("make_env", [make_env, make_dict_env, make_image_env])
def test_vec_env(tmp_path, make_env):
@pytest.mark.parametrize("make_gym_env", [make_env, make_dict_env, make_image_env])
def test_vec_env(tmp_path, make_gym_env):
"""Test VecNormalize Object"""
clip_obs = 0.5
clip_reward = 5.0
orig_venv = DummyVecEnv([make_env])
orig_venv = DummyVecEnv([make_gym_env])
norm_venv = VecNormalize(orig_venv, norm_obs=True, norm_reward=True, clip_obs=clip_obs, clip_reward=clip_reward)
assert orig_venv.render_mode is None
assert norm_venv.render_mode is None
_, done = norm_venv.reset(), [False]
while not done[0]:
actions = [norm_venv.action_space.sample()]
@ -278,9 +285,19 @@ def test_vec_env(tmp_path, make_env):
path = tmp_path / "vec_normalize"
norm_venv.save(path)
assert orig_venv.render_mode is None
deserialized = VecNormalize.load(path, venv=orig_venv)
assert deserialized.render_mode is None
check_vec_norm_equal(norm_venv, deserialized)
# Check that render mode is properly updated
vec_env = DummyVecEnv([make_env_render])
assert vec_env.render_mode == "rgb_array"
# Test that loading and wrapping keep the correct render mode
if make_gym_env == make_env:
assert VecNormalize.load(path, venv=vec_env).render_mode == "rgb_array"
assert VecNormalize(vec_env).render_mode == "rgb_array"
def test_get_original():
venv = _make_warmstart_cartpole()