Fix advantage normalization with mini-batchsize of 1 (#1028)

* fix nan in advnatages with batch size 1, for ppo

* changelog

* black

* Simplify test

* Bump version

Co-authored-by: Antonin Raffin <antonin.raffin@ensta.org>
This commit is contained in:
Hugh Perkins 2022-08-25 05:50:08 -04:00 committed by GitHub
parent 59af0c1b01
commit 2cc1477fa2
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
6 changed files with 36 additions and 21 deletions

View file

@ -3,7 +3,7 @@
Changelog
==========
Release 1.6.1a2 (WIP)
Release 1.6.1a3 (WIP)
---------------------------
Breaking Changes:
@ -20,6 +20,7 @@ SB3-Contrib
Bug Fixes:
^^^^^^^^^^
- Fixed issue where ``PPO`` gives NaN if rollout buffer provides a batch of size 1 (@hughperkins)
- Fixed the issue that ``predict`` does not always return action as ``np.ndarray`` (@qgallouedec)
- Fixed division by zero error when computing FPS when a small number of time has elapsed in operating systems with low-precision timers.
- Added multidimensional action space support (@qgallouedec)
@ -1029,4 +1030,4 @@ And all the contributors:
@simoninithomas @armandpl @manuel-delverme @Gautam-J @gianlucadecola @buoyancy99 @caburu @xy9485
@Gregwar @ycheng517 @quantitative-technologies @bcollazo @git-thor @TibiGG @cool-RR @MWeltevrede
@Melanol @qgallouedec @francescoluciano @jlp-ue @burakdmb @timothe-chaumont @honglu2875
@anand-bala
@anand-bala @hughperkins

View file

@ -137,8 +137,8 @@ class PPO(OnPolicyAlgorithm):
# Check that `n_steps * n_envs > 1` to avoid NaN
# when doing advantage normalization
buffer_size = self.env.num_envs * self.n_steps
assert (
buffer_size > 1
assert buffer_size > 1 or (
not normalize_advantage
), f"`n_steps * n_envs` must be greater than 1. Currently n_steps={self.n_steps} and n_envs={self.env.num_envs}"
# Check that the rollout buffer size is a multiple of the mini-batch size
untruncated_batches = buffer_size // batch_size
@ -210,7 +210,8 @@ class PPO(OnPolicyAlgorithm):
values = values.flatten()
# Normalize advantage
advantages = rollout_data.advantages
if self.normalize_advantage:
# Normalization does not make sense if mini batchsize == 1, see GH issue #325
if self.normalize_advantage and len(advantages) > 1:
advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)
# ratio between old and new policy, should be one at the first iteration

View file

@ -261,7 +261,7 @@ class SAC(OffPolicyAlgorithm):
# Compute actor loss
# Alternative: actor_loss = th.mean(log_prob - qf1_pi)
# Mean over all critic networks
# Min over all critic networks
q_values_pi = th.cat(self.critic(replay_data.observations, actions_pi), dim=1)
min_qf_pi, _ = th.min(q_values_pi, dim=1, keepdim=True)
actor_loss = (ent_coef * log_prob - min_qf_pi).mean()

View file

@ -1 +1 @@
1.6.1a2
1.6.1a3

View file

@ -213,3 +213,29 @@ def test_warn_dqn_multi_env():
buffer_size=100,
target_update_interval=1,
)
def test_ppo_warnings():
"""Test that PPO warns and errors correctly on
problematic rollout buffer sizes"""
# Only 1 step: advantage normalization will return NaN
with pytest.raises(AssertionError):
PPO("MlpPolicy", "Pendulum-v1", n_steps=1)
# batch_size of 1 is allowed when normalize_advantage=False
model = PPO("MlpPolicy", "Pendulum-v1", n_steps=1, batch_size=1, normalize_advantage=False)
model.learn(4)
# Truncated mini-batch
# Batch size 1 yields NaN with normalized advantage because
# torch.std(some_length_1_tensor) == NaN
# advantage normalization is automatically deactivated
# in that case
with pytest.warns(UserWarning, match="there will be a truncated mini-batch of size 1"):
model = PPO("MlpPolicy", "Pendulum-v1", n_steps=64, batch_size=63, verbose=1)
model.learn(64)
loss = model.logger.name_to_value["train/loss"]
assert loss > 0
assert not np.isnan(loss) # check not nan (since nan does not equal nan)

View file

@ -8,7 +8,7 @@ import torch as th
from gym import spaces
import stable_baselines3 as sb3
from stable_baselines3 import A2C, PPO
from stable_baselines3 import A2C
from stable_baselines3.common.atari_wrappers import ClipRewardEnv, MaxAndSkipEnv
from stable_baselines3.common.env_util import is_wrapped, make_atari_env, make_vec_env, unwrap_wrapper
from stable_baselines3.common.evaluation import evaluate_policy
@ -388,19 +388,6 @@ def test_is_wrapped():
assert unwrap_wrapper(env, Monitor) == monitor_env
def test_ppo_warnings():
"""Test that PPO warns and errors correctly on
problematic rollour buffer sizes"""
# Only 1 step: advantage normalization will return NaN
with pytest.raises(AssertionError):
PPO("MlpPolicy", "Pendulum-v1", n_steps=1)
# Truncated mini-batch
with pytest.warns(UserWarning):
PPO("MlpPolicy", "Pendulum-v1", n_steps=6, batch_size=8)
def test_get_system_info():
info, info_str = get_system_info(print_info=True)
assert info["Stable-Baselines3"] == str(sb3.__version__)