diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index 126de2e..4643011 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -3,7 +3,7 @@ Changelog ========== -Release 1.6.1a2 (WIP) +Release 1.6.1a3 (WIP) --------------------------- Breaking Changes: @@ -20,6 +20,7 @@ SB3-Contrib Bug Fixes: ^^^^^^^^^^ +- Fixed issue where ``PPO`` gives NaN if rollout buffer provides a batch of size 1 (@hughperkins) - Fixed the issue that ``predict`` does not always return action as ``np.ndarray`` (@qgallouedec) - Fixed division by zero error when computing FPS when a small number of time has elapsed in operating systems with low-precision timers. - Added multidimensional action space support (@qgallouedec) @@ -1029,4 +1030,4 @@ And all the contributors: @simoninithomas @armandpl @manuel-delverme @Gautam-J @gianlucadecola @buoyancy99 @caburu @xy9485 @Gregwar @ycheng517 @quantitative-technologies @bcollazo @git-thor @TibiGG @cool-RR @MWeltevrede @Melanol @qgallouedec @francescoluciano @jlp-ue @burakdmb @timothe-chaumont @honglu2875 -@anand-bala +@anand-bala @hughperkins diff --git a/stable_baselines3/ppo/ppo.py b/stable_baselines3/ppo/ppo.py index 5b8d9e2..0f7f8e4 100644 --- a/stable_baselines3/ppo/ppo.py +++ b/stable_baselines3/ppo/ppo.py @@ -137,8 +137,8 @@ class PPO(OnPolicyAlgorithm): # Check that `n_steps * n_envs > 1` to avoid NaN # when doing advantage normalization buffer_size = self.env.num_envs * self.n_steps - assert ( - buffer_size > 1 + assert buffer_size > 1 or ( + not normalize_advantage ), f"`n_steps * n_envs` must be greater than 1. Currently n_steps={self.n_steps} and n_envs={self.env.num_envs}" # Check that the rollout buffer size is a multiple of the mini-batch size untruncated_batches = buffer_size // batch_size @@ -210,7 +210,8 @@ class PPO(OnPolicyAlgorithm): values = values.flatten() # Normalize advantage advantages = rollout_data.advantages - if self.normalize_advantage: + # Normalization does not make sense if mini batchsize == 1, see GH issue #325 + if self.normalize_advantage and len(advantages) > 1: advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8) # ratio between old and new policy, should be one at the first iteration diff --git a/stable_baselines3/sac/sac.py b/stable_baselines3/sac/sac.py index 6969ef1..ba27998 100644 --- a/stable_baselines3/sac/sac.py +++ b/stable_baselines3/sac/sac.py @@ -261,7 +261,7 @@ class SAC(OffPolicyAlgorithm): # Compute actor loss # Alternative: actor_loss = th.mean(log_prob - qf1_pi) - # Mean over all critic networks + # Min over all critic networks q_values_pi = th.cat(self.critic(replay_data.observations, actions_pi), dim=1) min_qf_pi, _ = th.min(q_values_pi, dim=1, keepdim=True) actor_loss = (ent_coef * log_prob - min_qf_pi).mean() diff --git a/stable_baselines3/version.txt b/stable_baselines3/version.txt index 51cf83a..7a35b06 100644 --- a/stable_baselines3/version.txt +++ b/stable_baselines3/version.txt @@ -1 +1 @@ -1.6.1a2 +1.6.1a3 diff --git a/tests/test_run.py b/tests/test_run.py index b0a9a11..655182d 100644 --- a/tests/test_run.py +++ b/tests/test_run.py @@ -213,3 +213,29 @@ def test_warn_dqn_multi_env(): buffer_size=100, target_update_interval=1, ) + + +def test_ppo_warnings(): + """Test that PPO warns and errors correctly on + problematic rollout buffer sizes""" + + # Only 1 step: advantage normalization will return NaN + with pytest.raises(AssertionError): + PPO("MlpPolicy", "Pendulum-v1", n_steps=1) + + # batch_size of 1 is allowed when normalize_advantage=False + model = PPO("MlpPolicy", "Pendulum-v1", n_steps=1, batch_size=1, normalize_advantage=False) + model.learn(4) + + # Truncated mini-batch + # Batch size 1 yields NaN with normalized advantage because + # torch.std(some_length_1_tensor) == NaN + # advantage normalization is automatically deactivated + # in that case + with pytest.warns(UserWarning, match="there will be a truncated mini-batch of size 1"): + model = PPO("MlpPolicy", "Pendulum-v1", n_steps=64, batch_size=63, verbose=1) + model.learn(64) + + loss = model.logger.name_to_value["train/loss"] + assert loss > 0 + assert not np.isnan(loss) # check not nan (since nan does not equal nan) diff --git a/tests/test_utils.py b/tests/test_utils.py index 57b4b39..2a9eade 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -8,7 +8,7 @@ import torch as th from gym import spaces import stable_baselines3 as sb3 -from stable_baselines3 import A2C, PPO +from stable_baselines3 import A2C from stable_baselines3.common.atari_wrappers import ClipRewardEnv, MaxAndSkipEnv from stable_baselines3.common.env_util import is_wrapped, make_atari_env, make_vec_env, unwrap_wrapper from stable_baselines3.common.evaluation import evaluate_policy @@ -388,19 +388,6 @@ def test_is_wrapped(): assert unwrap_wrapper(env, Monitor) == monitor_env -def test_ppo_warnings(): - """Test that PPO warns and errors correctly on - problematic rollour buffer sizes""" - - # Only 1 step: advantage normalization will return NaN - with pytest.raises(AssertionError): - PPO("MlpPolicy", "Pendulum-v1", n_steps=1) - - # Truncated mini-batch - with pytest.warns(UserWarning): - PPO("MlpPolicy", "Pendulum-v1", n_steps=6, batch_size=8) - - def test_get_system_info(): info, info_str = get_system_info(print_info=True) assert info["Stable-Baselines3"] == str(sb3.__version__)