From 2bb4500948dccba3292135b1e295532fbc32f668 Mon Sep 17 00:00:00 2001 From: Antonin RAFFIN Date: Tue, 2 Nov 2021 12:52:26 +0100 Subject: [PATCH] Fix `set_env` when using `VecNormalize` (#638) * Fix `set_env` when using `VecNormalize` * Update version --- docs/misc/changelog.rst | 9 +++++++-- stable_baselines3/common/base_class.py | 4 ++++ stable_baselines3/version.txt | 2 +- tests/test_vec_normalize.py | 7 +++++++ 4 files changed, 19 insertions(+), 3 deletions(-) diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index 11e0239..b6bfd2a 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -4,7 +4,7 @@ Changelog ========== -Release 1.3.1a0 (WIP) +Release 1.3.1a1 (WIP) --------------------------- Breaking Changes: @@ -16,8 +16,10 @@ New Features: Bug Fixes: ^^^^^^^^^^ +- Fixed a bug where ``set_env()`` with ``VecNormalize`` would result in an error with off-policy algorithms (thanks @cleversonahum) - FPS calculation is now performed based on number of steps performed during last ``learn`` call, even when ``reset_num_timesteps`` is set to ``False`` (@kachayev) + Deprecations: ^^^^^^^^^^^^^ @@ -830,4 +832,7 @@ And all the contributors: @tirafesi @blurLake @koulakis @joeljosephjin @shwang @rk37 @andyshih12 @RaphaelWag @xicocaio @diditforlulz273 @liorcohen5 @ManifoldFR @mloo3 @SwamyDev @wmmc88 @megan-klaiber @thisray @tfederico @hn2 @LucasAlegre @AptX395 @zampanteymedio @JadenTravnik @decodyng @ardabbour @lorenz-h @mschweizer @lorepieri8 @vwxyzjn -@ShangqunYu @PierreExeter @JacopoPan @ltbd78 @tom-doerr @Atlis @liusida @09tangriro @amy12xx @juancroldan @benblack769 @bstee615 @c-rizz @skandermoalla @MihaiAnca13 @davidblom603 @ayeright @cyprienc @wkirgsn @AechPro @CUN-bjy @batu @IljaAvadiev @timokau @kachayev @eleurent @ac-93 +@ShangqunYu @PierreExeter @JacopoPan @ltbd78 @tom-doerr @Atlis @liusida @09tangriro @amy12xx @juancroldan +@benblack769 @bstee615 @c-rizz @skandermoalla @MihaiAnca13 @davidblom603 @ayeright @cyprienc +@wkirgsn @AechPro @CUN-bjy @batu @IljaAvadiev @timokau @kachayev @cleversonahum +@eleurent @ac-93 diff --git a/stable_baselines3/common/base_class.py b/stable_baselines3/common/base_class.py index 21c2748..4bad2c2 100644 --- a/stable_baselines3/common/base_class.py +++ b/stable_baselines3/common/base_class.py @@ -499,6 +499,10 @@ class BaseAlgorithm(ABC): env = self._wrap_env(env, self.verbose) # Check that the observation spaces match check_for_correct_spaces(env, self.observation_space, self.action_space) + # Update VecNormalize object + # otherwise the wrong env may be used, see https://github.com/DLR-RM/stable-baselines3/issues/637 + self._vec_normalize_env = unwrap_vec_normalize(env) + # Discard `_last_obs`, this will force the env to reset before training # See issue https://github.com/DLR-RM/stable-baselines3/issues/597 if force_reset: diff --git a/stable_baselines3/version.txt b/stable_baselines3/version.txt index e18a0e5..690d925 100644 --- a/stable_baselines3/version.txt +++ b/stable_baselines3/version.txt @@ -1 +1 @@ -1.3.1a0 +1.3.1a1 diff --git a/tests/test_vec_normalize.py b/tests/test_vec_normalize.py index b002928..a362365 100644 --- a/tests/test_vec_normalize.py +++ b/tests/test_vec_normalize.py @@ -341,6 +341,13 @@ def test_offpolicy_normalization(model_class, online_sampling): else: model = model_class("MlpPolicy", env, verbose=1, learning_starts=100, policy_kwargs=dict(net_arch=[64])) + # Check that VecNormalize object is correctly updated + assert model.get_vec_normalize_env() is env + model.set_env(eval_env) + assert model.get_vec_normalize_env() is eval_env + model.learn(total_timesteps=10) + model.set_env(env) + model.learn(total_timesteps=150, eval_env=eval_env, eval_freq=75) # Check getter assert isinstance(model.get_vec_normalize_env(), VecNormalize)