diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index 80f1dbd..fada248 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -9,6 +9,7 @@ Release 2.2.0a7 (WIP) Breaking Changes: ^^^^^^^^^^^^^^^^^ - Switched to ``ruff`` for sorting imports (isort is no longer needed), black and ruff version now require a minimum version +- Dropped ``x is False`` in favor of ``not x``, which means that callbacks that wrongly returned None (instead of a boolean) will cause the training to stop (@iwishiwasaneagle) New Features: ^^^^^^^^^^^^^ @@ -1462,7 +1463,7 @@ And all the contributors: @eleurent @ac-93 @cove9988 @theDebugger811 @hsuehch @Demetrio92 @thomasgubler @IperGiove @ScheiklP @simoninithomas @armandpl @manuel-delverme @Gautam-J @gianlucadecola @buoyancy99 @caburu @xy9485 @Gregwar @ycheng517 @quantitative-technologies @bcollazo @git-thor @TibiGG @cool-RR @MWeltevrede -@carlosluis @arjun-kg @tlpss @JonathanKuelz @Gabo-Tor +@carlosluis @arjun-kg @tlpss @JonathanKuelz @Gabo-Tor @iwishiwasaneagle @Melanol @qgallouedec @francescoluciano @jlp-ue @burakdmb @timothe-chaumont @honglu2875 @anand-bala @hughperkins @sidney-tio @AlexPasqua @dominicgkerr @Akhilez @Rocamonde @tobirohrer @ZikangXiong @DavyMorgan @luizapozzobon @Bonifatius94 @theSquaredError @harveybellini @DavyMorgan @FieteO @jonasreiher @npit @WeberSamuel @troiganto diff --git a/stable_baselines3/common/buffers.py b/stable_baselines3/common/buffers.py index a230a31..306b435 100644 --- a/stable_baselines3/common/buffers.py +++ b/stable_baselines3/common/buffers.py @@ -561,7 +561,7 @@ class DictReplayBuffer(ReplayBuffer): if psutil is not None: mem_available = psutil.virtual_memory().available - assert optimize_memory_usage is False, "DictReplayBuffer does not support optimize_memory_usage" + assert not optimize_memory_usage, "DictReplayBuffer does not support optimize_memory_usage" # disabling as this adds quite a bit of complexity # https://github.com/DLR-RM/stable-baselines3/pull/243#discussion_r531535702 self.optimize_memory_usage = optimize_memory_usage diff --git a/stable_baselines3/common/callbacks.py b/stable_baselines3/common/callbacks.py index 54f1b97..5089bba 100644 --- a/stable_baselines3/common/callbacks.py +++ b/stable_baselines3/common/callbacks.py @@ -554,7 +554,6 @@ class StopTrainingOnRewardThreshold(BaseCallback): def _on_step(self) -> bool: assert self.parent is not None, "``StopTrainingOnMinimumReward`` callback must be used with an ``EvalCallback``" - # Convert np.bool_ to bool, otherwise callback() is False won't work continue_training = bool(self.parent.best_mean_reward < self.reward_threshold) if self.verbose >= 1 and not continue_training: print( diff --git a/stable_baselines3/common/off_policy_algorithm.py b/stable_baselines3/common/off_policy_algorithm.py index 2caaf8e..e8dcac4 100644 --- a/stable_baselines3/common/off_policy_algorithm.py +++ b/stable_baselines3/common/off_policy_algorithm.py @@ -330,7 +330,7 @@ class OffPolicyAlgorithm(BaseAlgorithm): log_interval=log_interval, ) - if rollout.continue_training is False: + if not rollout.continue_training: break if self.num_timesteps > 0 and self.num_timesteps > self.learning_starts: @@ -556,7 +556,7 @@ class OffPolicyAlgorithm(BaseAlgorithm): # Give access to local variables callback.update_locals(locals()) # Only stop training if return value is False, not when it is None. - if callback.on_step() is False: + if not callback.on_step(): return RolloutReturn(num_collected_steps * env.num_envs, num_collected_episodes, continue_training=False) # Retrieve reward and episode length if using Monitor wrapper diff --git a/stable_baselines3/common/on_policy_algorithm.py b/stable_baselines3/common/on_policy_algorithm.py index 1e0f9e6..4f9bb08 100644 --- a/stable_baselines3/common/on_policy_algorithm.py +++ b/stable_baselines3/common/on_policy_algorithm.py @@ -186,7 +186,7 @@ class OnPolicyAlgorithm(BaseAlgorithm): # Give access to local variables callback.update_locals(locals()) - if callback.on_step() is False: + if not callback.on_step(): return False self._update_info_buffer(infos) @@ -265,7 +265,7 @@ class OnPolicyAlgorithm(BaseAlgorithm): while self.num_timesteps < total_timesteps: continue_training = self.collect_rollouts(self.env, callback, self.rollout_buffer, n_rollout_steps=self.n_steps) - if continue_training is False: + if not continue_training: break iteration += 1 diff --git a/stable_baselines3/common/policies.py b/stable_baselines3/common/policies.py index 9bb7b11..5f57672 100644 --- a/stable_baselines3/common/policies.py +++ b/stable_baselines3/common/policies.py @@ -90,7 +90,7 @@ class BaseModel(nn.Module): self.features_extractor_class = features_extractor_class self.features_extractor_kwargs = features_extractor_kwargs # Automatically deactivate dtype and bounds checks - if normalize_images is False and issubclass(features_extractor_class, (NatureCNN, CombinedExtractor)): + if not normalize_images and issubclass(features_extractor_class, (NatureCNN, CombinedExtractor)): self.features_extractor_kwargs.update(dict(normalized_image=True)) def _update_features_extractor( diff --git a/tests/test_callbacks.py b/tests/test_callbacks.py index f8b0e54..d159c43 100644 --- a/tests/test_callbacks.py +++ b/tests/test_callbacks.py @@ -4,9 +4,11 @@ import shutil import gymnasium as gym import numpy as np import pytest +import torch as th from stable_baselines3 import A2C, DDPG, DQN, PPO, SAC, TD3, HerReplayBuffer from stable_baselines3.common.callbacks import ( + BaseCallback, CallbackList, CheckpointCallback, EvalCallback, @@ -123,6 +125,40 @@ def test_eval_callback_vec_env(): assert eval_callback.last_mean_reward == 100.0 +class AlwaysFailCallback(BaseCallback): + def __init__(self, *args, callback_false_value, **kwargs): + super().__init__(*args, **kwargs) + self.callback_false_value = callback_false_value + + def _on_step(self) -> bool: + return self.callback_false_value + + +@pytest.mark.parametrize( + "model_class,model_kwargs", + [ + (A2C, dict(n_steps=1, stats_window_size=1)), + ( + SAC, + dict( + learning_starts=1, + buffer_size=1, + batch_size=1, + ), + ), + ], +) +@pytest.mark.parametrize("callback_false_value", [False, np.bool_(0), th.tensor(0, dtype=th.bool)]) +def test_callbacks_can_cancel_runs(model_class, model_kwargs, callback_false_value): + assert not callback_false_value # Sanity check to ensure parametrized values are valid + env_id = select_env(model_class) + model = model_class("MlpPolicy", env_id, **model_kwargs, policy_kwargs=dict(net_arch=[2])) + alwaysfailcallback = AlwaysFailCallback(callback_false_value=callback_false_value) + model.learn(10, callback=alwaysfailcallback) + + assert alwaysfailcallback.n_calls == 1 + + def test_eval_success_logging(tmp_path): n_bits = 2 n_envs = 2