Fix synchronization bug with EvalCallback (#907)

This commit is contained in:
Antonin RAFFIN 2022-05-08 20:54:34 +02:00 committed by GitHub
parent c2518dc160
commit 0fadc94df3
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 20 additions and 6 deletions

View file

@ -4,7 +4,7 @@ Changelog
==========
Release 1.5.1a5 (WIP)
Release 1.5.1a6 (WIP)
---------------------------
Breaking Changes:
@ -25,6 +25,7 @@ Bug Fixes:
- Fixed final goal selection strategy that did not sample the final achieved goal (@qgallouedec)
- Fixed a bug with special characters in the tensorboard log name (@quantitative-technologies)
- Fixed a bug in ``DummyVecEnv``'s and ``SubprocVecEnv``'s seeding function. None value was unchecked (@ScheiklP)
- Fixed a bug where ``EvalCallback`` would crash when trying to synchronize ``VecNormalize`` stats when observation normalization was disabled
Deprecations:
^^^^^^^^^^^^^

View file

@ -66,7 +66,9 @@ def sync_envs_normalization(env: "GymEnv", eval_env: "GymEnv") -> None:
env_tmp, eval_env_tmp = env, eval_env
while isinstance(env_tmp, VecEnvWrapper):
if isinstance(env_tmp, VecNormalize):
eval_env_tmp.obs_rms = deepcopy(env_tmp.obs_rms)
# Only synchronize if observation normalization exists
if hasattr(env_tmp, "obs_rms"):
eval_env_tmp.obs_rms = deepcopy(env_tmp.obs_rms)
eval_env_tmp.ret_rms = deepcopy(env_tmp.ret_rms)
env_tmp = env_tmp.venv
eval_env_tmp = eval_env_tmp.venv

View file

@ -1 +1 @@
1.5.1a5
1.5.1a6

View file

@ -388,11 +388,11 @@ def test_offpolicy_normalization(model_class, online_sampling):
@pytest.mark.parametrize("make_env", [make_env, make_dict_env])
def test_sync_vec_normalize(make_env):
env = DummyVecEnv([make_env])
original_env = DummyVecEnv([make_env])
assert unwrap_vec_normalize(env) is None
assert unwrap_vec_normalize(original_env) is None
env = VecNormalize(env, norm_obs=True, norm_reward=True, clip_obs=100.0, clip_reward=100.0)
env = VecNormalize(original_env, norm_obs=True, norm_reward=True, clip_obs=100.0, clip_reward=100.0)
assert isinstance(unwrap_vec_normalize(env), VecNormalize)
@ -433,6 +433,17 @@ def test_sync_vec_normalize(make_env):
assert allclose(obs, eval_env.normalize_obs(original_obs))
assert allclose(env.normalize_reward(dummy_rewards), eval_env.normalize_reward(dummy_rewards))
# Check synchronization when only reward is normalized
env = VecNormalize(original_env, norm_obs=False, norm_reward=True, clip_reward=100.0)
eval_env = DummyVecEnv([make_env])
eval_env = VecNormalize(eval_env, training=False, norm_obs=False, norm_reward=False)
env.reset()
env.step([env.action_space.sample()])
assert not np.allclose(env.ret_rms.mean, eval_env.ret_rms.mean)
sync_envs_normalization(env, eval_env)
assert np.allclose(env.ret_rms.mean, eval_env.ret_rms.mean)
assert np.allclose(env.ret_rms.var, eval_env.ret_rms.var)
def test_discrete_obs():
with pytest.raises(ValueError, match=".*only supports.*"):