mirror of
https://github.com/saymrwulf/stable-baselines3.git
synced 2026-05-22 22:10:16 +00:00
Add skip option for VecTransposeImage and bug fix in frame stack (#700)
* Update doc * Add comment * Add skip option to VecTransposeImage and fix bug in frame stack
This commit is contained in:
parent
d496cd4d95
commit
bb16645c4e
8 changed files with 56 additions and 15 deletions
|
|
@ -56,10 +56,10 @@ In the following example, we will train, save and load a DQN model on the Lunar
|
|||
LunarLander requires the python package ``box2d``.
|
||||
You can install it using ``apt install swig`` and then ``pip install box2d box2d-kengz``
|
||||
|
||||
.. note::
|
||||
.. warning::
|
||||
``load`` method re-creates the model from scratch and should be called on the Algorithm without instantiating it first,
|
||||
e.g. ``model = DQN.load("dqn_lunar", env=env)`` instead of ``model = DQN(env=env)`` followed by ``model.load("dqn_lunar")``. The latter **will not work** as ``load`` does not work by reference.
|
||||
If you want to load parameters without re-creating the model, e.g. to evaluate the same model
|
||||
e.g. ``model = DQN.load("dqn_lunar", env=env)`` instead of ``model = DQN(env=env)`` followed by ``model.load("dqn_lunar")``. The latter **will not work** as ``load`` is not an in-place operation.
|
||||
If you want to load parameters without re-creating the model, e.g. to evaluate the same model
|
||||
with multiple different sets of parameters, consider using ``set_parameters`` instead.
|
||||
|
||||
.. code-block:: python
|
||||
|
|
@ -163,7 +163,7 @@ Multiprocessing with off-policy algorithms
|
|||
------------------------------------------
|
||||
|
||||
.. warning::
|
||||
|
||||
|
||||
When using multiple environments with off-policy algorithms, you should update the ``gradient_steps``
|
||||
parameter too. Set it to ``gradient_steps=-1`` to perform as many gradient steps as transitions collected.
|
||||
There is usually a compromise between wall-clock time and sample efficiency,
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@ Changelog
|
|||
==========
|
||||
|
||||
|
||||
Release 1.3.1a6 (WIP)
|
||||
Release 1.3.1a7 (WIP)
|
||||
---------------------------
|
||||
|
||||
Breaking Changes:
|
||||
|
|
@ -20,7 +20,7 @@ New Features:
|
|||
- Added ``norm_obs_keys`` param for ``VecNormalize`` wrapper to configure which observation keys to normalize (@kachayev)
|
||||
- Added experimental support to train off-policy algorithms with multiple envs (note: ``HerReplayBuffer`` currently not supported)
|
||||
- Handle timeout termination properly for on-policy algorithms (when using ``TimeLimit``)
|
||||
|
||||
- Added ``skip`` option to ``VecTransposeImage`` to skip transforming the channel order when the heuristic is wrong
|
||||
|
||||
Bug Fixes:
|
||||
^^^^^^^^^^
|
||||
|
|
@ -29,6 +29,7 @@ Bug Fixes:
|
|||
- Fixed evaluation script for recurrent policies (experimental feature in SB3 contrib)
|
||||
- Fixed a bug where the observation would be incorrectly detected as non-vectorized instead of throwing an error
|
||||
- The env checker now properly checks and warns about potential issues for continuous action spaces when the boundaries are too small or when the dtype is not float32
|
||||
- Fixed a bug in ``VecFrameStack`` with channel first image envs, where the terminal observation would be wrongly created.
|
||||
|
||||
Deprecations:
|
||||
^^^^^^^^^^^^^
|
||||
|
|
@ -50,8 +51,8 @@ Documentation:
|
|||
- Add documentation on exporting to TFLite/Coral
|
||||
- Added JMLR paper and updated citation
|
||||
- Added link to RL Tips and Tricks video
|
||||
- Update ``BaseAlgorithm.load`` docstring
|
||||
- Add a Note on ``load`` behavior in the examples
|
||||
- Updated ``BaseAlgorithm.load`` docstring (@Demetrio92)
|
||||
- Added a note on ``load`` behavior in the examples (@Demetrio92)
|
||||
|
||||
|
||||
Release 1.3.0 (2021-10-23)
|
||||
|
|
@ -857,4 +858,4 @@ And all the contributors:
|
|||
@ShangqunYu @PierreExeter @JacopoPan @ltbd78 @tom-doerr @Atlis @liusida @09tangriro @amy12xx @juancroldan
|
||||
@benblack769 @bstee615 @c-rizz @skandermoalla @MihaiAnca13 @davidblom603 @ayeright @cyprienc
|
||||
@wkirgsn @AechPro @CUN-bjy @batu @IljaAvadiev @timokau @kachayev @cleversonahum
|
||||
@eleurent @ac-93 @cove9988 @theDebugger811 @hsuehch
|
||||
@eleurent @ac-93 @cove9988 @theDebugger811 @hsuehch @Demetrio92
|
||||
|
|
|
|||
|
|
@ -658,8 +658,8 @@ class BaseAlgorithm(ABC):
|
|||
) -> "BaseAlgorithm":
|
||||
"""
|
||||
Load the model from a zip-file.
|
||||
Note: `load` re-creates the model from scratch, it does not update it in-place!
|
||||
For an in-place load use `set_parameters` instead.
|
||||
Warning: ``load`` re-creates the model from scratch, it does not update it in-place!
|
||||
For an in-place load use ``set_parameters`` instead.
|
||||
|
||||
:param path: path to the file (or a file-like) where to
|
||||
load the agent from
|
||||
|
|
|
|||
|
|
@ -178,6 +178,7 @@ class ResultsWriter:
|
|||
filename = os.path.join(filename, Monitor.EXT)
|
||||
else:
|
||||
filename = filename + "." + Monitor.EXT
|
||||
# Prevent newline issue on Windows, see GH issue #692
|
||||
self.file_handler = open(filename, "wt", newline="\n")
|
||||
self.file_handler.write("#%s\n" % json.dumps(header))
|
||||
self.logger = csv.DictWriter(self.file_handler, fieldnames=("r", "l", "t") + extra_keys)
|
||||
|
|
|
|||
|
|
@ -126,7 +126,7 @@ class StackedObservations(object):
|
|||
if self.channels_first:
|
||||
new_terminal = np.concatenate(
|
||||
(self.stackedobs[i, :-stack_ax_size, ...], old_terminal),
|
||||
axis=self.stack_dimension,
|
||||
axis=0, # self.stack_dimension - 1, as there is not batch dim
|
||||
)
|
||||
else:
|
||||
new_terminal = np.concatenate(
|
||||
|
|
|
|||
|
|
@ -14,13 +14,21 @@ class VecTransposeImage(VecEnvWrapper):
|
|||
It is required for PyTorch convolution layers.
|
||||
|
||||
:param venv:
|
||||
:param skip: Skip this wrapper if needed as we rely on heuristic to apply it or not,
|
||||
which may result in unwanted behavior, see GH issue #671.
|
||||
"""
|
||||
|
||||
def __init__(self, venv: VecEnv):
|
||||
def __init__(self, venv: VecEnv, skip: bool = False):
|
||||
assert is_image_space(venv.observation_space) or isinstance(
|
||||
venv.observation_space, spaces.dict.Dict
|
||||
), "The observation space must be an image or dictionary observation space"
|
||||
|
||||
self.skip = skip
|
||||
# Do nothing
|
||||
if skip:
|
||||
super(VecTransposeImage, self).__init__(venv)
|
||||
return
|
||||
|
||||
if isinstance(venv.observation_space, spaces.dict.Dict):
|
||||
self.image_space_keys = []
|
||||
observation_space = deepcopy(venv.observation_space)
|
||||
|
|
@ -70,6 +78,10 @@ class VecTransposeImage(VecEnvWrapper):
|
|||
:param observations:
|
||||
:return: Transposed observations
|
||||
"""
|
||||
# Do nothing
|
||||
if self.skip:
|
||||
return observations
|
||||
|
||||
if isinstance(observations, dict):
|
||||
# Avoid modifying the original object in place
|
||||
observations = deepcopy(observations)
|
||||
|
|
|
|||
|
|
@ -1 +1 @@
|
|||
1.3.1a6
|
||||
1.3.1a7
|
||||
|
|
|
|||
|
|
@ -10,7 +10,7 @@ from stable_baselines3 import A2C, DQN, PPO, SAC, TD3
|
|||
from stable_baselines3.common.envs import FakeImageEnv
|
||||
from stable_baselines3.common.preprocessing import is_image_space, is_image_space_channels_first
|
||||
from stable_baselines3.common.utils import zip_strict
|
||||
from stable_baselines3.common.vec_env import VecTransposeImage, is_vecenv_wrapped
|
||||
from stable_baselines3.common.vec_env import DummyVecEnv, VecFrameStack, VecTransposeImage, is_vecenv_wrapped
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model_class", [A2C, PPO, SAC, TD3, DQN])
|
||||
|
|
@ -57,6 +57,33 @@ def test_cnn(tmp_path, model_class):
|
|||
os.remove(str(tmp_path / SAVE_NAME))
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model_class", [A2C])
|
||||
def test_vec_transpose_skip(tmp_path, model_class):
|
||||
# Fake grayscale with frameskip
|
||||
env = FakeImageEnv(
|
||||
screen_height=41, screen_width=40, n_channels=10, discrete=model_class not in {SAC, TD3}, channel_first=True
|
||||
)
|
||||
env = DummyVecEnv([lambda: env])
|
||||
# Stack 5 frames so the observation is now (50, 40, 40) but the env is still channel first
|
||||
env = VecFrameStack(env, 5, channels_order="first")
|
||||
obs_shape_before = env.reset().shape
|
||||
# The observation space should be different as the heuristic thinks it is channel last
|
||||
assert not np.allclose(obs_shape_before, VecTransposeImage(env).reset().shape)
|
||||
env = VecTransposeImage(env, skip=True)
|
||||
# The observation space should be the same as we skip the VecTransposeImage
|
||||
assert np.allclose(obs_shape_before, env.reset().shape)
|
||||
|
||||
kwargs = dict(
|
||||
n_steps=64,
|
||||
policy_kwargs=dict(features_extractor_kwargs=dict(features_dim=32)),
|
||||
seed=1,
|
||||
)
|
||||
model = model_class("CnnPolicy", env, **kwargs).learn(250)
|
||||
|
||||
obs = env.reset()
|
||||
action, _ = model.predict(obs, deterministic=True)
|
||||
|
||||
|
||||
def patch_dqn_names_(model):
|
||||
# Small hack to make the test work with DQN
|
||||
if isinstance(model, DQN):
|
||||
|
|
|
|||
Loading…
Reference in a new issue