Fixed loading of `ent_coef for SAC and TQC`, it was not optimized anymore (#392)

* Fix ent coef loading bug

* Add test

* Add comment

* Reuse save path
This commit is contained in:
Antonin RAFFIN 2021-04-15 14:50:43 +02:00 committed by GitHub
parent ddbe0e93f9
commit c4304029a2
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 25 additions and 4 deletions

View file

@ -4,7 +4,7 @@ Changelog
==========
Release 1.1.0a2 (WIP)
Release 1.1.0a3 (WIP)
---------------------------
Breaking Changes:
@ -21,6 +21,7 @@ New Features:
Bug Fixes:
^^^^^^^^^^
- Fixed potential issue when calling off-policy algorithms with default arguments multiple times (the size of the replay buffer would be the same)
- Fixed loading of ``ent_coef`` for ``SAC`` and ``TQC``, it was not optimized anymore (thanks @Atlis)
Deprecations:
^^^^^^^^^^^^^
@ -649,4 +650,4 @@ And all the contributors:
@tirafesi @blurLake @koulakis @joeljosephjin @shwang @rk37 @andyshih12 @RaphaelWag @xicocaio
@diditforlulz273 @liorcohen5 @ManifoldFR @mloo3 @SwamyDev @wmmc88 @megan-klaiber @thisray
@tfederico @hn2 @LucasAlegre @AptX395 @zampanteymedio @decodyng @ardabbour @lorenz-h @mschweizer @lorepieri8 @vwxyzjn
@ShangqunYu @PierreExeter @JacopoPan @ltbd78 @tom-doerr
@ShangqunYu @PierreExeter @JacopoPan @ltbd78 @tom-doerr @Atlis

View file

@ -650,7 +650,9 @@ class BaseAlgorithm(ABC):
# put other pytorch variables back in place
if pytorch_variables is not None:
for name in pytorch_variables:
recursive_setattr(model, name, pytorch_variables[name])
# Set the data attribute directly to avoid issue when using optimizers
# See https://github.com/DLR-RM/stable-baselines3/issues/391
recursive_setattr(model, name + ".data", pytorch_variables[name].data)
# Sample gSDE exploration matrix, so it uses the right device
# see issue #44

View file

@ -1 +1 @@
1.1.0a2
1.1.0a3

View file

@ -233,6 +233,24 @@ def test_exclude_include_saved_params(tmp_path, model_class):
os.remove(tmp_path / "test_save.zip")
def test_save_load_pytorch_var(tmp_path):
model = SAC("MlpPolicy", "Pendulum-v0", seed=3, policy_kwargs=dict(net_arch=[64], n_critics=1))
model.learn(200)
save_path = str(tmp_path / "sac_pendulum")
model.save(save_path)
env = model.get_env()
ent_coef_before = model.log_ent_coef
del model
model = SAC.load(save_path, env=env)
assert th.allclose(ent_coef_before, model.log_ent_coef)
model.learn(200)
ent_coef_after = model.log_ent_coef
# Check that the entropy coefficient is still optimized
assert not th.allclose(ent_coef_before, ent_coef_after)
@pytest.mark.parametrize("model_class", [A2C, TD3])
def test_save_load_env_cnn(tmp_path, model_class):
"""