From c99d65c664bfc04acfaa59b2e7a4a1d2bbd9a438 Mon Sep 17 00:00:00 2001 From: PatrickHelm <91552435+PatrickHelm@users.noreply.github.com> Date: Wed, 30 Aug 2023 12:37:14 +0200 Subject: [PATCH] Fix `VectorizedActionNoise` in `OffPolicyAlgorithm` (#1657) * moves VectorizedActionNoise into _setup_learn() * update changelog --------- Co-authored-by: Antonin RAFFIN Co-authored-by: Antonin Raffin --- docs/misc/changelog.rst | 3 ++- stable_baselines3/common/off_policy_algorithm.py | 12 ++++++++---- 2 files changed, 10 insertions(+), 5 deletions(-) diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index 53b222c..84ec850 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -20,7 +20,8 @@ New Features: Bug Fixes: ^^^^^^^^^^ -- Prevents OOB error on Windows if no seed is passed (@PatrickHelm) +- Moves VectorizedActionNoise into ``_setup_learn()`` in OffPolicyAlgorithm (@PatrickHelm) +- Prevents out of bound error on Windows if no seed is passed (@PatrickHelm) Deprecations: ^^^^^^^^^^^^^ diff --git a/stable_baselines3/common/off_policy_algorithm.py b/stable_baselines3/common/off_policy_algorithm.py index e3e6c59..bac3953 100644 --- a/stable_baselines3/common/off_policy_algorithm.py +++ b/stable_baselines3/common/off_policy_algorithm.py @@ -281,6 +281,14 @@ class OffPolicyAlgorithm(BaseAlgorithm): pos = (replay_buffer.pos - 1) % replay_buffer.buffer_size replay_buffer.dones[pos] = True + # Vectorize action noise if needed + if ( + self.action_noise is not None + and self.env.num_envs > 1 + and not isinstance(self.action_noise, VectorizedActionNoise) + ): + self.action_noise = VectorizedActionNoise(self.action_noise, self.env.num_envs) + return super()._setup_learn( total_timesteps, callback, @@ -523,10 +531,6 @@ class OffPolicyAlgorithm(BaseAlgorithm): if env.num_envs > 1: assert train_freq.unit == TrainFrequencyUnit.STEP, "You must use only one env when doing episodic training." - # Vectorize action noise if needed - if action_noise is not None and env.num_envs > 1 and not isinstance(action_noise, VectorizedActionNoise): - action_noise = VectorizedActionNoise(action_noise, env.num_envs) - if self.use_sde: self.actor.reset_noise(env.num_envs)