mirror of
https://github.com/saymrwulf/stable-baselines3.git
synced 2026-05-14 20:58:03 +00:00
Set CallbackList children's parent correctly (#1939)
* Fixing #1791 * Update test and version * Add test for callback after eval * Fix mypy error * Remove tqdm warnings --------- Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org>
This commit is contained in:
parent
0b06d8ab20
commit
4efee92fba
7 changed files with 39 additions and 6 deletions
|
|
@ -3,7 +3,7 @@
|
|||
Changelog
|
||||
==========
|
||||
|
||||
Release 2.4.0a2 (WIP)
|
||||
Release 2.4.0a3 (WIP)
|
||||
--------------------------
|
||||
|
||||
Breaking Changes:
|
||||
|
|
@ -17,7 +17,8 @@ Bug Fixes:
|
|||
- Fixed memory leak when loading learner from storage, ``set_parameters()`` does not try to load the object data anymore
|
||||
and only loads the PyTorch parameters (@peteole)
|
||||
- Cast type in compute gae method to avoid error when using torch compile (@amjames)
|
||||
- Fixed error when loading a model that has ``net_arch`` manually set to ``None`` (@jak3122)
|
||||
- ``CallbackList`` now sets the ``.parent`` attribute of child callbacks to its own ``.parent``. (will-maclean)
|
||||
- Fixed error when loading a model that has ``net_arch`` manually set to ``None`` (@jak3122)
|
||||
|
||||
`SB3-Contrib`_
|
||||
^^^^^^^^^^^^^^
|
||||
|
|
@ -1662,4 +1663,4 @@ And all the contributors:
|
|||
@anand-bala @hughperkins @sidney-tio @AlexPasqua @dominicgkerr @Akhilez @Rocamonde @tobirohrer @ZikangXiong @ReHoss
|
||||
@DavyMorgan @luizapozzobon @Bonifatius94 @theSquaredError @harveybellini @DavyMorgan @FieteO @jonasreiher @npit @WeberSamuel @troiganto
|
||||
@lutogniew @lbergmann1 @lukashass @BertrandDecoster @pseudo-rnd-thoughts @stefanbschneider @kyle-he @PatrickHelm @corentinlger
|
||||
@marekm4 @stagoverflow @rushitnshah @markscsmith @NickLucche @cschindlbeck @peteole @jak3122
|
||||
@marekm4 @stagoverflow @rushitnshah @markscsmith @NickLucche @cschindlbeck @peteole @jak3122 @will-maclean
|
||||
|
|
|
|||
|
|
@ -46,6 +46,8 @@ filterwarnings = [
|
|||
"ignore::DeprecationWarning:tensorboard",
|
||||
# Gymnasium warnings
|
||||
"ignore::UserWarning:gymnasium",
|
||||
# tqdm warning about rich being experimental
|
||||
"ignore:rich is experimental"
|
||||
]
|
||||
markers = [
|
||||
"expensive: marks tests as expensive (deselect with '-m \"not expensive\"')"
|
||||
|
|
|
|||
|
|
@ -419,7 +419,7 @@ class RolloutBuffer(BaseBuffer):
|
|||
:param dones: if the last step was a terminal step (one bool for each env).
|
||||
"""
|
||||
# Convert to numpy
|
||||
last_values = last_values.clone().cpu().numpy().flatten()
|
||||
last_values = last_values.clone().cpu().numpy().flatten() # type: ignore[assignment]
|
||||
|
||||
last_gae_lam = 0
|
||||
for step in reversed(range(self.buffer_size)):
|
||||
|
|
|
|||
|
|
@ -204,6 +204,10 @@ class CallbackList(BaseCallback):
|
|||
for callback in self.callbacks:
|
||||
callback.init_callback(self.model)
|
||||
|
||||
# Fix for https://github.com/DLR-RM/stable-baselines3/issues/1791
|
||||
# pass through the parent callback to all children
|
||||
callback.parent = self.parent
|
||||
|
||||
def _on_training_start(self) -> None:
|
||||
for callback in self.callbacks:
|
||||
callback.on_training_start(self.locals, self.globals)
|
||||
|
|
|
|||
|
|
@ -367,7 +367,7 @@ class BasePolicy(BaseModel, ABC):
|
|||
with th.no_grad():
|
||||
actions = self._predict(obs_tensor, deterministic=deterministic)
|
||||
# Convert to numpy, and reshape to the original action shape
|
||||
actions = actions.cpu().numpy().reshape((-1, *self.action_space.shape)) # type: ignore[misc]
|
||||
actions = actions.cpu().numpy().reshape((-1, *self.action_space.shape)) # type: ignore[misc, assignment]
|
||||
|
||||
if isinstance(self.action_space, spaces.Box):
|
||||
if self.squash_output:
|
||||
|
|
|
|||
|
|
@ -1 +1 @@
|
|||
2.4.0a2
|
||||
2.4.0a3
|
||||
|
|
|
|||
|
|
@ -264,3 +264,29 @@ def test_checkpoint_additional_info(tmp_path):
|
|||
model = DQN.load(checkpoint_dir / "rl_model_200_steps.zip")
|
||||
model.load_replay_buffer(checkpoint_dir / "rl_model_replay_buffer_200_steps.pkl")
|
||||
VecNormalize.load(checkpoint_dir / "rl_model_vecnormalize_200_steps.pkl", dummy_vec_env)
|
||||
|
||||
|
||||
def test_eval_callback_chaining(tmp_path):
|
||||
class DummyCallback(BaseCallback):
|
||||
def _on_step(self):
|
||||
# Check that the parent callback is an EvalCallback
|
||||
assert isinstance(self.parent, EvalCallback)
|
||||
assert hasattr(self.parent, "best_mean_reward")
|
||||
return True
|
||||
|
||||
stop_on_threshold_callback = StopTrainingOnRewardThreshold(reward_threshold=-200, verbose=1)
|
||||
|
||||
eval_callback = EvalCallback(
|
||||
gym.make("Pendulum-v1"),
|
||||
best_model_save_path=tmp_path,
|
||||
log_path=tmp_path,
|
||||
eval_freq=32,
|
||||
deterministic=True,
|
||||
render=False,
|
||||
callback_on_new_best=CallbackList([DummyCallback(), stop_on_threshold_callback]),
|
||||
callback_after_eval=CallbackList([DummyCallback()]),
|
||||
warn=False,
|
||||
)
|
||||
|
||||
model = PPO("MlpPolicy", "Pendulum-v1", n_steps=64, n_epochs=1)
|
||||
model.learn(64, callback=eval_callback)
|
||||
|
|
|
|||
Loading…
Reference in a new issue