Allow PPO to turn off advantage normalization (#763)

* Allow PPO to turn of advantage normalization

* update changelog

* Add a test case

* Update test and sanity check

* Fix tests

Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org>
This commit is contained in:
Costa Huang 2022-02-22 09:29:21 -05:00 committed by GitHub
parent 7ce4bb8016
commit d2ebd2eeaa
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 17 additions and 4 deletions

View file

@ -16,6 +16,7 @@ New Features:
^^^^^^^^^^^^^
- Makes the length of keys and values in ``HumanOutputFormat`` configurable,
depending on desired maximum width of output.
- Allow PPO to turn of advantage normalization (see `PR #763 <https://github.com/DLR-RM/stable-baselines3/pull/763>`_) @vwxyzjn
SB3-Contrib
^^^^^^^^^^^

View file

@ -42,6 +42,7 @@ class PPO(OnPolicyAlgorithm):
This is a parameter specific to the OpenAI implementation. If None is passed (default),
no clipping will be done on the value function.
IMPORTANT: this clipping depends on the reward scaling.
:param normalize_advantage: Whether to normalize or not the advantage
:param ent_coef: Entropy coefficient for the loss calculation
:param vf_coef: Value function coefficient for the loss calculation
:param max_grad_norm: The maximum value for the gradient clipping
@ -76,6 +77,7 @@ class PPO(OnPolicyAlgorithm):
gae_lambda: float = 0.95,
clip_range: Union[float, Schedule] = 0.2,
clip_range_vf: Union[None, float, Schedule] = None,
normalize_advantage: bool = True,
ent_coef: float = 0.0,
vf_coef: float = 0.5,
max_grad_norm: float = 0.5,
@ -120,9 +122,10 @@ class PPO(OnPolicyAlgorithm):
# Sanity check, otherwise it will lead to noisy gradient and NaN
# because of the advantage normalization
assert (
batch_size > 1
), "`batch_size` must be greater than 1. See https://github.com/DLR-RM/stable-baselines3/issues/440"
if normalize_advantage:
assert (
batch_size > 1
), "`batch_size` must be greater than 1. See https://github.com/DLR-RM/stable-baselines3/issues/440"
if self.env is not None:
# Check that `n_steps * n_envs > 1` to avoid NaN
@ -146,6 +149,7 @@ class PPO(OnPolicyAlgorithm):
self.n_epochs = n_epochs
self.clip_range = clip_range
self.clip_range_vf = clip_range_vf
self.normalize_advantage = normalize_advantage
self.target_kl = target_kl
if _init_setup_model:
@ -200,7 +204,8 @@ class PPO(OnPolicyAlgorithm):
values = values.flatten()
# Normalize advantage
advantages = rollout_data.advantages
advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)
if self.normalize_advantage:
advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)
# ratio between old and new policy, should be one at the first iteration
ratio = th.exp(log_prob - rollout_data.old_log_prob)

View file

@ -34,6 +34,13 @@ def test_a2c(env_id):
model.learn(total_timesteps=1000, eval_freq=500)
@pytest.mark.parametrize("model_class", [A2C, PPO])
@pytest.mark.parametrize("normalize_advantage", [False, True])
def test_advantage_normalization(model_class, normalize_advantage):
model = model_class("MlpPolicy", "CartPole-v1", n_steps=64, normalize_advantage=normalize_advantage)
model.learn(64)
@pytest.mark.parametrize("env_id", ["CartPole-v1", "Pendulum-v1"])
@pytest.mark.parametrize("clip_range_vf", [None, 0.2, -0.2])
def test_ppo(env_id, clip_range_vf):