mirror of
https://github.com/saymrwulf/stable-baselines3.git
synced 2026-05-16 21:10:08 +00:00
* Add auto formatting with black and isort * Reformat code * Ignore typing errors * Add note about line length * Add minimum version for isort * Add commit-checks * Update docker image * Fixed lost import (during last merge) * Fix opencv dependency
37 lines
1.1 KiB
Python
37 lines
1.1 KiB
Python
import os
|
|
|
|
import pytest
|
|
|
|
from stable_baselines3 import A2C, PPO, SAC, TD3
|
|
|
|
MODEL_DICT = {
|
|
"a2c": (A2C, "CartPole-v1"),
|
|
"ppo": (PPO, "CartPole-v1"),
|
|
"sac": (SAC, "Pendulum-v0"),
|
|
"td3": (TD3, "Pendulum-v0"),
|
|
}
|
|
|
|
N_STEPS = 100
|
|
|
|
|
|
@pytest.mark.parametrize("model_name", MODEL_DICT.keys())
|
|
def test_tensorboard(tmp_path, model_name):
|
|
# Skip if no tensorboard installed
|
|
pytest.importorskip("tensorboard")
|
|
|
|
logname = model_name.upper()
|
|
algo, env_id = MODEL_DICT[model_name]
|
|
model = algo("MlpPolicy", env_id, verbose=1, tensorboard_log=tmp_path)
|
|
model.learn(N_STEPS)
|
|
model.learn(N_STEPS, reset_num_timesteps=False)
|
|
|
|
assert os.path.isdir(tmp_path / str(logname + "_1"))
|
|
assert not os.path.isdir(tmp_path / str(logname + "_2"))
|
|
|
|
logname = "tb_multiple_runs_" + model_name
|
|
model.learn(N_STEPS, tb_log_name=logname)
|
|
model.learn(N_STEPS, tb_log_name=logname)
|
|
|
|
assert os.path.isdir(tmp_path / str(logname + "_1"))
|
|
# Check that the log dir name increments correctly
|
|
assert os.path.isdir(tmp_path / str(logname + "_2"))
|