stable-baselines3/tests/test_tensorboard.py
Antonin RAFFIN daaebd0a52
Drop python 3.8 and add python 3.12 support (#2041)
* Drop python 3.8 support, add python 3.12 support

* Upgrade to python 3.9 syntax

* Fixes for Numpy v2

* Fix doc warning
2024-11-18 15:40:36 +01:00

84 lines
2.9 KiB
Python

import os
from typing import Union
import pytest
from stable_baselines3 import A2C, PPO, SAC, TD3
from stable_baselines3.common.callbacks import BaseCallback
from stable_baselines3.common.logger import HParam
from stable_baselines3.common.utils import get_latest_run_id
MODEL_DICT = {
"a2c": (A2C, "CartPole-v1"),
"ppo": (PPO, "CartPole-v1"),
"sac": (SAC, "Pendulum-v1"),
"td3": (TD3, "Pendulum-v1"),
}
N_STEPS = 100
class HParamCallback(BaseCallback):
"""
Saves the hyperparameters and metrics at the start of the training, and logs them to TensorBoard.
"""
def _on_training_start(self) -> None:
hparam_dict: dict[str, Union[str, float]] = {
"algorithm": self.model.__class__.__name__,
# Ignore type checking for gamma, see https://github.com/DLR-RM/stable-baselines3/pull/1194/files#r1035006458
"gamma": self.model.gamma, # type: ignore[attr-defined]
}
if isinstance(self.model.learning_rate, float): # Can also be Schedule, in that case, we don't report
hparam_dict["learning rate"] = self.model.learning_rate
# define the metrics that will appear in the `HPARAMS` Tensorboard tab by referencing their tag
# Tensorbaord will find & display metrics from the `SCALARS` tab
metric_dict: dict[str, float] = {
"rollout/ep_len_mean": 0,
}
self.logger.record(
"hparams",
HParam(hparam_dict, metric_dict),
exclude=("stdout", "log", "json", "csv"),
)
def _on_step(self) -> bool:
return True
@pytest.mark.parametrize("model_name", MODEL_DICT.keys())
def test_tensorboard(tmp_path, model_name):
# Skip if no tensorboard installed
pytest.importorskip("tensorboard")
logname = model_name.upper()
algo, env_id = MODEL_DICT[model_name]
kwargs = {}
if model_name == "ppo":
kwargs["n_steps"] = 64
elif model_name in {"sac", "td3"}:
kwargs["train_freq"] = 2
model = algo("MlpPolicy", env_id, verbose=1, tensorboard_log=tmp_path, **kwargs)
model.learn(N_STEPS, callback=HParamCallback())
model.learn(N_STEPS, reset_num_timesteps=False)
assert os.path.isdir(tmp_path / str(logname + "_1"))
assert not os.path.isdir(tmp_path / str(logname + "_2"))
logname = "tb_multiple_runs_" + model_name
model.learn(N_STEPS, tb_log_name=logname)
model.learn(N_STEPS, tb_log_name=logname)
assert os.path.isdir(tmp_path / str(logname + "_1"))
# Check that the log dir name increments correctly
assert os.path.isdir(tmp_path / str(logname + "_2"))
def test_escape_log_name(tmp_path):
# Log name that must be escaped
log_name = "filename[16, 16]"
# Create folder
os.makedirs(str(tmp_path) + f"/{log_name}_1", exist_ok=True)
os.makedirs(str(tmp_path) + f"/{log_name}_2", exist_ok=True)
last_run_id = get_latest_run_id(tmp_path, log_name)
assert last_run_id == 2