mirror of
https://github.com/saymrwulf/stable-baselines3.git
synced 2026-05-18 21:30:19 +00:00
Fix advantage normalization with mini-batchsize of 1 (#1028)
* fix nan in advnatages with batch size 1, for ppo * changelog * black * Simplify test * Bump version Co-authored-by: Antonin Raffin <antonin.raffin@ensta.org>
This commit is contained in:
parent
59af0c1b01
commit
2cc1477fa2
6 changed files with 36 additions and 21 deletions
|
|
@ -3,7 +3,7 @@
|
|||
Changelog
|
||||
==========
|
||||
|
||||
Release 1.6.1a2 (WIP)
|
||||
Release 1.6.1a3 (WIP)
|
||||
---------------------------
|
||||
|
||||
Breaking Changes:
|
||||
|
|
@ -20,6 +20,7 @@ SB3-Contrib
|
|||
|
||||
Bug Fixes:
|
||||
^^^^^^^^^^
|
||||
- Fixed issue where ``PPO`` gives NaN if rollout buffer provides a batch of size 1 (@hughperkins)
|
||||
- Fixed the issue that ``predict`` does not always return action as ``np.ndarray`` (@qgallouedec)
|
||||
- Fixed division by zero error when computing FPS when a small number of time has elapsed in operating systems with low-precision timers.
|
||||
- Added multidimensional action space support (@qgallouedec)
|
||||
|
|
@ -1029,4 +1030,4 @@ And all the contributors:
|
|||
@simoninithomas @armandpl @manuel-delverme @Gautam-J @gianlucadecola @buoyancy99 @caburu @xy9485
|
||||
@Gregwar @ycheng517 @quantitative-technologies @bcollazo @git-thor @TibiGG @cool-RR @MWeltevrede
|
||||
@Melanol @qgallouedec @francescoluciano @jlp-ue @burakdmb @timothe-chaumont @honglu2875
|
||||
@anand-bala
|
||||
@anand-bala @hughperkins
|
||||
|
|
|
|||
|
|
@ -137,8 +137,8 @@ class PPO(OnPolicyAlgorithm):
|
|||
# Check that `n_steps * n_envs > 1` to avoid NaN
|
||||
# when doing advantage normalization
|
||||
buffer_size = self.env.num_envs * self.n_steps
|
||||
assert (
|
||||
buffer_size > 1
|
||||
assert buffer_size > 1 or (
|
||||
not normalize_advantage
|
||||
), f"`n_steps * n_envs` must be greater than 1. Currently n_steps={self.n_steps} and n_envs={self.env.num_envs}"
|
||||
# Check that the rollout buffer size is a multiple of the mini-batch size
|
||||
untruncated_batches = buffer_size // batch_size
|
||||
|
|
@ -210,7 +210,8 @@ class PPO(OnPolicyAlgorithm):
|
|||
values = values.flatten()
|
||||
# Normalize advantage
|
||||
advantages = rollout_data.advantages
|
||||
if self.normalize_advantage:
|
||||
# Normalization does not make sense if mini batchsize == 1, see GH issue #325
|
||||
if self.normalize_advantage and len(advantages) > 1:
|
||||
advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)
|
||||
|
||||
# ratio between old and new policy, should be one at the first iteration
|
||||
|
|
|
|||
|
|
@ -261,7 +261,7 @@ class SAC(OffPolicyAlgorithm):
|
|||
|
||||
# Compute actor loss
|
||||
# Alternative: actor_loss = th.mean(log_prob - qf1_pi)
|
||||
# Mean over all critic networks
|
||||
# Min over all critic networks
|
||||
q_values_pi = th.cat(self.critic(replay_data.observations, actions_pi), dim=1)
|
||||
min_qf_pi, _ = th.min(q_values_pi, dim=1, keepdim=True)
|
||||
actor_loss = (ent_coef * log_prob - min_qf_pi).mean()
|
||||
|
|
|
|||
|
|
@ -1 +1 @@
|
|||
1.6.1a2
|
||||
1.6.1a3
|
||||
|
|
|
|||
|
|
@ -213,3 +213,29 @@ def test_warn_dqn_multi_env():
|
|||
buffer_size=100,
|
||||
target_update_interval=1,
|
||||
)
|
||||
|
||||
|
||||
def test_ppo_warnings():
|
||||
"""Test that PPO warns and errors correctly on
|
||||
problematic rollout buffer sizes"""
|
||||
|
||||
# Only 1 step: advantage normalization will return NaN
|
||||
with pytest.raises(AssertionError):
|
||||
PPO("MlpPolicy", "Pendulum-v1", n_steps=1)
|
||||
|
||||
# batch_size of 1 is allowed when normalize_advantage=False
|
||||
model = PPO("MlpPolicy", "Pendulum-v1", n_steps=1, batch_size=1, normalize_advantage=False)
|
||||
model.learn(4)
|
||||
|
||||
# Truncated mini-batch
|
||||
# Batch size 1 yields NaN with normalized advantage because
|
||||
# torch.std(some_length_1_tensor) == NaN
|
||||
# advantage normalization is automatically deactivated
|
||||
# in that case
|
||||
with pytest.warns(UserWarning, match="there will be a truncated mini-batch of size 1"):
|
||||
model = PPO("MlpPolicy", "Pendulum-v1", n_steps=64, batch_size=63, verbose=1)
|
||||
model.learn(64)
|
||||
|
||||
loss = model.logger.name_to_value["train/loss"]
|
||||
assert loss > 0
|
||||
assert not np.isnan(loss) # check not nan (since nan does not equal nan)
|
||||
|
|
|
|||
|
|
@ -8,7 +8,7 @@ import torch as th
|
|||
from gym import spaces
|
||||
|
||||
import stable_baselines3 as sb3
|
||||
from stable_baselines3 import A2C, PPO
|
||||
from stable_baselines3 import A2C
|
||||
from stable_baselines3.common.atari_wrappers import ClipRewardEnv, MaxAndSkipEnv
|
||||
from stable_baselines3.common.env_util import is_wrapped, make_atari_env, make_vec_env, unwrap_wrapper
|
||||
from stable_baselines3.common.evaluation import evaluate_policy
|
||||
|
|
@ -388,19 +388,6 @@ def test_is_wrapped():
|
|||
assert unwrap_wrapper(env, Monitor) == monitor_env
|
||||
|
||||
|
||||
def test_ppo_warnings():
|
||||
"""Test that PPO warns and errors correctly on
|
||||
problematic rollour buffer sizes"""
|
||||
|
||||
# Only 1 step: advantage normalization will return NaN
|
||||
with pytest.raises(AssertionError):
|
||||
PPO("MlpPolicy", "Pendulum-v1", n_steps=1)
|
||||
|
||||
# Truncated mini-batch
|
||||
with pytest.warns(UserWarning):
|
||||
PPO("MlpPolicy", "Pendulum-v1", n_steps=6, batch_size=8)
|
||||
|
||||
|
||||
def test_get_system_info():
|
||||
info, info_str = get_system_info(print_info=True)
|
||||
assert info["Stable-Baselines3"] == str(sb3.__version__)
|
||||
|
|
|
|||
Loading…
Reference in a new issue