Add log every n step callback

This commit is contained in:
Antonin RAFFIN 2025-02-09 14:31:59 +01:00
parent c5c29a32d9
commit cc46e90cd0
5 changed files with 28 additions and 4 deletions

View file

@ -865,3 +865,9 @@ class BaseAlgorithm(ABC):
params_to_save = self.get_parameters()
save_to_zip_file(path, data=data, params=params_to_save, pytorch_variables=pytorch_variables)
@abstractmethod
def _dump_logs(self) -> None:
"""
Write log data. (Implemented by OffPolicyAlgorithm and OnPolicyAlgorithm)
"""

View file

@ -591,6 +591,21 @@ class EveryNTimesteps(EventCallback):
return True
class LogEveryNTimesteps(EveryNTimesteps):
"""
Log data every ``n_steps`` timesteps
:param n_steps: Number of timesteps between two trigger.
"""
def __init__(self, n_steps: int):
super().__init__(n_steps, callback=ConvertCallback(self._log_data))
def _log_data(self, _locals: dict[str, Any], _globals: dict[str, Any]) -> bool:
self.model._dump_logs()
return True
class StopTrainingOnMaxEpisodes(BaseCallback):
"""
Stop the training once a maximum number of episodes are played.

View file

@ -408,7 +408,7 @@ class OffPolicyAlgorithm(BaseAlgorithm):
def _dump_logs(self) -> None:
"""
Write log.
Write log data.
"""
assert self.ep_info_buffer is not None
assert self.ep_success_buffer is not None

View file

@ -274,7 +274,7 @@ class OnPolicyAlgorithm(BaseAlgorithm):
"""
raise NotImplementedError
def _dump_logs(self, iteration: int) -> None:
def _dump_logs(self, iteration: int = 0) -> None:
"""
Write log.
@ -285,7 +285,8 @@ class OnPolicyAlgorithm(BaseAlgorithm):
time_elapsed = max((time.time_ns() - self.start_time) / 1e9, sys.float_info.epsilon)
fps = int((self.num_timesteps - self._num_timesteps_at_start) / time_elapsed)
self.logger.record("time/iterations", iteration, exclude="tensorboard")
if iteration > 0:
self.logger.record("time/iterations", iteration, exclude="tensorboard")
if len(self.ep_info_buffer) > 0 and len(self.ep_info_buffer[0]) > 0:
self.logger.record("rollout/ep_rew_mean", safe_mean([ep_info["r"] for ep_info in self.ep_info_buffer]))
self.logger.record("rollout/ep_len_mean", safe_mean([ep_info["l"] for ep_info in self.ep_info_buffer]))

View file

@ -13,6 +13,7 @@ from stable_baselines3.common.callbacks import (
CheckpointCallback,
EvalCallback,
EveryNTimesteps,
LogEveryNTimesteps,
StopTrainingOnMaxEpisodes,
StopTrainingOnNoModelImprovement,
StopTrainingOnRewardThreshold,
@ -62,11 +63,12 @@ def test_callbacks(tmp_path, model_class):
checkpoint_on_event = CheckpointCallback(save_freq=1, save_path=log_folder, name_prefix="event")
event_callback = EveryNTimesteps(n_steps=500, callback=checkpoint_on_event)
log_callback = LogEveryNTimesteps(n_steps=250)
# Stop training if max number of episodes is reached
callback_max_episodes = StopTrainingOnMaxEpisodes(max_episodes=100, verbose=1)
callback = CallbackList([checkpoint_callback, eval_callback, event_callback, callback_max_episodes])
callback = CallbackList([checkpoint_callback, eval_callback, event_callback, log_callback, callback_max_episodes])
model.learn(500, callback=callback)
# Check access to local variables