mirror of
https://github.com/saymrwulf/stable-baselines3.git
synced 2026-05-14 20:58:03 +00:00
* Drop python 3.8 support, add python 3.12 support * Upgrade to python 3.9 syntax * Fixes for Numpy v2 * Fix doc warning
84 lines
2.9 KiB
Python
84 lines
2.9 KiB
Python
import os
|
|
from typing import Union
|
|
|
|
import pytest
|
|
|
|
from stable_baselines3 import A2C, PPO, SAC, TD3
|
|
from stable_baselines3.common.callbacks import BaseCallback
|
|
from stable_baselines3.common.logger import HParam
|
|
from stable_baselines3.common.utils import get_latest_run_id
|
|
|
|
MODEL_DICT = {
|
|
"a2c": (A2C, "CartPole-v1"),
|
|
"ppo": (PPO, "CartPole-v1"),
|
|
"sac": (SAC, "Pendulum-v1"),
|
|
"td3": (TD3, "Pendulum-v1"),
|
|
}
|
|
|
|
N_STEPS = 100
|
|
|
|
|
|
class HParamCallback(BaseCallback):
|
|
"""
|
|
Saves the hyperparameters and metrics at the start of the training, and logs them to TensorBoard.
|
|
"""
|
|
|
|
def _on_training_start(self) -> None:
|
|
hparam_dict: dict[str, Union[str, float]] = {
|
|
"algorithm": self.model.__class__.__name__,
|
|
# Ignore type checking for gamma, see https://github.com/DLR-RM/stable-baselines3/pull/1194/files#r1035006458
|
|
"gamma": self.model.gamma, # type: ignore[attr-defined]
|
|
}
|
|
if isinstance(self.model.learning_rate, float): # Can also be Schedule, in that case, we don't report
|
|
hparam_dict["learning rate"] = self.model.learning_rate
|
|
# define the metrics that will appear in the `HPARAMS` Tensorboard tab by referencing their tag
|
|
# Tensorbaord will find & display metrics from the `SCALARS` tab
|
|
metric_dict: dict[str, float] = {
|
|
"rollout/ep_len_mean": 0,
|
|
}
|
|
self.logger.record(
|
|
"hparams",
|
|
HParam(hparam_dict, metric_dict),
|
|
exclude=("stdout", "log", "json", "csv"),
|
|
)
|
|
|
|
def _on_step(self) -> bool:
|
|
return True
|
|
|
|
|
|
@pytest.mark.parametrize("model_name", MODEL_DICT.keys())
|
|
def test_tensorboard(tmp_path, model_name):
|
|
# Skip if no tensorboard installed
|
|
pytest.importorskip("tensorboard")
|
|
|
|
logname = model_name.upper()
|
|
algo, env_id = MODEL_DICT[model_name]
|
|
kwargs = {}
|
|
if model_name == "ppo":
|
|
kwargs["n_steps"] = 64
|
|
elif model_name in {"sac", "td3"}:
|
|
kwargs["train_freq"] = 2
|
|
model = algo("MlpPolicy", env_id, verbose=1, tensorboard_log=tmp_path, **kwargs)
|
|
model.learn(N_STEPS, callback=HParamCallback())
|
|
model.learn(N_STEPS, reset_num_timesteps=False)
|
|
|
|
assert os.path.isdir(tmp_path / str(logname + "_1"))
|
|
assert not os.path.isdir(tmp_path / str(logname + "_2"))
|
|
|
|
logname = "tb_multiple_runs_" + model_name
|
|
model.learn(N_STEPS, tb_log_name=logname)
|
|
model.learn(N_STEPS, tb_log_name=logname)
|
|
|
|
assert os.path.isdir(tmp_path / str(logname + "_1"))
|
|
# Check that the log dir name increments correctly
|
|
assert os.path.isdir(tmp_path / str(logname + "_2"))
|
|
|
|
|
|
def test_escape_log_name(tmp_path):
|
|
# Log name that must be escaped
|
|
log_name = "filename[16, 16]"
|
|
# Create folder
|
|
os.makedirs(str(tmp_path) + f"/{log_name}_1", exist_ok=True)
|
|
os.makedirs(str(tmp_path) + f"/{log_name}_2", exist_ok=True)
|
|
last_run_id = get_latest_run_id(tmp_path, log_name)
|
|
assert last_run_id == 2
|