Set CallbackList children's parent correctly (#1939)

* Fixing #1791

* Update test and version

* Add test for callback after eval

* Fix mypy error

* Remove tqdm warnings

---------

Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org>
This commit is contained in:
will-maclean 2024-06-07 22:07:28 +10:00 committed by GitHub
parent 0b06d8ab20
commit 4efee92fba
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
7 changed files with 39 additions and 6 deletions

View file

@ -3,7 +3,7 @@
Changelog
==========
Release 2.4.0a2 (WIP)
Release 2.4.0a3 (WIP)
--------------------------
Breaking Changes:
@ -17,7 +17,8 @@ Bug Fixes:
- Fixed memory leak when loading learner from storage, ``set_parameters()`` does not try to load the object data anymore
and only loads the PyTorch parameters (@peteole)
- Cast type in compute gae method to avoid error when using torch compile (@amjames)
- Fixed error when loading a model that has ``net_arch`` manually set to ``None`` (@jak3122)
- ``CallbackList`` now sets the ``.parent`` attribute of child callbacks to its own ``.parent``. (will-maclean)
- Fixed error when loading a model that has ``net_arch`` manually set to ``None`` (@jak3122)
`SB3-Contrib`_
^^^^^^^^^^^^^^
@ -1662,4 +1663,4 @@ And all the contributors:
@anand-bala @hughperkins @sidney-tio @AlexPasqua @dominicgkerr @Akhilez @Rocamonde @tobirohrer @ZikangXiong @ReHoss
@DavyMorgan @luizapozzobon @Bonifatius94 @theSquaredError @harveybellini @DavyMorgan @FieteO @jonasreiher @npit @WeberSamuel @troiganto
@lutogniew @lbergmann1 @lukashass @BertrandDecoster @pseudo-rnd-thoughts @stefanbschneider @kyle-he @PatrickHelm @corentinlger
@marekm4 @stagoverflow @rushitnshah @markscsmith @NickLucche @cschindlbeck @peteole @jak3122
@marekm4 @stagoverflow @rushitnshah @markscsmith @NickLucche @cschindlbeck @peteole @jak3122 @will-maclean

View file

@ -46,6 +46,8 @@ filterwarnings = [
"ignore::DeprecationWarning:tensorboard",
# Gymnasium warnings
"ignore::UserWarning:gymnasium",
# tqdm warning about rich being experimental
"ignore:rich is experimental"
]
markers = [
"expensive: marks tests as expensive (deselect with '-m \"not expensive\"')"

View file

@ -419,7 +419,7 @@ class RolloutBuffer(BaseBuffer):
:param dones: if the last step was a terminal step (one bool for each env).
"""
# Convert to numpy
last_values = last_values.clone().cpu().numpy().flatten()
last_values = last_values.clone().cpu().numpy().flatten() # type: ignore[assignment]
last_gae_lam = 0
for step in reversed(range(self.buffer_size)):

View file

@ -204,6 +204,10 @@ class CallbackList(BaseCallback):
for callback in self.callbacks:
callback.init_callback(self.model)
# Fix for https://github.com/DLR-RM/stable-baselines3/issues/1791
# pass through the parent callback to all children
callback.parent = self.parent
def _on_training_start(self) -> None:
for callback in self.callbacks:
callback.on_training_start(self.locals, self.globals)

View file

@ -367,7 +367,7 @@ class BasePolicy(BaseModel, ABC):
with th.no_grad():
actions = self._predict(obs_tensor, deterministic=deterministic)
# Convert to numpy, and reshape to the original action shape
actions = actions.cpu().numpy().reshape((-1, *self.action_space.shape)) # type: ignore[misc]
actions = actions.cpu().numpy().reshape((-1, *self.action_space.shape)) # type: ignore[misc, assignment]
if isinstance(self.action_space, spaces.Box):
if self.squash_output:

View file

@ -1 +1 @@
2.4.0a2
2.4.0a3

View file

@ -264,3 +264,29 @@ def test_checkpoint_additional_info(tmp_path):
model = DQN.load(checkpoint_dir / "rl_model_200_steps.zip")
model.load_replay_buffer(checkpoint_dir / "rl_model_replay_buffer_200_steps.pkl")
VecNormalize.load(checkpoint_dir / "rl_model_vecnormalize_200_steps.pkl", dummy_vec_env)
def test_eval_callback_chaining(tmp_path):
class DummyCallback(BaseCallback):
def _on_step(self):
# Check that the parent callback is an EvalCallback
assert isinstance(self.parent, EvalCallback)
assert hasattr(self.parent, "best_mean_reward")
return True
stop_on_threshold_callback = StopTrainingOnRewardThreshold(reward_threshold=-200, verbose=1)
eval_callback = EvalCallback(
gym.make("Pendulum-v1"),
best_model_save_path=tmp_path,
log_path=tmp_path,
eval_freq=32,
deterministic=True,
render=False,
callback_on_new_best=CallbackList([DummyCallback(), stop_on_threshold_callback]),
callback_after_eval=CallbackList([DummyCallback()]),
warn=False,
)
model = PPO("MlpPolicy", "Pendulum-v1", n_steps=64, n_epochs=1)
model.learn(64, callback=eval_callback)