mirror of
https://github.com/saymrwulf/stable-baselines3.git
synced 2026-07-02 03:55:39 +00:00
Fixed saving of A2C and PPO policy when using gSDE (#401)
This commit is contained in:
parent
5d47296b8d
commit
c69f7cd5e6
4 changed files with 15 additions and 5 deletions
|
|
@ -4,7 +4,7 @@ Changelog
|
|||
==========
|
||||
|
||||
|
||||
Release 1.1.0a4 (WIP)
|
||||
Release 1.1.0a5 (WIP)
|
||||
---------------------------
|
||||
|
||||
Breaking Changes:
|
||||
|
|
@ -23,6 +23,7 @@ Bug Fixes:
|
|||
^^^^^^^^^^
|
||||
- Fixed potential issue when calling off-policy algorithms with default arguments multiple times (the size of the replay buffer would be the same)
|
||||
- Fixed loading of ``ent_coef`` for ``SAC`` and ``TQC``, it was not optimized anymore (thanks @Atlis)
|
||||
- Fixed saving of ``A2C`` and ``PPO`` policy when using gSDE (thanks @liusida)
|
||||
|
||||
Deprecations:
|
||||
^^^^^^^^^^^^^
|
||||
|
|
@ -653,4 +654,4 @@ And all the contributors:
|
|||
@tirafesi @blurLake @koulakis @joeljosephjin @shwang @rk37 @andyshih12 @RaphaelWag @xicocaio
|
||||
@diditforlulz273 @liorcohen5 @ManifoldFR @mloo3 @SwamyDev @wmmc88 @megan-klaiber @thisray
|
||||
@tfederico @hn2 @LucasAlegre @AptX395 @zampanteymedio @decodyng @ardabbour @lorenz-h @mschweizer @lorepieri8 @vwxyzjn
|
||||
@ShangqunYu @PierreExeter @JacopoPan @ltbd78 @tom-doerr @Atlis
|
||||
@ShangqunYu @PierreExeter @JacopoPan @ltbd78 @tom-doerr @Atlis @liusida
|
||||
|
|
|
|||
|
|
@ -435,7 +435,7 @@ class ActorCriticPolicy(BasePolicy):
|
|||
log_std_init=self.log_std_init,
|
||||
squash_output=default_none_kwargs["squash_output"],
|
||||
full_std=default_none_kwargs["full_std"],
|
||||
sde_net_arch=default_none_kwargs["sde_net_arch"],
|
||||
sde_net_arch=self.sde_net_arch,
|
||||
use_expln=default_none_kwargs["use_expln"],
|
||||
lr_schedule=self._dummy_schedule, # dummy lr schedule, not needed for loading policy alone
|
||||
ortho_init=self.ortho_init,
|
||||
|
|
|
|||
|
|
@ -1 +1 @@
|
|||
1.1.0a4
|
||||
1.1.0a5
|
||||
|
|
|
|||
|
|
@ -341,7 +341,8 @@ def test_warn_buffer(recwarn, model_class, optimize_memory_usage):
|
|||
|
||||
@pytest.mark.parametrize("model_class", MODEL_LIST)
|
||||
@pytest.mark.parametrize("policy_str", ["MlpPolicy", "CnnPolicy"])
|
||||
def test_save_load_policy(tmp_path, model_class, policy_str):
|
||||
@pytest.mark.parametrize("use_sde", [False, True])
|
||||
def test_save_load_policy(tmp_path, model_class, policy_str, use_sde):
|
||||
"""
|
||||
Test saving and loading policy only.
|
||||
|
||||
|
|
@ -349,6 +350,11 @@ def test_save_load_policy(tmp_path, model_class, policy_str):
|
|||
:param policy_str: (str) Name of the policy.
|
||||
"""
|
||||
kwargs = dict(policy_kwargs=dict(net_arch=[16]))
|
||||
|
||||
# gSDE is only applicable for A2C, PPO and SAC
|
||||
if use_sde and model_class not in [A2C, PPO, SAC]:
|
||||
pytest.skip()
|
||||
|
||||
if policy_str == "MlpPolicy":
|
||||
env = select_env(model_class)
|
||||
else:
|
||||
|
|
@ -360,6 +366,9 @@ def test_save_load_policy(tmp_path, model_class, policy_str):
|
|||
)
|
||||
env = FakeImageEnv(screen_height=40, screen_width=40, n_channels=2, discrete=model_class == DQN)
|
||||
|
||||
if use_sde:
|
||||
kwargs["use_sde"] = True
|
||||
|
||||
env = DummyVecEnv([lambda: env])
|
||||
|
||||
# create model
|
||||
|
|
|
|||
Loading…
Reference in a new issue