mirror of
https://github.com/saymrwulf/stable-baselines3.git
synced 2026-05-16 21:10:08 +00:00
fix: Follow PEP8 guidelines and evaluate falsy to truthy with not rather than is False. (#1707)
* fix: Follow PEP8 guidelines and evaluate falsy to truth with `not` rather than `is False`. https://docs.python.org/2/library/stdtypes.html#truth-value-testing * chore: Update changelog inline with intent of changes in PR #1707 Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com> * fix: Change `is False` to `not` as per PEP8 * chore: Remove superfluous comment about `is False` * test: One On- and one Off-Policy algorithm (A2C and SAC respectively), with settings to speed up testing * Update changelog * chore: Remove EvalCallback as it's not actually required * Update changelog.rst * Rm duplicated "others" section in changelog.rst --------- Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com> Co-authored-by: Antonin Raffin <antonin.raffin@ensta.org>
This commit is contained in:
parent
c6bf251d46
commit
2ddf015cd9
7 changed files with 44 additions and 8 deletions
|
|
@ -9,6 +9,7 @@ Release 2.2.0a7 (WIP)
|
|||
Breaking Changes:
|
||||
^^^^^^^^^^^^^^^^^
|
||||
- Switched to ``ruff`` for sorting imports (isort is no longer needed), black and ruff version now require a minimum version
|
||||
- Dropped ``x is False`` in favor of ``not x``, which means that callbacks that wrongly returned None (instead of a boolean) will cause the training to stop (@iwishiwasaneagle)
|
||||
|
||||
New Features:
|
||||
^^^^^^^^^^^^^
|
||||
|
|
@ -1462,7 +1463,7 @@ And all the contributors:
|
|||
@eleurent @ac-93 @cove9988 @theDebugger811 @hsuehch @Demetrio92 @thomasgubler @IperGiove @ScheiklP
|
||||
@simoninithomas @armandpl @manuel-delverme @Gautam-J @gianlucadecola @buoyancy99 @caburu @xy9485
|
||||
@Gregwar @ycheng517 @quantitative-technologies @bcollazo @git-thor @TibiGG @cool-RR @MWeltevrede
|
||||
@carlosluis @arjun-kg @tlpss @JonathanKuelz @Gabo-Tor
|
||||
@carlosluis @arjun-kg @tlpss @JonathanKuelz @Gabo-Tor @iwishiwasaneagle
|
||||
@Melanol @qgallouedec @francescoluciano @jlp-ue @burakdmb @timothe-chaumont @honglu2875
|
||||
@anand-bala @hughperkins @sidney-tio @AlexPasqua @dominicgkerr @Akhilez @Rocamonde @tobirohrer @ZikangXiong
|
||||
@DavyMorgan @luizapozzobon @Bonifatius94 @theSquaredError @harveybellini @DavyMorgan @FieteO @jonasreiher @npit @WeberSamuel @troiganto
|
||||
|
|
|
|||
|
|
@ -561,7 +561,7 @@ class DictReplayBuffer(ReplayBuffer):
|
|||
if psutil is not None:
|
||||
mem_available = psutil.virtual_memory().available
|
||||
|
||||
assert optimize_memory_usage is False, "DictReplayBuffer does not support optimize_memory_usage"
|
||||
assert not optimize_memory_usage, "DictReplayBuffer does not support optimize_memory_usage"
|
||||
# disabling as this adds quite a bit of complexity
|
||||
# https://github.com/DLR-RM/stable-baselines3/pull/243#discussion_r531535702
|
||||
self.optimize_memory_usage = optimize_memory_usage
|
||||
|
|
|
|||
|
|
@ -554,7 +554,6 @@ class StopTrainingOnRewardThreshold(BaseCallback):
|
|||
|
||||
def _on_step(self) -> bool:
|
||||
assert self.parent is not None, "``StopTrainingOnMinimumReward`` callback must be used with an ``EvalCallback``"
|
||||
# Convert np.bool_ to bool, otherwise callback() is False won't work
|
||||
continue_training = bool(self.parent.best_mean_reward < self.reward_threshold)
|
||||
if self.verbose >= 1 and not continue_training:
|
||||
print(
|
||||
|
|
|
|||
|
|
@ -330,7 +330,7 @@ class OffPolicyAlgorithm(BaseAlgorithm):
|
|||
log_interval=log_interval,
|
||||
)
|
||||
|
||||
if rollout.continue_training is False:
|
||||
if not rollout.continue_training:
|
||||
break
|
||||
|
||||
if self.num_timesteps > 0 and self.num_timesteps > self.learning_starts:
|
||||
|
|
@ -556,7 +556,7 @@ class OffPolicyAlgorithm(BaseAlgorithm):
|
|||
# Give access to local variables
|
||||
callback.update_locals(locals())
|
||||
# Only stop training if return value is False, not when it is None.
|
||||
if callback.on_step() is False:
|
||||
if not callback.on_step():
|
||||
return RolloutReturn(num_collected_steps * env.num_envs, num_collected_episodes, continue_training=False)
|
||||
|
||||
# Retrieve reward and episode length if using Monitor wrapper
|
||||
|
|
|
|||
|
|
@ -186,7 +186,7 @@ class OnPolicyAlgorithm(BaseAlgorithm):
|
|||
|
||||
# Give access to local variables
|
||||
callback.update_locals(locals())
|
||||
if callback.on_step() is False:
|
||||
if not callback.on_step():
|
||||
return False
|
||||
|
||||
self._update_info_buffer(infos)
|
||||
|
|
@ -265,7 +265,7 @@ class OnPolicyAlgorithm(BaseAlgorithm):
|
|||
while self.num_timesteps < total_timesteps:
|
||||
continue_training = self.collect_rollouts(self.env, callback, self.rollout_buffer, n_rollout_steps=self.n_steps)
|
||||
|
||||
if continue_training is False:
|
||||
if not continue_training:
|
||||
break
|
||||
|
||||
iteration += 1
|
||||
|
|
|
|||
|
|
@ -90,7 +90,7 @@ class BaseModel(nn.Module):
|
|||
self.features_extractor_class = features_extractor_class
|
||||
self.features_extractor_kwargs = features_extractor_kwargs
|
||||
# Automatically deactivate dtype and bounds checks
|
||||
if normalize_images is False and issubclass(features_extractor_class, (NatureCNN, CombinedExtractor)):
|
||||
if not normalize_images and issubclass(features_extractor_class, (NatureCNN, CombinedExtractor)):
|
||||
self.features_extractor_kwargs.update(dict(normalized_image=True))
|
||||
|
||||
def _update_features_extractor(
|
||||
|
|
|
|||
|
|
@ -4,9 +4,11 @@ import shutil
|
|||
import gymnasium as gym
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch as th
|
||||
|
||||
from stable_baselines3 import A2C, DDPG, DQN, PPO, SAC, TD3, HerReplayBuffer
|
||||
from stable_baselines3.common.callbacks import (
|
||||
BaseCallback,
|
||||
CallbackList,
|
||||
CheckpointCallback,
|
||||
EvalCallback,
|
||||
|
|
@ -123,6 +125,40 @@ def test_eval_callback_vec_env():
|
|||
assert eval_callback.last_mean_reward == 100.0
|
||||
|
||||
|
||||
class AlwaysFailCallback(BaseCallback):
|
||||
def __init__(self, *args, callback_false_value, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.callback_false_value = callback_false_value
|
||||
|
||||
def _on_step(self) -> bool:
|
||||
return self.callback_false_value
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"model_class,model_kwargs",
|
||||
[
|
||||
(A2C, dict(n_steps=1, stats_window_size=1)),
|
||||
(
|
||||
SAC,
|
||||
dict(
|
||||
learning_starts=1,
|
||||
buffer_size=1,
|
||||
batch_size=1,
|
||||
),
|
||||
),
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize("callback_false_value", [False, np.bool_(0), th.tensor(0, dtype=th.bool)])
|
||||
def test_callbacks_can_cancel_runs(model_class, model_kwargs, callback_false_value):
|
||||
assert not callback_false_value # Sanity check to ensure parametrized values are valid
|
||||
env_id = select_env(model_class)
|
||||
model = model_class("MlpPolicy", env_id, **model_kwargs, policy_kwargs=dict(net_arch=[2]))
|
||||
alwaysfailcallback = AlwaysFailCallback(callback_false_value=callback_false_value)
|
||||
model.learn(10, callback=alwaysfailcallback)
|
||||
|
||||
assert alwaysfailcallback.n_calls == 1
|
||||
|
||||
|
||||
def test_eval_success_logging(tmp_path):
|
||||
n_bits = 2
|
||||
n_envs = 2
|
||||
|
|
|
|||
Loading…
Reference in a new issue