mirror of
https://github.com/saymrwulf/stable-baselines3.git
synced 2026-05-17 21:20:11 +00:00
* create Hparam class & support in all OutputFormats * add hparams documentation & example * add hparam tests * remove unnecessary test & fix name * format changes * support hyperparameters logging to tensorboard * fix HParams class docstring * use more explicit variable names * raise error instead of warning * Unpin protobuf * Add test for logging hparams Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org>
83 lines
2.6 KiB
Python
83 lines
2.6 KiB
Python
import os
|
|
|
|
import pytest
|
|
|
|
from stable_baselines3 import A2C, PPO, SAC, TD3
|
|
from stable_baselines3.common.callbacks import BaseCallback
|
|
from stable_baselines3.common.logger import HParam
|
|
from stable_baselines3.common.utils import get_latest_run_id
|
|
|
|
MODEL_DICT = {
|
|
"a2c": (A2C, "CartPole-v1"),
|
|
"ppo": (PPO, "CartPole-v1"),
|
|
"sac": (SAC, "Pendulum-v1"),
|
|
"td3": (TD3, "Pendulum-v1"),
|
|
}
|
|
|
|
N_STEPS = 100
|
|
|
|
|
|
class HParamCallback(BaseCallback):
|
|
def __init__(self):
|
|
"""
|
|
Saves the hyperparameters and metrics at the start of the training, and logs them to TensorBoard.
|
|
"""
|
|
super().__init__()
|
|
|
|
def _on_training_start(self) -> None:
|
|
hparam_dict = {
|
|
"algorithm": self.model.__class__.__name__,
|
|
"learning rate": self.model.learning_rate,
|
|
"gamma": self.model.gamma,
|
|
}
|
|
# define the metrics that will appear in the `HPARAMS` Tensorboard tab by referencing their tag
|
|
# Tensorbaord will find & display metrics from the `SCALARS` tab
|
|
metric_dict = {
|
|
"rollout/ep_len_mean": 0,
|
|
}
|
|
self.logger.record(
|
|
"hparams",
|
|
HParam(hparam_dict, metric_dict),
|
|
exclude=("stdout", "log", "json", "csv"),
|
|
)
|
|
|
|
def _on_step(self) -> bool:
|
|
return True
|
|
|
|
|
|
@pytest.mark.parametrize("model_name", MODEL_DICT.keys())
|
|
def test_tensorboard(tmp_path, model_name):
|
|
# Skip if no tensorboard installed
|
|
pytest.importorskip("tensorboard")
|
|
|
|
logname = model_name.upper()
|
|
algo, env_id = MODEL_DICT[model_name]
|
|
kwargs = {}
|
|
if model_name == "ppo":
|
|
kwargs["n_steps"] = 64
|
|
elif model_name in {"sac", "td3"}:
|
|
kwargs["train_freq"] = 2
|
|
model = algo("MlpPolicy", env_id, verbose=1, tensorboard_log=tmp_path, **kwargs)
|
|
model.learn(N_STEPS, callback=HParamCallback())
|
|
model.learn(N_STEPS, reset_num_timesteps=False)
|
|
|
|
assert os.path.isdir(tmp_path / str(logname + "_1"))
|
|
assert not os.path.isdir(tmp_path / str(logname + "_2"))
|
|
|
|
logname = "tb_multiple_runs_" + model_name
|
|
model.learn(N_STEPS, tb_log_name=logname)
|
|
model.learn(N_STEPS, tb_log_name=logname)
|
|
|
|
assert os.path.isdir(tmp_path / str(logname + "_1"))
|
|
# Check that the log dir name increments correctly
|
|
assert os.path.isdir(tmp_path / str(logname + "_2"))
|
|
|
|
|
|
def test_escape_log_name(tmp_path):
|
|
# Log name that must be escaped
|
|
log_name = "filename[16, 16]"
|
|
# Create folder
|
|
os.makedirs(str(tmp_path) + f"/{log_name}_1", exist_ok=True)
|
|
os.makedirs(str(tmp_path) + f"/{log_name}_2", exist_ok=True)
|
|
last_run_id = get_latest_run_id(tmp_path, log_name)
|
|
assert last_run_id == 2
|