mirror of
https://github.com/saymrwulf/stable-baselines3.git
synced 2026-06-08 00:23:22 +00:00
Fix Inconsistencies with EvalCallback tensorboard logs (#492)
* Make EvalCallback dump the evaluation logs it records #457. * Make test deterministic Co-authored-by: Antonin Raffin <antonin.raffin@ensta.org>
This commit is contained in:
parent
066e1409d9
commit
abbf48e93e
3 changed files with 32 additions and 1 deletions
|
|
@ -60,6 +60,7 @@ Bug Fixes:
|
|||
- Fixed saving of ``A2C`` and ``PPO`` policy when using gSDE (thanks @liusida)
|
||||
- Fixed a bug where no output would be shown even if ``verbose>=1`` after passing ``verbose=0`` once
|
||||
- Fixed observation buffers dtype in DictReplayBuffer (@c-rizz)
|
||||
- Fixed EvalCallback tensorboard logs being logged with the incorrect timestep. They are now written with the timestep at which they were recorded. (@skandermoalla)
|
||||
|
||||
Deprecations:
|
||||
^^^^^^^^^^^^^
|
||||
|
|
@ -707,4 +708,4 @@ And all the contributors:
|
|||
@diditforlulz273 @liorcohen5 @ManifoldFR @mloo3 @SwamyDev @wmmc88 @megan-klaiber @thisray
|
||||
@tfederico @hn2 @LucasAlegre @AptX395 @zampanteymedio @JadenTravnik @decodyng @ardabbour @lorenz-h @mschweizer @lorepieri8 @vwxyzjn
|
||||
@ShangqunYu @PierreExeter @JacopoPan @ltbd78 @tom-doerr @Atlis @liusida @09tangriro @amy12xx @juancroldan @benblack769 @bstee615
|
||||
@c-rizz
|
||||
@c-rizz @skandermoalla
|
||||
|
|
|
|||
|
|
@ -414,6 +414,10 @@ class EvalCallback(EventCallback):
|
|||
print(f"Success rate: {100 * success_rate:.2f}%")
|
||||
self.logger.record("eval/success_rate", success_rate)
|
||||
|
||||
# Dump log so the evaluation results are printed with the correct timestep
|
||||
self.logger.record("time/total timesteps", self.num_timesteps, exclude="tensorboard")
|
||||
self.logger.dump(self.num_timesteps)
|
||||
|
||||
if mean_reward > self.best_mean_reward:
|
||||
if self.verbose > 0:
|
||||
print("New best mean reward!")
|
||||
|
|
|
|||
|
|
@ -141,3 +141,29 @@ def test_eval_success_logging(tmp_path):
|
|||
assert len(eval_callback._is_success_buffer) > 0
|
||||
# More than 50% success rate
|
||||
assert np.mean(eval_callback._is_success_buffer) > 0.5
|
||||
|
||||
|
||||
def test_eval_callback_logs_are_written_with_the_correct_timestep(tmp_path):
|
||||
# Skip if no tensorboard installed
|
||||
pytest.importorskip("tensorboard")
|
||||
from tensorboard.backend.event_processing.event_accumulator import EventAccumulator
|
||||
|
||||
env_name = select_env(DQN)
|
||||
model = DQN(
|
||||
"MlpPolicy",
|
||||
env_name,
|
||||
policy_kwargs=dict(net_arch=[32]),
|
||||
tensorboard_log=tmp_path,
|
||||
verbose=1,
|
||||
seed=1,
|
||||
)
|
||||
|
||||
eval_env = gym.make(env_name)
|
||||
eval_freq = 101
|
||||
eval_callback = EvalCallback(eval_env, eval_freq=eval_freq, warn=False)
|
||||
model.learn(500, callback=eval_callback)
|
||||
|
||||
acc = EventAccumulator(str(tmp_path / "DQN_1"))
|
||||
acc.Reload()
|
||||
for event in acc.scalars.Items("eval/mean_reward"):
|
||||
assert event.step % eval_freq == 0
|
||||
|
|
|
|||
Loading…
Reference in a new issue