mirror of
https://github.com/saymrwulf/stable-baselines3.git
synced 2026-06-25 02:50:59 +00:00
Fixed loading of `ent_coef for SAC and TQC`, it was not optimized anymore (#392)
* Fix ent coef loading bug * Add test * Add comment * Reuse save path
This commit is contained in:
parent
ddbe0e93f9
commit
c4304029a2
4 changed files with 25 additions and 4 deletions
|
|
@ -4,7 +4,7 @@ Changelog
|
|||
==========
|
||||
|
||||
|
||||
Release 1.1.0a2 (WIP)
|
||||
Release 1.1.0a3 (WIP)
|
||||
---------------------------
|
||||
|
||||
Breaking Changes:
|
||||
|
|
@ -21,6 +21,7 @@ New Features:
|
|||
Bug Fixes:
|
||||
^^^^^^^^^^
|
||||
- Fixed potential issue when calling off-policy algorithms with default arguments multiple times (the size of the replay buffer would be the same)
|
||||
- Fixed loading of ``ent_coef`` for ``SAC`` and ``TQC``, it was not optimized anymore (thanks @Atlis)
|
||||
|
||||
Deprecations:
|
||||
^^^^^^^^^^^^^
|
||||
|
|
@ -649,4 +650,4 @@ And all the contributors:
|
|||
@tirafesi @blurLake @koulakis @joeljosephjin @shwang @rk37 @andyshih12 @RaphaelWag @xicocaio
|
||||
@diditforlulz273 @liorcohen5 @ManifoldFR @mloo3 @SwamyDev @wmmc88 @megan-klaiber @thisray
|
||||
@tfederico @hn2 @LucasAlegre @AptX395 @zampanteymedio @decodyng @ardabbour @lorenz-h @mschweizer @lorepieri8 @vwxyzjn
|
||||
@ShangqunYu @PierreExeter @JacopoPan @ltbd78 @tom-doerr
|
||||
@ShangqunYu @PierreExeter @JacopoPan @ltbd78 @tom-doerr @Atlis
|
||||
|
|
|
|||
|
|
@ -650,7 +650,9 @@ class BaseAlgorithm(ABC):
|
|||
# put other pytorch variables back in place
|
||||
if pytorch_variables is not None:
|
||||
for name in pytorch_variables:
|
||||
recursive_setattr(model, name, pytorch_variables[name])
|
||||
# Set the data attribute directly to avoid issue when using optimizers
|
||||
# See https://github.com/DLR-RM/stable-baselines3/issues/391
|
||||
recursive_setattr(model, name + ".data", pytorch_variables[name].data)
|
||||
|
||||
# Sample gSDE exploration matrix, so it uses the right device
|
||||
# see issue #44
|
||||
|
|
|
|||
|
|
@ -1 +1 @@
|
|||
1.1.0a2
|
||||
1.1.0a3
|
||||
|
|
|
|||
|
|
@ -233,6 +233,24 @@ def test_exclude_include_saved_params(tmp_path, model_class):
|
|||
os.remove(tmp_path / "test_save.zip")
|
||||
|
||||
|
||||
def test_save_load_pytorch_var(tmp_path):
|
||||
model = SAC("MlpPolicy", "Pendulum-v0", seed=3, policy_kwargs=dict(net_arch=[64], n_critics=1))
|
||||
model.learn(200)
|
||||
save_path = str(tmp_path / "sac_pendulum")
|
||||
model.save(save_path)
|
||||
env = model.get_env()
|
||||
ent_coef_before = model.log_ent_coef
|
||||
|
||||
del model
|
||||
|
||||
model = SAC.load(save_path, env=env)
|
||||
assert th.allclose(ent_coef_before, model.log_ent_coef)
|
||||
model.learn(200)
|
||||
ent_coef_after = model.log_ent_coef
|
||||
# Check that the entropy coefficient is still optimized
|
||||
assert not th.allclose(ent_coef_before, ent_coef_after)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model_class", [A2C, TD3])
|
||||
def test_save_load_env_cnn(tmp_path, model_class):
|
||||
"""
|
||||
|
|
|
|||
Loading…
Reference in a new issue