diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index cb3a17f..d8b6d6d 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -4,7 +4,7 @@ Changelog ========== -Release 1.5.1a5 (WIP) +Release 1.5.1a6 (WIP) --------------------------- Breaking Changes: @@ -25,6 +25,7 @@ Bug Fixes: - Fixed final goal selection strategy that did not sample the final achieved goal (@qgallouedec) - Fixed a bug with special characters in the tensorboard log name (@quantitative-technologies) - Fixed a bug in ``DummyVecEnv``'s and ``SubprocVecEnv``'s seeding function. None value was unchecked (@ScheiklP) +- Fixed a bug where ``EvalCallback`` would crash when trying to synchronize ``VecNormalize`` stats when observation normalization was disabled Deprecations: ^^^^^^^^^^^^^ diff --git a/stable_baselines3/common/vec_env/__init__.py b/stable_baselines3/common/vec_env/__init__.py index 37ebc36..3880fbd 100644 --- a/stable_baselines3/common/vec_env/__init__.py +++ b/stable_baselines3/common/vec_env/__init__.py @@ -66,7 +66,9 @@ def sync_envs_normalization(env: "GymEnv", eval_env: "GymEnv") -> None: env_tmp, eval_env_tmp = env, eval_env while isinstance(env_tmp, VecEnvWrapper): if isinstance(env_tmp, VecNormalize): - eval_env_tmp.obs_rms = deepcopy(env_tmp.obs_rms) + # Only synchronize if observation normalization exists + if hasattr(env_tmp, "obs_rms"): + eval_env_tmp.obs_rms = deepcopy(env_tmp.obs_rms) eval_env_tmp.ret_rms = deepcopy(env_tmp.ret_rms) env_tmp = env_tmp.venv eval_env_tmp = eval_env_tmp.venv diff --git a/stable_baselines3/version.txt b/stable_baselines3/version.txt index bccb8c6..1e5deca 100644 --- a/stable_baselines3/version.txt +++ b/stable_baselines3/version.txt @@ -1 +1 @@ -1.5.1a5 +1.5.1a6 diff --git a/tests/test_vec_normalize.py b/tests/test_vec_normalize.py index 86a0d84..a363e40 100644 --- a/tests/test_vec_normalize.py +++ b/tests/test_vec_normalize.py @@ -388,11 +388,11 @@ def test_offpolicy_normalization(model_class, online_sampling): @pytest.mark.parametrize("make_env", [make_env, make_dict_env]) def test_sync_vec_normalize(make_env): - env = DummyVecEnv([make_env]) + original_env = DummyVecEnv([make_env]) - assert unwrap_vec_normalize(env) is None + assert unwrap_vec_normalize(original_env) is None - env = VecNormalize(env, norm_obs=True, norm_reward=True, clip_obs=100.0, clip_reward=100.0) + env = VecNormalize(original_env, norm_obs=True, norm_reward=True, clip_obs=100.0, clip_reward=100.0) assert isinstance(unwrap_vec_normalize(env), VecNormalize) @@ -433,6 +433,17 @@ def test_sync_vec_normalize(make_env): assert allclose(obs, eval_env.normalize_obs(original_obs)) assert allclose(env.normalize_reward(dummy_rewards), eval_env.normalize_reward(dummy_rewards)) + # Check synchronization when only reward is normalized + env = VecNormalize(original_env, norm_obs=False, norm_reward=True, clip_reward=100.0) + eval_env = DummyVecEnv([make_env]) + eval_env = VecNormalize(eval_env, training=False, norm_obs=False, norm_reward=False) + env.reset() + env.step([env.action_space.sample()]) + assert not np.allclose(env.ret_rms.mean, eval_env.ret_rms.mean) + sync_envs_normalization(env, eval_env) + assert np.allclose(env.ret_rms.mean, eval_env.ret_rms.mean) + assert np.allclose(env.ret_rms.var, eval_env.ret_rms.var) + def test_discrete_obs(): with pytest.raises(ValueError, match=".*only supports.*"):