mirror of
https://github.com/saymrwulf/stable-baselines3.git
synced 2026-05-14 20:58:03 +00:00
Add log every n step callback
This commit is contained in:
parent
c5c29a32d9
commit
cc46e90cd0
5 changed files with 28 additions and 4 deletions
|
|
@ -865,3 +865,9 @@ class BaseAlgorithm(ABC):
|
|||
params_to_save = self.get_parameters()
|
||||
|
||||
save_to_zip_file(path, data=data, params=params_to_save, pytorch_variables=pytorch_variables)
|
||||
|
||||
@abstractmethod
|
||||
def _dump_logs(self) -> None:
|
||||
"""
|
||||
Write log data. (Implemented by OffPolicyAlgorithm and OnPolicyAlgorithm)
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -591,6 +591,21 @@ class EveryNTimesteps(EventCallback):
|
|||
return True
|
||||
|
||||
|
||||
class LogEveryNTimesteps(EveryNTimesteps):
|
||||
"""
|
||||
Log data every ``n_steps`` timesteps
|
||||
|
||||
:param n_steps: Number of timesteps between two trigger.
|
||||
"""
|
||||
|
||||
def __init__(self, n_steps: int):
|
||||
super().__init__(n_steps, callback=ConvertCallback(self._log_data))
|
||||
|
||||
def _log_data(self, _locals: dict[str, Any], _globals: dict[str, Any]) -> bool:
|
||||
self.model._dump_logs()
|
||||
return True
|
||||
|
||||
|
||||
class StopTrainingOnMaxEpisodes(BaseCallback):
|
||||
"""
|
||||
Stop the training once a maximum number of episodes are played.
|
||||
|
|
|
|||
|
|
@ -408,7 +408,7 @@ class OffPolicyAlgorithm(BaseAlgorithm):
|
|||
|
||||
def _dump_logs(self) -> None:
|
||||
"""
|
||||
Write log.
|
||||
Write log data.
|
||||
"""
|
||||
assert self.ep_info_buffer is not None
|
||||
assert self.ep_success_buffer is not None
|
||||
|
|
|
|||
|
|
@ -274,7 +274,7 @@ class OnPolicyAlgorithm(BaseAlgorithm):
|
|||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def _dump_logs(self, iteration: int) -> None:
|
||||
def _dump_logs(self, iteration: int = 0) -> None:
|
||||
"""
|
||||
Write log.
|
||||
|
||||
|
|
@ -285,7 +285,8 @@ class OnPolicyAlgorithm(BaseAlgorithm):
|
|||
|
||||
time_elapsed = max((time.time_ns() - self.start_time) / 1e9, sys.float_info.epsilon)
|
||||
fps = int((self.num_timesteps - self._num_timesteps_at_start) / time_elapsed)
|
||||
self.logger.record("time/iterations", iteration, exclude="tensorboard")
|
||||
if iteration > 0:
|
||||
self.logger.record("time/iterations", iteration, exclude="tensorboard")
|
||||
if len(self.ep_info_buffer) > 0 and len(self.ep_info_buffer[0]) > 0:
|
||||
self.logger.record("rollout/ep_rew_mean", safe_mean([ep_info["r"] for ep_info in self.ep_info_buffer]))
|
||||
self.logger.record("rollout/ep_len_mean", safe_mean([ep_info["l"] for ep_info in self.ep_info_buffer]))
|
||||
|
|
|
|||
|
|
@ -13,6 +13,7 @@ from stable_baselines3.common.callbacks import (
|
|||
CheckpointCallback,
|
||||
EvalCallback,
|
||||
EveryNTimesteps,
|
||||
LogEveryNTimesteps,
|
||||
StopTrainingOnMaxEpisodes,
|
||||
StopTrainingOnNoModelImprovement,
|
||||
StopTrainingOnRewardThreshold,
|
||||
|
|
@ -62,11 +63,12 @@ def test_callbacks(tmp_path, model_class):
|
|||
checkpoint_on_event = CheckpointCallback(save_freq=1, save_path=log_folder, name_prefix="event")
|
||||
|
||||
event_callback = EveryNTimesteps(n_steps=500, callback=checkpoint_on_event)
|
||||
log_callback = LogEveryNTimesteps(n_steps=250)
|
||||
|
||||
# Stop training if max number of episodes is reached
|
||||
callback_max_episodes = StopTrainingOnMaxEpisodes(max_episodes=100, verbose=1)
|
||||
|
||||
callback = CallbackList([checkpoint_callback, eval_callback, event_callback, callback_max_episodes])
|
||||
callback = CallbackList([checkpoint_callback, eval_callback, event_callback, log_callback, callback_max_episodes])
|
||||
model.learn(500, callback=callback)
|
||||
|
||||
# Check access to local variables
|
||||
|
|
|
|||
Loading…
Reference in a new issue