stable-baselines3/tests/test_tensorboard.py
Timothé 01cc127d32
Support hparams logging to tensorboard (#984)
* create Hparam class & support in all OutputFormats

* add hparams documentation & example

* add hparam tests

* remove unnecessary test & fix name

* format changes

* support hyperparameters logging to tensorboard

* fix HParams class docstring

* use more explicit variable names

* raise error instead of warning

* Unpin protobuf

* Add test for logging hparams

Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org>
2022-08-22 22:06:54 +02:00

83 lines
2.6 KiB
Python

import os
import pytest
from stable_baselines3 import A2C, PPO, SAC, TD3
from stable_baselines3.common.callbacks import BaseCallback
from stable_baselines3.common.logger import HParam
from stable_baselines3.common.utils import get_latest_run_id
MODEL_DICT = {
"a2c": (A2C, "CartPole-v1"),
"ppo": (PPO, "CartPole-v1"),
"sac": (SAC, "Pendulum-v1"),
"td3": (TD3, "Pendulum-v1"),
}
N_STEPS = 100
class HParamCallback(BaseCallback):
def __init__(self):
"""
Saves the hyperparameters and metrics at the start of the training, and logs them to TensorBoard.
"""
super().__init__()
def _on_training_start(self) -> None:
hparam_dict = {
"algorithm": self.model.__class__.__name__,
"learning rate": self.model.learning_rate,
"gamma": self.model.gamma,
}
# define the metrics that will appear in the `HPARAMS` Tensorboard tab by referencing their tag
# Tensorbaord will find & display metrics from the `SCALARS` tab
metric_dict = {
"rollout/ep_len_mean": 0,
}
self.logger.record(
"hparams",
HParam(hparam_dict, metric_dict),
exclude=("stdout", "log", "json", "csv"),
)
def _on_step(self) -> bool:
return True
@pytest.mark.parametrize("model_name", MODEL_DICT.keys())
def test_tensorboard(tmp_path, model_name):
# Skip if no tensorboard installed
pytest.importorskip("tensorboard")
logname = model_name.upper()
algo, env_id = MODEL_DICT[model_name]
kwargs = {}
if model_name == "ppo":
kwargs["n_steps"] = 64
elif model_name in {"sac", "td3"}:
kwargs["train_freq"] = 2
model = algo("MlpPolicy", env_id, verbose=1, tensorboard_log=tmp_path, **kwargs)
model.learn(N_STEPS, callback=HParamCallback())
model.learn(N_STEPS, reset_num_timesteps=False)
assert os.path.isdir(tmp_path / str(logname + "_1"))
assert not os.path.isdir(tmp_path / str(logname + "_2"))
logname = "tb_multiple_runs_" + model_name
model.learn(N_STEPS, tb_log_name=logname)
model.learn(N_STEPS, tb_log_name=logname)
assert os.path.isdir(tmp_path / str(logname + "_1"))
# Check that the log dir name increments correctly
assert os.path.isdir(tmp_path / str(logname + "_2"))
def test_escape_log_name(tmp_path):
# Log name that must be escaped
log_name = "filename[16, 16]"
# Create folder
os.makedirs(str(tmp_path) + f"/{log_name}_1", exist_ok=True)
os.makedirs(str(tmp_path) + f"/{log_name}_2", exist_ok=True)
last_run_id = get_latest_run_id(tmp_path, log_name)
assert last_run_id == 2