Add skip option for VecTransposeImage and bug fix in frame stack (#700)

* Update doc

* Add comment

* Add skip option to VecTransposeImage and fix bug in frame stack
This commit is contained in:
Antonin RAFFIN 2021-12-23 16:12:49 +01:00 committed by GitHub
parent d496cd4d95
commit bb16645c4e
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
8 changed files with 56 additions and 15 deletions

View file

@ -56,10 +56,10 @@ In the following example, we will train, save and load a DQN model on the Lunar
LunarLander requires the python package ``box2d``.
You can install it using ``apt install swig`` and then ``pip install box2d box2d-kengz``
.. note::
.. warning::
``load`` method re-creates the model from scratch and should be called on the Algorithm without instantiating it first,
e.g. ``model = DQN.load("dqn_lunar", env=env)`` instead of ``model = DQN(env=env)`` followed by ``model.load("dqn_lunar")``. The latter **will not work** as ``load`` does not work by reference.
If you want to load parameters without re-creating the model, e.g. to evaluate the same model
e.g. ``model = DQN.load("dqn_lunar", env=env)`` instead of ``model = DQN(env=env)`` followed by ``model.load("dqn_lunar")``. The latter **will not work** as ``load`` is not an in-place operation.
If you want to load parameters without re-creating the model, e.g. to evaluate the same model
with multiple different sets of parameters, consider using ``set_parameters`` instead.
.. code-block:: python
@ -163,7 +163,7 @@ Multiprocessing with off-policy algorithms
------------------------------------------
.. warning::
When using multiple environments with off-policy algorithms, you should update the ``gradient_steps``
parameter too. Set it to ``gradient_steps=-1`` to perform as many gradient steps as transitions collected.
There is usually a compromise between wall-clock time and sample efficiency,

View file

@ -4,7 +4,7 @@ Changelog
==========
Release 1.3.1a6 (WIP)
Release 1.3.1a7 (WIP)
---------------------------
Breaking Changes:
@ -20,7 +20,7 @@ New Features:
- Added ``norm_obs_keys`` param for ``VecNormalize`` wrapper to configure which observation keys to normalize (@kachayev)
- Added experimental support to train off-policy algorithms with multiple envs (note: ``HerReplayBuffer`` currently not supported)
- Handle timeout termination properly for on-policy algorithms (when using ``TimeLimit``)
- Added ``skip`` option to ``VecTransposeImage`` to skip transforming the channel order when the heuristic is wrong
Bug Fixes:
^^^^^^^^^^
@ -29,6 +29,7 @@ Bug Fixes:
- Fixed evaluation script for recurrent policies (experimental feature in SB3 contrib)
- Fixed a bug where the observation would be incorrectly detected as non-vectorized instead of throwing an error
- The env checker now properly checks and warns about potential issues for continuous action spaces when the boundaries are too small or when the dtype is not float32
- Fixed a bug in ``VecFrameStack`` with channel first image envs, where the terminal observation would be wrongly created.
Deprecations:
^^^^^^^^^^^^^
@ -50,8 +51,8 @@ Documentation:
- Add documentation on exporting to TFLite/Coral
- Added JMLR paper and updated citation
- Added link to RL Tips and Tricks video
- Update ``BaseAlgorithm.load`` docstring
- Add a Note on ``load`` behavior in the examples
- Updated ``BaseAlgorithm.load`` docstring (@Demetrio92)
- Added a note on ``load`` behavior in the examples (@Demetrio92)
Release 1.3.0 (2021-10-23)
@ -857,4 +858,4 @@ And all the contributors:
@ShangqunYu @PierreExeter @JacopoPan @ltbd78 @tom-doerr @Atlis @liusida @09tangriro @amy12xx @juancroldan
@benblack769 @bstee615 @c-rizz @skandermoalla @MihaiAnca13 @davidblom603 @ayeright @cyprienc
@wkirgsn @AechPro @CUN-bjy @batu @IljaAvadiev @timokau @kachayev @cleversonahum
@eleurent @ac-93 @cove9988 @theDebugger811 @hsuehch
@eleurent @ac-93 @cove9988 @theDebugger811 @hsuehch @Demetrio92

View file

@ -658,8 +658,8 @@ class BaseAlgorithm(ABC):
) -> "BaseAlgorithm":
"""
Load the model from a zip-file.
Note: `load` re-creates the model from scratch, it does not update it in-place!
For an in-place load use `set_parameters` instead.
Warning: ``load`` re-creates the model from scratch, it does not update it in-place!
For an in-place load use ``set_parameters`` instead.
:param path: path to the file (or a file-like) where to
load the agent from

View file

@ -178,6 +178,7 @@ class ResultsWriter:
filename = os.path.join(filename, Monitor.EXT)
else:
filename = filename + "." + Monitor.EXT
# Prevent newline issue on Windows, see GH issue #692
self.file_handler = open(filename, "wt", newline="\n")
self.file_handler.write("#%s\n" % json.dumps(header))
self.logger = csv.DictWriter(self.file_handler, fieldnames=("r", "l", "t") + extra_keys)

View file

@ -126,7 +126,7 @@ class StackedObservations(object):
if self.channels_first:
new_terminal = np.concatenate(
(self.stackedobs[i, :-stack_ax_size, ...], old_terminal),
axis=self.stack_dimension,
axis=0, # self.stack_dimension - 1, as there is not batch dim
)
else:
new_terminal = np.concatenate(

View file

@ -14,13 +14,21 @@ class VecTransposeImage(VecEnvWrapper):
It is required for PyTorch convolution layers.
:param venv:
:param skip: Skip this wrapper if needed as we rely on heuristic to apply it or not,
which may result in unwanted behavior, see GH issue #671.
"""
def __init__(self, venv: VecEnv):
def __init__(self, venv: VecEnv, skip: bool = False):
assert is_image_space(venv.observation_space) or isinstance(
venv.observation_space, spaces.dict.Dict
), "The observation space must be an image or dictionary observation space"
self.skip = skip
# Do nothing
if skip:
super(VecTransposeImage, self).__init__(venv)
return
if isinstance(venv.observation_space, spaces.dict.Dict):
self.image_space_keys = []
observation_space = deepcopy(venv.observation_space)
@ -70,6 +78,10 @@ class VecTransposeImage(VecEnvWrapper):
:param observations:
:return: Transposed observations
"""
# Do nothing
if self.skip:
return observations
if isinstance(observations, dict):
# Avoid modifying the original object in place
observations = deepcopy(observations)

View file

@ -1 +1 @@
1.3.1a6
1.3.1a7

View file

@ -10,7 +10,7 @@ from stable_baselines3 import A2C, DQN, PPO, SAC, TD3
from stable_baselines3.common.envs import FakeImageEnv
from stable_baselines3.common.preprocessing import is_image_space, is_image_space_channels_first
from stable_baselines3.common.utils import zip_strict
from stable_baselines3.common.vec_env import VecTransposeImage, is_vecenv_wrapped
from stable_baselines3.common.vec_env import DummyVecEnv, VecFrameStack, VecTransposeImage, is_vecenv_wrapped
@pytest.mark.parametrize("model_class", [A2C, PPO, SAC, TD3, DQN])
@ -57,6 +57,33 @@ def test_cnn(tmp_path, model_class):
os.remove(str(tmp_path / SAVE_NAME))
@pytest.mark.parametrize("model_class", [A2C])
def test_vec_transpose_skip(tmp_path, model_class):
# Fake grayscale with frameskip
env = FakeImageEnv(
screen_height=41, screen_width=40, n_channels=10, discrete=model_class not in {SAC, TD3}, channel_first=True
)
env = DummyVecEnv([lambda: env])
# Stack 5 frames so the observation is now (50, 40, 40) but the env is still channel first
env = VecFrameStack(env, 5, channels_order="first")
obs_shape_before = env.reset().shape
# The observation space should be different as the heuristic thinks it is channel last
assert not np.allclose(obs_shape_before, VecTransposeImage(env).reset().shape)
env = VecTransposeImage(env, skip=True)
# The observation space should be the same as we skip the VecTransposeImage
assert np.allclose(obs_shape_before, env.reset().shape)
kwargs = dict(
n_steps=64,
policy_kwargs=dict(features_extractor_kwargs=dict(features_dim=32)),
seed=1,
)
model = model_class("CnnPolicy", env, **kwargs).learn(250)
obs = env.reset()
action, _ = model.predict(obs, deterministic=True)
def patch_dqn_names_(model):
# Small hack to make the test work with DQN
if isinstance(model, DQN):