diff --git a/stable_baselines3/common/base_class.py b/stable_baselines3/common/base_class.py index 412f9dd..96c0311 100644 --- a/stable_baselines3/common/base_class.py +++ b/stable_baselines3/common/base_class.py @@ -865,3 +865,9 @@ class BaseAlgorithm(ABC): params_to_save = self.get_parameters() save_to_zip_file(path, data=data, params=params_to_save, pytorch_variables=pytorch_variables) + + @abstractmethod + def _dump_logs(self) -> None: + """ + Write log data. (Implemented by OffPolicyAlgorithm and OnPolicyAlgorithm) + """ diff --git a/stable_baselines3/common/callbacks.py b/stable_baselines3/common/callbacks.py index 0e73879..50da532 100644 --- a/stable_baselines3/common/callbacks.py +++ b/stable_baselines3/common/callbacks.py @@ -591,6 +591,21 @@ class EveryNTimesteps(EventCallback): return True +class LogEveryNTimesteps(EveryNTimesteps): + """ + Log data every ``n_steps`` timesteps + + :param n_steps: Number of timesteps between two trigger. + """ + + def __init__(self, n_steps: int): + super().__init__(n_steps, callback=ConvertCallback(self._log_data)) + + def _log_data(self, _locals: dict[str, Any], _globals: dict[str, Any]) -> bool: + self.model._dump_logs() + return True + + class StopTrainingOnMaxEpisodes(BaseCallback): """ Stop the training once a maximum number of episodes are played. diff --git a/stable_baselines3/common/off_policy_algorithm.py b/stable_baselines3/common/off_policy_algorithm.py index c3e1c66..3bb47f1 100644 --- a/stable_baselines3/common/off_policy_algorithm.py +++ b/stable_baselines3/common/off_policy_algorithm.py @@ -408,7 +408,7 @@ class OffPolicyAlgorithm(BaseAlgorithm): def _dump_logs(self) -> None: """ - Write log. + Write log data. """ assert self.ep_info_buffer is not None assert self.ep_success_buffer is not None diff --git a/stable_baselines3/common/on_policy_algorithm.py b/stable_baselines3/common/on_policy_algorithm.py index ac4c097..b6b43f0 100644 --- a/stable_baselines3/common/on_policy_algorithm.py +++ b/stable_baselines3/common/on_policy_algorithm.py @@ -274,7 +274,7 @@ class OnPolicyAlgorithm(BaseAlgorithm): """ raise NotImplementedError - def _dump_logs(self, iteration: int) -> None: + def _dump_logs(self, iteration: int = 0) -> None: """ Write log. @@ -285,7 +285,8 @@ class OnPolicyAlgorithm(BaseAlgorithm): time_elapsed = max((time.time_ns() - self.start_time) / 1e9, sys.float_info.epsilon) fps = int((self.num_timesteps - self._num_timesteps_at_start) / time_elapsed) - self.logger.record("time/iterations", iteration, exclude="tensorboard") + if iteration > 0: + self.logger.record("time/iterations", iteration, exclude="tensorboard") if len(self.ep_info_buffer) > 0 and len(self.ep_info_buffer[0]) > 0: self.logger.record("rollout/ep_rew_mean", safe_mean([ep_info["r"] for ep_info in self.ep_info_buffer])) self.logger.record("rollout/ep_len_mean", safe_mean([ep_info["l"] for ep_info in self.ep_info_buffer])) diff --git a/tests/test_callbacks.py b/tests/test_callbacks.py index ffc3732..81d84ac 100644 --- a/tests/test_callbacks.py +++ b/tests/test_callbacks.py @@ -13,6 +13,7 @@ from stable_baselines3.common.callbacks import ( CheckpointCallback, EvalCallback, EveryNTimesteps, + LogEveryNTimesteps, StopTrainingOnMaxEpisodes, StopTrainingOnNoModelImprovement, StopTrainingOnRewardThreshold, @@ -62,11 +63,12 @@ def test_callbacks(tmp_path, model_class): checkpoint_on_event = CheckpointCallback(save_freq=1, save_path=log_folder, name_prefix="event") event_callback = EveryNTimesteps(n_steps=500, callback=checkpoint_on_event) + log_callback = LogEveryNTimesteps(n_steps=250) # Stop training if max number of episodes is reached callback_max_episodes = StopTrainingOnMaxEpisodes(max_episodes=100, verbose=1) - callback = CallbackList([checkpoint_callback, eval_callback, event_callback, callback_max_episodes]) + callback = CallbackList([checkpoint_callback, eval_callback, event_callback, log_callback, callback_max_episodes]) model.learn(500, callback=callback) # Check access to local variables