mirror of
https://github.com/saymrwulf/stable-baselines3.git
synced 2026-05-17 21:20:11 +00:00
* init commit tensorboard-integration * Added tb logger to ppo (with output exclusions) * fixed truncated stdout * categorize stdout outputs by tag * separated exclusions from values, added missing logs * saving exclusions as dict instead of list * reformatting, auto run indexing * included renaming suggestions, fixed tests * tb support for sac * linting * moved logging to base class * tb support for td3 * removed histograms, non-verbose output working * modifed changelog * linting * fixed type error * moved logger config to utils * removed episode_rewards log from ppo * Enable tensorboard in tests * Remove unused import * Update logger sub titles * Minor edit for PPO * Update logger and tb log folder * Pass correct logger to Callbacks * updated docs * added tb example image to docs * add support for continuing training in tensorboard * added tensorboard to docs index * added tb test * moved logger config to _setup_learn, updated tests * accessing verbose from base class * Update doc and tests * Rename session -> time * Update version * Update logger truncate * Update types * Remove duplicated code Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org>
36 lines
1.1 KiB
Python
36 lines
1.1 KiB
Python
import os
|
|
import pytest
|
|
|
|
from stable_baselines3 import A2C, PPO, SAC, TD3
|
|
|
|
MODEL_DICT = {
|
|
'a2c': (A2C, 'CartPole-v1'),
|
|
'ppo': (PPO, 'CartPole-v1'),
|
|
'sac': (SAC, 'Pendulum-v0'),
|
|
'td3': (TD3, 'Pendulum-v0'),
|
|
}
|
|
|
|
N_STEPS = 100
|
|
|
|
|
|
@pytest.mark.parametrize("model_name", MODEL_DICT.keys())
|
|
def test_tensorboard(tmp_path, model_name):
|
|
# Skip if no tensorboard installed
|
|
pytest.importorskip("tensorboard")
|
|
|
|
logname = model_name.upper()
|
|
algo, env_id = MODEL_DICT[model_name]
|
|
model = algo('MlpPolicy', env_id, verbose=1, tensorboard_log=tmp_path)
|
|
model.learn(N_STEPS)
|
|
model.learn(N_STEPS, reset_num_timesteps=False)
|
|
|
|
assert os.path.isdir(tmp_path / str(logname + "_1"))
|
|
assert not os.path.isdir(tmp_path / str(logname + "_2"))
|
|
|
|
logname = "tb_multiple_runs_" + model_name
|
|
model.learn(N_STEPS, tb_log_name=logname)
|
|
model.learn(N_STEPS, tb_log_name=logname)
|
|
|
|
assert os.path.isdir(tmp_path / str(logname + "_1"))
|
|
# Check that the log dir name increments correctly
|
|
assert os.path.isdir(tmp_path / str(logname + "_2"))
|