Fixed saving of A2C and PPO policy when using gSDE (#401)

This commit is contained in:
Antonin RAFFIN 2021-04-19 12:23:02 +02:00 committed by GitHub
parent 5d47296b8d
commit c69f7cd5e6
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 15 additions and 5 deletions

View file

@ -4,7 +4,7 @@ Changelog
==========
Release 1.1.0a4 (WIP)
Release 1.1.0a5 (WIP)
---------------------------
Breaking Changes:
@ -23,6 +23,7 @@ Bug Fixes:
^^^^^^^^^^
- Fixed potential issue when calling off-policy algorithms with default arguments multiple times (the size of the replay buffer would be the same)
- Fixed loading of ``ent_coef`` for ``SAC`` and ``TQC``, it was not optimized anymore (thanks @Atlis)
- Fixed saving of ``A2C`` and ``PPO`` policy when using gSDE (thanks @liusida)
Deprecations:
^^^^^^^^^^^^^
@ -653,4 +654,4 @@ And all the contributors:
@tirafesi @blurLake @koulakis @joeljosephjin @shwang @rk37 @andyshih12 @RaphaelWag @xicocaio
@diditforlulz273 @liorcohen5 @ManifoldFR @mloo3 @SwamyDev @wmmc88 @megan-klaiber @thisray
@tfederico @hn2 @LucasAlegre @AptX395 @zampanteymedio @decodyng @ardabbour @lorenz-h @mschweizer @lorepieri8 @vwxyzjn
@ShangqunYu @PierreExeter @JacopoPan @ltbd78 @tom-doerr @Atlis
@ShangqunYu @PierreExeter @JacopoPan @ltbd78 @tom-doerr @Atlis @liusida

View file

@ -435,7 +435,7 @@ class ActorCriticPolicy(BasePolicy):
log_std_init=self.log_std_init,
squash_output=default_none_kwargs["squash_output"],
full_std=default_none_kwargs["full_std"],
sde_net_arch=default_none_kwargs["sde_net_arch"],
sde_net_arch=self.sde_net_arch,
use_expln=default_none_kwargs["use_expln"],
lr_schedule=self._dummy_schedule, # dummy lr schedule, not needed for loading policy alone
ortho_init=self.ortho_init,

View file

@ -1 +1 @@
1.1.0a4
1.1.0a5

View file

@ -341,7 +341,8 @@ def test_warn_buffer(recwarn, model_class, optimize_memory_usage):
@pytest.mark.parametrize("model_class", MODEL_LIST)
@pytest.mark.parametrize("policy_str", ["MlpPolicy", "CnnPolicy"])
def test_save_load_policy(tmp_path, model_class, policy_str):
@pytest.mark.parametrize("use_sde", [False, True])
def test_save_load_policy(tmp_path, model_class, policy_str, use_sde):
"""
Test saving and loading policy only.
@ -349,6 +350,11 @@ def test_save_load_policy(tmp_path, model_class, policy_str):
:param policy_str: (str) Name of the policy.
"""
kwargs = dict(policy_kwargs=dict(net_arch=[16]))
# gSDE is only applicable for A2C, PPO and SAC
if use_sde and model_class not in [A2C, PPO, SAC]:
pytest.skip()
if policy_str == "MlpPolicy":
env = select_env(model_class)
else:
@ -360,6 +366,9 @@ def test_save_load_policy(tmp_path, model_class, policy_str):
)
env = FakeImageEnv(screen_height=40, screen_width=40, n_channels=2, discrete=model_class == DQN)
if use_sde:
kwargs["use_sde"] = True
env = DummyVecEnv([lambda: env])
# create model