mirror of
https://github.com/saymrwulf/stable-baselines3.git
synced 2026-05-17 21:20:11 +00:00
Allow PPO to turn off advantage normalization (#763)
* Allow PPO to turn of advantage normalization * update changelog * Add a test case * Update test and sanity check * Fix tests Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org>
This commit is contained in:
parent
7ce4bb8016
commit
d2ebd2eeaa
3 changed files with 17 additions and 4 deletions
|
|
@ -16,6 +16,7 @@ New Features:
|
|||
^^^^^^^^^^^^^
|
||||
- Makes the length of keys and values in ``HumanOutputFormat`` configurable,
|
||||
depending on desired maximum width of output.
|
||||
- Allow PPO to turn of advantage normalization (see `PR #763 <https://github.com/DLR-RM/stable-baselines3/pull/763>`_) @vwxyzjn
|
||||
|
||||
SB3-Contrib
|
||||
^^^^^^^^^^^
|
||||
|
|
|
|||
|
|
@ -42,6 +42,7 @@ class PPO(OnPolicyAlgorithm):
|
|||
This is a parameter specific to the OpenAI implementation. If None is passed (default),
|
||||
no clipping will be done on the value function.
|
||||
IMPORTANT: this clipping depends on the reward scaling.
|
||||
:param normalize_advantage: Whether to normalize or not the advantage
|
||||
:param ent_coef: Entropy coefficient for the loss calculation
|
||||
:param vf_coef: Value function coefficient for the loss calculation
|
||||
:param max_grad_norm: The maximum value for the gradient clipping
|
||||
|
|
@ -76,6 +77,7 @@ class PPO(OnPolicyAlgorithm):
|
|||
gae_lambda: float = 0.95,
|
||||
clip_range: Union[float, Schedule] = 0.2,
|
||||
clip_range_vf: Union[None, float, Schedule] = None,
|
||||
normalize_advantage: bool = True,
|
||||
ent_coef: float = 0.0,
|
||||
vf_coef: float = 0.5,
|
||||
max_grad_norm: float = 0.5,
|
||||
|
|
@ -120,9 +122,10 @@ class PPO(OnPolicyAlgorithm):
|
|||
|
||||
# Sanity check, otherwise it will lead to noisy gradient and NaN
|
||||
# because of the advantage normalization
|
||||
assert (
|
||||
batch_size > 1
|
||||
), "`batch_size` must be greater than 1. See https://github.com/DLR-RM/stable-baselines3/issues/440"
|
||||
if normalize_advantage:
|
||||
assert (
|
||||
batch_size > 1
|
||||
), "`batch_size` must be greater than 1. See https://github.com/DLR-RM/stable-baselines3/issues/440"
|
||||
|
||||
if self.env is not None:
|
||||
# Check that `n_steps * n_envs > 1` to avoid NaN
|
||||
|
|
@ -146,6 +149,7 @@ class PPO(OnPolicyAlgorithm):
|
|||
self.n_epochs = n_epochs
|
||||
self.clip_range = clip_range
|
||||
self.clip_range_vf = clip_range_vf
|
||||
self.normalize_advantage = normalize_advantage
|
||||
self.target_kl = target_kl
|
||||
|
||||
if _init_setup_model:
|
||||
|
|
@ -200,7 +204,8 @@ class PPO(OnPolicyAlgorithm):
|
|||
values = values.flatten()
|
||||
# Normalize advantage
|
||||
advantages = rollout_data.advantages
|
||||
advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)
|
||||
if self.normalize_advantage:
|
||||
advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)
|
||||
|
||||
# ratio between old and new policy, should be one at the first iteration
|
||||
ratio = th.exp(log_prob - rollout_data.old_log_prob)
|
||||
|
|
|
|||
|
|
@ -34,6 +34,13 @@ def test_a2c(env_id):
|
|||
model.learn(total_timesteps=1000, eval_freq=500)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model_class", [A2C, PPO])
|
||||
@pytest.mark.parametrize("normalize_advantage", [False, True])
|
||||
def test_advantage_normalization(model_class, normalize_advantage):
|
||||
model = model_class("MlpPolicy", "CartPole-v1", n_steps=64, normalize_advantage=normalize_advantage)
|
||||
model.learn(64)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("env_id", ["CartPole-v1", "Pendulum-v1"])
|
||||
@pytest.mark.parametrize("clip_range_vf", [None, 0.2, -0.2])
|
||||
def test_ppo(env_id, clip_range_vf):
|
||||
|
|
|
|||
Loading…
Reference in a new issue