fix: Follow PEP8 guidelines and evaluate falsy to truthy with not rather than is False. (#1707)

* fix: Follow PEP8 guidelines and evaluate falsy to truth with `not` rather than `is False`.

https://docs.python.org/2/library/stdtypes.html#truth-value-testing

* chore: Update changelog inline with intent of changes in PR #1707

Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>

* fix: Change `is False` to `not` as per PEP8

* chore: Remove superfluous comment about `is False`

* test: One On- and one Off-Policy algorithm (A2C and SAC respectively), with settings to speed up testing

* Update changelog

* chore: Remove EvalCallback as it's not actually required

* Update changelog.rst

* Rm duplicated "others" section in changelog.rst

---------

Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Co-authored-by: Antonin Raffin <antonin.raffin@ensta.org>
This commit is contained in:
Jan-Hendrik Ewers 2023-10-09 11:21:12 +01:00 committed by GitHub
parent c6bf251d46
commit 2ddf015cd9
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
7 changed files with 44 additions and 8 deletions

View file

@ -9,6 +9,7 @@ Release 2.2.0a7 (WIP)
Breaking Changes:
^^^^^^^^^^^^^^^^^
- Switched to ``ruff`` for sorting imports (isort is no longer needed), black and ruff version now require a minimum version
- Dropped ``x is False`` in favor of ``not x``, which means that callbacks that wrongly returned None (instead of a boolean) will cause the training to stop (@iwishiwasaneagle)
New Features:
^^^^^^^^^^^^^
@ -1462,7 +1463,7 @@ And all the contributors:
@eleurent @ac-93 @cove9988 @theDebugger811 @hsuehch @Demetrio92 @thomasgubler @IperGiove @ScheiklP
@simoninithomas @armandpl @manuel-delverme @Gautam-J @gianlucadecola @buoyancy99 @caburu @xy9485
@Gregwar @ycheng517 @quantitative-technologies @bcollazo @git-thor @TibiGG @cool-RR @MWeltevrede
@carlosluis @arjun-kg @tlpss @JonathanKuelz @Gabo-Tor
@carlosluis @arjun-kg @tlpss @JonathanKuelz @Gabo-Tor @iwishiwasaneagle
@Melanol @qgallouedec @francescoluciano @jlp-ue @burakdmb @timothe-chaumont @honglu2875
@anand-bala @hughperkins @sidney-tio @AlexPasqua @dominicgkerr @Akhilez @Rocamonde @tobirohrer @ZikangXiong
@DavyMorgan @luizapozzobon @Bonifatius94 @theSquaredError @harveybellini @DavyMorgan @FieteO @jonasreiher @npit @WeberSamuel @troiganto

View file

@ -561,7 +561,7 @@ class DictReplayBuffer(ReplayBuffer):
if psutil is not None:
mem_available = psutil.virtual_memory().available
assert optimize_memory_usage is False, "DictReplayBuffer does not support optimize_memory_usage"
assert not optimize_memory_usage, "DictReplayBuffer does not support optimize_memory_usage"
# disabling as this adds quite a bit of complexity
# https://github.com/DLR-RM/stable-baselines3/pull/243#discussion_r531535702
self.optimize_memory_usage = optimize_memory_usage

View file

@ -554,7 +554,6 @@ class StopTrainingOnRewardThreshold(BaseCallback):
def _on_step(self) -> bool:
assert self.parent is not None, "``StopTrainingOnMinimumReward`` callback must be used with an ``EvalCallback``"
# Convert np.bool_ to bool, otherwise callback() is False won't work
continue_training = bool(self.parent.best_mean_reward < self.reward_threshold)
if self.verbose >= 1 and not continue_training:
print(

View file

@ -330,7 +330,7 @@ class OffPolicyAlgorithm(BaseAlgorithm):
log_interval=log_interval,
)
if rollout.continue_training is False:
if not rollout.continue_training:
break
if self.num_timesteps > 0 and self.num_timesteps > self.learning_starts:
@ -556,7 +556,7 @@ class OffPolicyAlgorithm(BaseAlgorithm):
# Give access to local variables
callback.update_locals(locals())
# Only stop training if return value is False, not when it is None.
if callback.on_step() is False:
if not callback.on_step():
return RolloutReturn(num_collected_steps * env.num_envs, num_collected_episodes, continue_training=False)
# Retrieve reward and episode length if using Monitor wrapper

View file

@ -186,7 +186,7 @@ class OnPolicyAlgorithm(BaseAlgorithm):
# Give access to local variables
callback.update_locals(locals())
if callback.on_step() is False:
if not callback.on_step():
return False
self._update_info_buffer(infos)
@ -265,7 +265,7 @@ class OnPolicyAlgorithm(BaseAlgorithm):
while self.num_timesteps < total_timesteps:
continue_training = self.collect_rollouts(self.env, callback, self.rollout_buffer, n_rollout_steps=self.n_steps)
if continue_training is False:
if not continue_training:
break
iteration += 1

View file

@ -90,7 +90,7 @@ class BaseModel(nn.Module):
self.features_extractor_class = features_extractor_class
self.features_extractor_kwargs = features_extractor_kwargs
# Automatically deactivate dtype and bounds checks
if normalize_images is False and issubclass(features_extractor_class, (NatureCNN, CombinedExtractor)):
if not normalize_images and issubclass(features_extractor_class, (NatureCNN, CombinedExtractor)):
self.features_extractor_kwargs.update(dict(normalized_image=True))
def _update_features_extractor(

View file

@ -4,9 +4,11 @@ import shutil
import gymnasium as gym
import numpy as np
import pytest
import torch as th
from stable_baselines3 import A2C, DDPG, DQN, PPO, SAC, TD3, HerReplayBuffer
from stable_baselines3.common.callbacks import (
BaseCallback,
CallbackList,
CheckpointCallback,
EvalCallback,
@ -123,6 +125,40 @@ def test_eval_callback_vec_env():
assert eval_callback.last_mean_reward == 100.0
class AlwaysFailCallback(BaseCallback):
def __init__(self, *args, callback_false_value, **kwargs):
super().__init__(*args, **kwargs)
self.callback_false_value = callback_false_value
def _on_step(self) -> bool:
return self.callback_false_value
@pytest.mark.parametrize(
"model_class,model_kwargs",
[
(A2C, dict(n_steps=1, stats_window_size=1)),
(
SAC,
dict(
learning_starts=1,
buffer_size=1,
batch_size=1,
),
),
],
)
@pytest.mark.parametrize("callback_false_value", [False, np.bool_(0), th.tensor(0, dtype=th.bool)])
def test_callbacks_can_cancel_runs(model_class, model_kwargs, callback_false_value):
assert not callback_false_value # Sanity check to ensure parametrized values are valid
env_id = select_env(model_class)
model = model_class("MlpPolicy", env_id, **model_kwargs, policy_kwargs=dict(net_arch=[2]))
alwaysfailcallback = AlwaysFailCallback(callback_false_value=callback_false_value)
model.learn(10, callback=alwaysfailcallback)
assert alwaysfailcallback.n_calls == 1
def test_eval_success_logging(tmp_path):
n_bits = 2
n_envs = 2