mirror of
https://github.com/saymrwulf/stable-baselines3.git
synced 2026-05-28 22:56:53 +00:00
Fix synchronization bug with EvalCallback (#907)
This commit is contained in:
parent
c2518dc160
commit
0fadc94df3
4 changed files with 20 additions and 6 deletions
|
|
@ -4,7 +4,7 @@ Changelog
|
|||
==========
|
||||
|
||||
|
||||
Release 1.5.1a5 (WIP)
|
||||
Release 1.5.1a6 (WIP)
|
||||
---------------------------
|
||||
|
||||
Breaking Changes:
|
||||
|
|
@ -25,6 +25,7 @@ Bug Fixes:
|
|||
- Fixed final goal selection strategy that did not sample the final achieved goal (@qgallouedec)
|
||||
- Fixed a bug with special characters in the tensorboard log name (@quantitative-technologies)
|
||||
- Fixed a bug in ``DummyVecEnv``'s and ``SubprocVecEnv``'s seeding function. None value was unchecked (@ScheiklP)
|
||||
- Fixed a bug where ``EvalCallback`` would crash when trying to synchronize ``VecNormalize`` stats when observation normalization was disabled
|
||||
|
||||
Deprecations:
|
||||
^^^^^^^^^^^^^
|
||||
|
|
|
|||
|
|
@ -66,7 +66,9 @@ def sync_envs_normalization(env: "GymEnv", eval_env: "GymEnv") -> None:
|
|||
env_tmp, eval_env_tmp = env, eval_env
|
||||
while isinstance(env_tmp, VecEnvWrapper):
|
||||
if isinstance(env_tmp, VecNormalize):
|
||||
eval_env_tmp.obs_rms = deepcopy(env_tmp.obs_rms)
|
||||
# Only synchronize if observation normalization exists
|
||||
if hasattr(env_tmp, "obs_rms"):
|
||||
eval_env_tmp.obs_rms = deepcopy(env_tmp.obs_rms)
|
||||
eval_env_tmp.ret_rms = deepcopy(env_tmp.ret_rms)
|
||||
env_tmp = env_tmp.venv
|
||||
eval_env_tmp = eval_env_tmp.venv
|
||||
|
|
|
|||
|
|
@ -1 +1 @@
|
|||
1.5.1a5
|
||||
1.5.1a6
|
||||
|
|
|
|||
|
|
@ -388,11 +388,11 @@ def test_offpolicy_normalization(model_class, online_sampling):
|
|||
|
||||
@pytest.mark.parametrize("make_env", [make_env, make_dict_env])
|
||||
def test_sync_vec_normalize(make_env):
|
||||
env = DummyVecEnv([make_env])
|
||||
original_env = DummyVecEnv([make_env])
|
||||
|
||||
assert unwrap_vec_normalize(env) is None
|
||||
assert unwrap_vec_normalize(original_env) is None
|
||||
|
||||
env = VecNormalize(env, norm_obs=True, norm_reward=True, clip_obs=100.0, clip_reward=100.0)
|
||||
env = VecNormalize(original_env, norm_obs=True, norm_reward=True, clip_obs=100.0, clip_reward=100.0)
|
||||
|
||||
assert isinstance(unwrap_vec_normalize(env), VecNormalize)
|
||||
|
||||
|
|
@ -433,6 +433,17 @@ def test_sync_vec_normalize(make_env):
|
|||
assert allclose(obs, eval_env.normalize_obs(original_obs))
|
||||
assert allclose(env.normalize_reward(dummy_rewards), eval_env.normalize_reward(dummy_rewards))
|
||||
|
||||
# Check synchronization when only reward is normalized
|
||||
env = VecNormalize(original_env, norm_obs=False, norm_reward=True, clip_reward=100.0)
|
||||
eval_env = DummyVecEnv([make_env])
|
||||
eval_env = VecNormalize(eval_env, training=False, norm_obs=False, norm_reward=False)
|
||||
env.reset()
|
||||
env.step([env.action_space.sample()])
|
||||
assert not np.allclose(env.ret_rms.mean, eval_env.ret_rms.mean)
|
||||
sync_envs_normalization(env, eval_env)
|
||||
assert np.allclose(env.ret_rms.mean, eval_env.ret_rms.mean)
|
||||
assert np.allclose(env.ret_rms.var, eval_env.ret_rms.var)
|
||||
|
||||
|
||||
def test_discrete_obs():
|
||||
with pytest.raises(ValueError, match=".*only supports.*"):
|
||||
|
|
|
|||
Loading…
Reference in a new issue