From bb16645c4ee170b28245ff70d4c02aad7eb788b4 Mon Sep 17 00:00:00 2001 From: Antonin RAFFIN Date: Thu, 23 Dec 2021 16:12:49 +0100 Subject: [PATCH] Add `skip` option for `VecTransposeImage` and bug fix in frame stack (#700) * Update doc * Add comment * Add skip option to VecTransposeImage and fix bug in frame stack --- docs/guide/examples.rst | 8 ++--- docs/misc/changelog.rst | 11 +++---- stable_baselines3/common/base_class.py | 4 +-- stable_baselines3/common/monitor.py | 1 + .../common/vec_env/stacked_observations.py | 2 +- .../common/vec_env/vec_transpose.py | 14 ++++++++- stable_baselines3/version.txt | 2 +- tests/test_cnn.py | 29 ++++++++++++++++++- 8 files changed, 56 insertions(+), 15 deletions(-) diff --git a/docs/guide/examples.rst b/docs/guide/examples.rst index 67d9a8d..733279b 100644 --- a/docs/guide/examples.rst +++ b/docs/guide/examples.rst @@ -56,10 +56,10 @@ In the following example, we will train, save and load a DQN model on the Lunar LunarLander requires the python package ``box2d``. You can install it using ``apt install swig`` and then ``pip install box2d box2d-kengz`` -.. note:: +.. warning:: ``load`` method re-creates the model from scratch and should be called on the Algorithm without instantiating it first, - e.g. ``model = DQN.load("dqn_lunar", env=env)`` instead of ``model = DQN(env=env)`` followed by ``model.load("dqn_lunar")``. The latter **will not work** as ``load`` does not work by reference. - If you want to load parameters without re-creating the model, e.g. to evaluate the same model + e.g. ``model = DQN.load("dqn_lunar", env=env)`` instead of ``model = DQN(env=env)`` followed by ``model.load("dqn_lunar")``. The latter **will not work** as ``load`` is not an in-place operation. + If you want to load parameters without re-creating the model, e.g. to evaluate the same model with multiple different sets of parameters, consider using ``set_parameters`` instead. .. code-block:: python @@ -163,7 +163,7 @@ Multiprocessing with off-policy algorithms ------------------------------------------ .. warning:: - + When using multiple environments with off-policy algorithms, you should update the ``gradient_steps`` parameter too. Set it to ``gradient_steps=-1`` to perform as many gradient steps as transitions collected. There is usually a compromise between wall-clock time and sample efficiency, diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index 3e4c067..989e021 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -4,7 +4,7 @@ Changelog ========== -Release 1.3.1a6 (WIP) +Release 1.3.1a7 (WIP) --------------------------- Breaking Changes: @@ -20,7 +20,7 @@ New Features: - Added ``norm_obs_keys`` param for ``VecNormalize`` wrapper to configure which observation keys to normalize (@kachayev) - Added experimental support to train off-policy algorithms with multiple envs (note: ``HerReplayBuffer`` currently not supported) - Handle timeout termination properly for on-policy algorithms (when using ``TimeLimit``) - +- Added ``skip`` option to ``VecTransposeImage`` to skip transforming the channel order when the heuristic is wrong Bug Fixes: ^^^^^^^^^^ @@ -29,6 +29,7 @@ Bug Fixes: - Fixed evaluation script for recurrent policies (experimental feature in SB3 contrib) - Fixed a bug where the observation would be incorrectly detected as non-vectorized instead of throwing an error - The env checker now properly checks and warns about potential issues for continuous action spaces when the boundaries are too small or when the dtype is not float32 +- Fixed a bug in ``VecFrameStack`` with channel first image envs, where the terminal observation would be wrongly created. Deprecations: ^^^^^^^^^^^^^ @@ -50,8 +51,8 @@ Documentation: - Add documentation on exporting to TFLite/Coral - Added JMLR paper and updated citation - Added link to RL Tips and Tricks video -- Update ``BaseAlgorithm.load`` docstring -- Add a Note on ``load`` behavior in the examples +- Updated ``BaseAlgorithm.load`` docstring (@Demetrio92) +- Added a note on ``load`` behavior in the examples (@Demetrio92) Release 1.3.0 (2021-10-23) @@ -857,4 +858,4 @@ And all the contributors: @ShangqunYu @PierreExeter @JacopoPan @ltbd78 @tom-doerr @Atlis @liusida @09tangriro @amy12xx @juancroldan @benblack769 @bstee615 @c-rizz @skandermoalla @MihaiAnca13 @davidblom603 @ayeright @cyprienc @wkirgsn @AechPro @CUN-bjy @batu @IljaAvadiev @timokau @kachayev @cleversonahum -@eleurent @ac-93 @cove9988 @theDebugger811 @hsuehch +@eleurent @ac-93 @cove9988 @theDebugger811 @hsuehch @Demetrio92 diff --git a/stable_baselines3/common/base_class.py b/stable_baselines3/common/base_class.py index bdee678..25c2638 100644 --- a/stable_baselines3/common/base_class.py +++ b/stable_baselines3/common/base_class.py @@ -658,8 +658,8 @@ class BaseAlgorithm(ABC): ) -> "BaseAlgorithm": """ Load the model from a zip-file. - Note: `load` re-creates the model from scratch, it does not update it in-place! - For an in-place load use `set_parameters` instead. + Warning: ``load`` re-creates the model from scratch, it does not update it in-place! + For an in-place load use ``set_parameters`` instead. :param path: path to the file (or a file-like) where to load the agent from diff --git a/stable_baselines3/common/monitor.py b/stable_baselines3/common/monitor.py index 7b51761..04cda22 100644 --- a/stable_baselines3/common/monitor.py +++ b/stable_baselines3/common/monitor.py @@ -178,6 +178,7 @@ class ResultsWriter: filename = os.path.join(filename, Monitor.EXT) else: filename = filename + "." + Monitor.EXT + # Prevent newline issue on Windows, see GH issue #692 self.file_handler = open(filename, "wt", newline="\n") self.file_handler.write("#%s\n" % json.dumps(header)) self.logger = csv.DictWriter(self.file_handler, fieldnames=("r", "l", "t") + extra_keys) diff --git a/stable_baselines3/common/vec_env/stacked_observations.py b/stable_baselines3/common/vec_env/stacked_observations.py index 956231f..affd775 100644 --- a/stable_baselines3/common/vec_env/stacked_observations.py +++ b/stable_baselines3/common/vec_env/stacked_observations.py @@ -126,7 +126,7 @@ class StackedObservations(object): if self.channels_first: new_terminal = np.concatenate( (self.stackedobs[i, :-stack_ax_size, ...], old_terminal), - axis=self.stack_dimension, + axis=0, # self.stack_dimension - 1, as there is not batch dim ) else: new_terminal = np.concatenate( diff --git a/stable_baselines3/common/vec_env/vec_transpose.py b/stable_baselines3/common/vec_env/vec_transpose.py index 399fb31..e6f728b 100644 --- a/stable_baselines3/common/vec_env/vec_transpose.py +++ b/stable_baselines3/common/vec_env/vec_transpose.py @@ -14,13 +14,21 @@ class VecTransposeImage(VecEnvWrapper): It is required for PyTorch convolution layers. :param venv: + :param skip: Skip this wrapper if needed as we rely on heuristic to apply it or not, + which may result in unwanted behavior, see GH issue #671. """ - def __init__(self, venv: VecEnv): + def __init__(self, venv: VecEnv, skip: bool = False): assert is_image_space(venv.observation_space) or isinstance( venv.observation_space, spaces.dict.Dict ), "The observation space must be an image or dictionary observation space" + self.skip = skip + # Do nothing + if skip: + super(VecTransposeImage, self).__init__(venv) + return + if isinstance(venv.observation_space, spaces.dict.Dict): self.image_space_keys = [] observation_space = deepcopy(venv.observation_space) @@ -70,6 +78,10 @@ class VecTransposeImage(VecEnvWrapper): :param observations: :return: Transposed observations """ + # Do nothing + if self.skip: + return observations + if isinstance(observations, dict): # Avoid modifying the original object in place observations = deepcopy(observations) diff --git a/stable_baselines3/version.txt b/stable_baselines3/version.txt index e6eaed8..f625807 100644 --- a/stable_baselines3/version.txt +++ b/stable_baselines3/version.txt @@ -1 +1 @@ -1.3.1a6 +1.3.1a7 diff --git a/tests/test_cnn.py b/tests/test_cnn.py index 8ec33fb..03f089d 100644 --- a/tests/test_cnn.py +++ b/tests/test_cnn.py @@ -10,7 +10,7 @@ from stable_baselines3 import A2C, DQN, PPO, SAC, TD3 from stable_baselines3.common.envs import FakeImageEnv from stable_baselines3.common.preprocessing import is_image_space, is_image_space_channels_first from stable_baselines3.common.utils import zip_strict -from stable_baselines3.common.vec_env import VecTransposeImage, is_vecenv_wrapped +from stable_baselines3.common.vec_env import DummyVecEnv, VecFrameStack, VecTransposeImage, is_vecenv_wrapped @pytest.mark.parametrize("model_class", [A2C, PPO, SAC, TD3, DQN]) @@ -57,6 +57,33 @@ def test_cnn(tmp_path, model_class): os.remove(str(tmp_path / SAVE_NAME)) +@pytest.mark.parametrize("model_class", [A2C]) +def test_vec_transpose_skip(tmp_path, model_class): + # Fake grayscale with frameskip + env = FakeImageEnv( + screen_height=41, screen_width=40, n_channels=10, discrete=model_class not in {SAC, TD3}, channel_first=True + ) + env = DummyVecEnv([lambda: env]) + # Stack 5 frames so the observation is now (50, 40, 40) but the env is still channel first + env = VecFrameStack(env, 5, channels_order="first") + obs_shape_before = env.reset().shape + # The observation space should be different as the heuristic thinks it is channel last + assert not np.allclose(obs_shape_before, VecTransposeImage(env).reset().shape) + env = VecTransposeImage(env, skip=True) + # The observation space should be the same as we skip the VecTransposeImage + assert np.allclose(obs_shape_before, env.reset().shape) + + kwargs = dict( + n_steps=64, + policy_kwargs=dict(features_extractor_kwargs=dict(features_dim=32)), + seed=1, + ) + model = model_class("CnnPolicy", env, **kwargs).learn(250) + + obs = env.reset() + action, _ = model.predict(obs, deterministic=True) + + def patch_dqn_names_(model): # Small hack to make the test work with DQN if isinstance(model, DQN):