diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index 9a875bf..33c772f 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -16,6 +16,7 @@ New Features: ^^^^^^^^^^^^^ - Makes the length of keys and values in ``HumanOutputFormat`` configurable, depending on desired maximum width of output. +- Allow PPO to turn of advantage normalization (see `PR #763 `_) @vwxyzjn SB3-Contrib ^^^^^^^^^^^ diff --git a/stable_baselines3/ppo/ppo.py b/stable_baselines3/ppo/ppo.py index 9e16e04..088bab3 100644 --- a/stable_baselines3/ppo/ppo.py +++ b/stable_baselines3/ppo/ppo.py @@ -42,6 +42,7 @@ class PPO(OnPolicyAlgorithm): This is a parameter specific to the OpenAI implementation. If None is passed (default), no clipping will be done on the value function. IMPORTANT: this clipping depends on the reward scaling. + :param normalize_advantage: Whether to normalize or not the advantage :param ent_coef: Entropy coefficient for the loss calculation :param vf_coef: Value function coefficient for the loss calculation :param max_grad_norm: The maximum value for the gradient clipping @@ -76,6 +77,7 @@ class PPO(OnPolicyAlgorithm): gae_lambda: float = 0.95, clip_range: Union[float, Schedule] = 0.2, clip_range_vf: Union[None, float, Schedule] = None, + normalize_advantage: bool = True, ent_coef: float = 0.0, vf_coef: float = 0.5, max_grad_norm: float = 0.5, @@ -120,9 +122,10 @@ class PPO(OnPolicyAlgorithm): # Sanity check, otherwise it will lead to noisy gradient and NaN # because of the advantage normalization - assert ( - batch_size > 1 - ), "`batch_size` must be greater than 1. See https://github.com/DLR-RM/stable-baselines3/issues/440" + if normalize_advantage: + assert ( + batch_size > 1 + ), "`batch_size` must be greater than 1. See https://github.com/DLR-RM/stable-baselines3/issues/440" if self.env is not None: # Check that `n_steps * n_envs > 1` to avoid NaN @@ -146,6 +149,7 @@ class PPO(OnPolicyAlgorithm): self.n_epochs = n_epochs self.clip_range = clip_range self.clip_range_vf = clip_range_vf + self.normalize_advantage = normalize_advantage self.target_kl = target_kl if _init_setup_model: @@ -200,7 +204,8 @@ class PPO(OnPolicyAlgorithm): values = values.flatten() # Normalize advantage advantages = rollout_data.advantages - advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8) + if self.normalize_advantage: + advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8) # ratio between old and new policy, should be one at the first iteration ratio = th.exp(log_prob - rollout_data.old_log_prob) diff --git a/tests/test_run.py b/tests/test_run.py index 223776d..e4e8a2e 100644 --- a/tests/test_run.py +++ b/tests/test_run.py @@ -34,6 +34,13 @@ def test_a2c(env_id): model.learn(total_timesteps=1000, eval_freq=500) +@pytest.mark.parametrize("model_class", [A2C, PPO]) +@pytest.mark.parametrize("normalize_advantage", [False, True]) +def test_advantage_normalization(model_class, normalize_advantage): + model = model_class("MlpPolicy", "CartPole-v1", n_steps=64, normalize_advantage=normalize_advantage) + model.learn(64) + + @pytest.mark.parametrize("env_id", ["CartPole-v1", "Pendulum-v1"]) @pytest.mark.parametrize("clip_range_vf", [None, 0.2, -0.2]) def test_ppo(env_id, clip_range_vf):