From 99712760c8f2bb8a40c201158b649eb0041b830f Mon Sep 17 00:00:00 2001 From: Antonin RAFFIN Date: Tue, 12 Sep 2023 11:28:32 +0200 Subject: [PATCH] Fix render_mode when loading VecNormalize (#1671) * Fix render_mode when loading VecNormalize * Switch from isort to ruff, and cap black version * Add test and update changelog --- CONTRIBUTING.md | 6 ++--- Makefile | 4 ++-- docs/misc/changelog.rst | 11 +++++++-- pyproject.toml | 5 ---- setup.py | 8 +++---- .../common/vec_env/vec_normalize.py | 1 + stable_baselines3/version.txt | 2 +- tests/test_vec_normalize.py | 23 ++++++++++++++++--- 8 files changed, 39 insertions(+), 21 deletions(-) diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 62f47ce..0063cc1 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -38,7 +38,7 @@ pip install -e .[docs,tests,extra] ## Codestyle -We use [black codestyle](https://github.com/psf/black) (max line length of 127 characters) together with [isort](https://github.com/timothycrosley/isort) to sort the imports. +We use [black codestyle](https://github.com/psf/black) (max line length of 127 characters) together with [ruff](https://github.com/astral-sh/ruff) (isort rules) to sort the imports. For the documentation, we use the default line length of 88 characters per line. **Please run `make format`** to reformat your code. You can check the codestyle using `make check-codestyle` and `make lint`. @@ -63,7 +63,7 @@ def my_function(arg1: type1, arg2: type2) -> returntype: Before proposing a PR, please open an issue, where the feature will be discussed. This prevent from duplicated PR to be proposed and also ease the code review process. -Each PR need to be reviewed and accepted by at least one of the maintainers (@hill-a, @araffin, @ernestum, @AdamGleave or @Miffyli). +Each PR need to be reviewed and accepted by at least one of the maintainers (@hill-a, @araffin, @ernestum, @AdamGleave, @Miffyli or @qgallouedec). A PR must pass the Continuous Integration tests to be merged with the master branch. @@ -85,7 +85,7 @@ Type checking with `pytype` and `mypy`: make type ``` -Codestyle check with `black`, `isort` and `ruff`: +Codestyle check with `black`, and `ruff` (`isort` rules): ``` make check-codestyle diff --git a/Makefile b/Makefile index 7fa590c..cb90f31 100644 --- a/Makefile +++ b/Makefile @@ -27,13 +27,13 @@ lint: format: # Sort imports - isort ${LINT_PATHS} + ruff --select I ${LINT_PATHS} --fix # Reformat using black black ${LINT_PATHS} check-codestyle: # Sort imports - isort --check ${LINT_PATHS} + ruff --select I ${LINT_PATHS} # Reformat using black black --check ${LINT_PATHS} diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index cd7d460..70716d5 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -3,11 +3,12 @@ Changelog ========== -Release 2.2.0a1 (WIP) +Release 2.2.0a2 (WIP) -------------------------- Breaking Changes: ^^^^^^^^^^^^^^^^^ +- Switched to ``ruff`` for sorting imports (isort is no longer needed), black and ruff version now require a minimum version New Features: ^^^^^^^^^^^^^ @@ -18,6 +19,10 @@ New Features: `RL Zoo`_ ^^^^^^^^^ +`SBX`_ +^^^^^^^^^ +- Added ``DDPG`` and ``TD3`` + Bug Fixes: ^^^^^^^^^^ - Prevents using squash_output and not use_sde in ActorCritcPolicy (@PatrickHelm) @@ -25,7 +30,8 @@ Bug Fixes: - Moves VectorizedActionNoise into ``_setup_learn()`` in OffPolicyAlgorithm (@PatrickHelm) - Prevents out of bound error on Windows if no seed is passed (@PatrickHelm) - Calls ``callback.update_locals()`` before ``callback.on_rollout_end()`` in OnPolicyAlgorithm (@PatrickHelm) -- Fixes replay buffer device after loading in OffPolicyAlgorithm (@PatrickHelm) +- Fixed replay buffer device after loading in OffPolicyAlgorithm (@PatrickHelm) +- Fixed ``render_mode`` which was not properly loaded when using ``VecNormalize.load()`` Deprecations: @@ -1424,6 +1430,7 @@ and `Quentin Gallouédec`_ (aka @qgallouedec). .. _SB3-Contrib: https://github.com/Stable-Baselines-Team/stable-baselines3-contrib .. _RL Zoo: https://github.com/DLR-RM/rl-baselines3-zoo +.. _SBX: https://github.com/araffin/sbx Contributors: ------------- diff --git a/pyproject.toml b/pyproject.toml index 1c1837a..7e5d2b6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -24,11 +24,6 @@ max-complexity = 15 [tool.black] line-length = 127 -[tool.isort] -profile = "black" -line_length = 127 -src_paths = ["stable_baselines3"] - [tool.pytype] inputs = ["stable_baselines3"] disable = ["pyi-error"] diff --git a/setup.py b/setup.py index deb9f54..cdaf263 100644 --- a/setup.py +++ b/setup.py @@ -120,12 +120,10 @@ setup( # Type check "pytype", "mypy", - # Lint code (flake8 replacement) - "ruff", - # Sort imports - "isort>=5.0", + # Lint code and sort imports (flake8 and isort replacement) + "ruff>=0.0.288", # Reformat - "black", + "black>=23.9.1,<24", ], "docs": [ "sphinx>=5.3,<7.0", diff --git a/stable_baselines3/common/vec_env/vec_normalize.py b/stable_baselines3/common/vec_env/vec_normalize.py index ebefa82..27c3d43 100644 --- a/stable_baselines3/common/vec_env/vec_normalize.py +++ b/stable_baselines3/common/vec_env/vec_normalize.py @@ -163,6 +163,7 @@ class VecNormalize(VecEnvWrapper): self.venv = venv self.num_envs = venv.num_envs self.class_attributes = dict(inspect.getmembers(self.__class__)) + self.render_mode = venv.render_mode # Check that the observation_space shape match utils.check_shape_equal(self.observation_space, venv.observation_space) diff --git a/stable_baselines3/version.txt b/stable_baselines3/version.txt index 8c6ccba..59ead85 100644 --- a/stable_baselines3/version.txt +++ b/stable_baselines3/version.txt @@ -1 +1 @@ -2.2.0a1 +2.2.0a2 diff --git a/tests/test_vec_normalize.py b/tests/test_vec_normalize.py index ae59047..2b30d5a 100644 --- a/tests/test_vec_normalize.py +++ b/tests/test_vec_normalize.py @@ -123,6 +123,10 @@ def make_env(): return Monitor(gym.make(ENV_ID)) +def make_env_render(): + return Monitor(gym.make(ENV_ID, render_mode="rgb_array")) + + def make_dict_env(): return Monitor(DummyDictEnv()) @@ -257,14 +261,17 @@ def test_obs_rms_vec_normalize(): assert np.allclose(env.ret_rms.mean, 5.688, atol=1e-3) -@pytest.mark.parametrize("make_env", [make_env, make_dict_env, make_image_env]) -def test_vec_env(tmp_path, make_env): +@pytest.mark.parametrize("make_gym_env", [make_env, make_dict_env, make_image_env]) +def test_vec_env(tmp_path, make_gym_env): """Test VecNormalize Object""" clip_obs = 0.5 clip_reward = 5.0 - orig_venv = DummyVecEnv([make_env]) + orig_venv = DummyVecEnv([make_gym_env]) norm_venv = VecNormalize(orig_venv, norm_obs=True, norm_reward=True, clip_obs=clip_obs, clip_reward=clip_reward) + assert orig_venv.render_mode is None + assert norm_venv.render_mode is None + _, done = norm_venv.reset(), [False] while not done[0]: actions = [norm_venv.action_space.sample()] @@ -278,9 +285,19 @@ def test_vec_env(tmp_path, make_env): path = tmp_path / "vec_normalize" norm_venv.save(path) + assert orig_venv.render_mode is None deserialized = VecNormalize.load(path, venv=orig_venv) + assert deserialized.render_mode is None check_vec_norm_equal(norm_venv, deserialized) + # Check that render mode is properly updated + vec_env = DummyVecEnv([make_env_render]) + assert vec_env.render_mode == "rgb_array" + # Test that loading and wrapping keep the correct render mode + if make_gym_env == make_env: + assert VecNormalize.load(path, venv=vec_env).render_mode == "rgb_array" + assert VecNormalize(vec_env).render_mode == "rgb_array" + def test_get_original(): venv = _make_warmstart_cartpole()