From c4304029a2efa93c321104ebcae403ca679eb1a1 Mon Sep 17 00:00:00 2001 From: Antonin RAFFIN Date: Thu, 15 Apr 2021 14:50:43 +0200 Subject: [PATCH] Fixed loading of ``ent_coef`` for ``SAC`` and ``TQC``, it was not optimized anymore (#392) * Fix ent coef loading bug * Add test * Add comment * Reuse save path --- docs/misc/changelog.rst | 5 +++-- stable_baselines3/common/base_class.py | 4 +++- stable_baselines3/version.txt | 2 +- tests/test_save_load.py | 18 ++++++++++++++++++ 4 files changed, 25 insertions(+), 4 deletions(-) diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index a7b780f..b0228a7 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -4,7 +4,7 @@ Changelog ========== -Release 1.1.0a2 (WIP) +Release 1.1.0a3 (WIP) --------------------------- Breaking Changes: @@ -21,6 +21,7 @@ New Features: Bug Fixes: ^^^^^^^^^^ - Fixed potential issue when calling off-policy algorithms with default arguments multiple times (the size of the replay buffer would be the same) +- Fixed loading of ``ent_coef`` for ``SAC`` and ``TQC``, it was not optimized anymore (thanks @Atlis) Deprecations: ^^^^^^^^^^^^^ @@ -649,4 +650,4 @@ And all the contributors: @tirafesi @blurLake @koulakis @joeljosephjin @shwang @rk37 @andyshih12 @RaphaelWag @xicocaio @diditforlulz273 @liorcohen5 @ManifoldFR @mloo3 @SwamyDev @wmmc88 @megan-klaiber @thisray @tfederico @hn2 @LucasAlegre @AptX395 @zampanteymedio @decodyng @ardabbour @lorenz-h @mschweizer @lorepieri8 @vwxyzjn -@ShangqunYu @PierreExeter @JacopoPan @ltbd78 @tom-doerr +@ShangqunYu @PierreExeter @JacopoPan @ltbd78 @tom-doerr @Atlis diff --git a/stable_baselines3/common/base_class.py b/stable_baselines3/common/base_class.py index 06bef17..3164df1 100644 --- a/stable_baselines3/common/base_class.py +++ b/stable_baselines3/common/base_class.py @@ -650,7 +650,9 @@ class BaseAlgorithm(ABC): # put other pytorch variables back in place if pytorch_variables is not None: for name in pytorch_variables: - recursive_setattr(model, name, pytorch_variables[name]) + # Set the data attribute directly to avoid issue when using optimizers + # See https://github.com/DLR-RM/stable-baselines3/issues/391 + recursive_setattr(model, name + ".data", pytorch_variables[name].data) # Sample gSDE exploration matrix, so it uses the right device # see issue #44 diff --git a/stable_baselines3/version.txt b/stable_baselines3/version.txt index c733209..55edcad 100644 --- a/stable_baselines3/version.txt +++ b/stable_baselines3/version.txt @@ -1 +1 @@ -1.1.0a2 +1.1.0a3 diff --git a/tests/test_save_load.py b/tests/test_save_load.py index 1815877..9b629ef 100644 --- a/tests/test_save_load.py +++ b/tests/test_save_load.py @@ -233,6 +233,24 @@ def test_exclude_include_saved_params(tmp_path, model_class): os.remove(tmp_path / "test_save.zip") +def test_save_load_pytorch_var(tmp_path): + model = SAC("MlpPolicy", "Pendulum-v0", seed=3, policy_kwargs=dict(net_arch=[64], n_critics=1)) + model.learn(200) + save_path = str(tmp_path / "sac_pendulum") + model.save(save_path) + env = model.get_env() + ent_coef_before = model.log_ent_coef + + del model + + model = SAC.load(save_path, env=env) + assert th.allclose(ent_coef_before, model.log_ent_coef) + model.learn(200) + ent_coef_after = model.log_ent_coef + # Check that the entropy coefficient is still optimized + assert not th.allclose(ent_coef_before, ent_coef_after) + + @pytest.mark.parametrize("model_class", [A2C, TD3]) def test_save_load_env_cnn(tmp_path, model_class): """