stable-baselines3/tests/test_tensorboard.py
Roland Gavrilescu bb01253261
Tensorboard integration (#30)
* init commit tensorboard-integration

* Added tb logger to ppo (with output exclusions)

* fixed truncated stdout

* categorize stdout outputs by tag

* separated exclusions from values, added missing logs

* saving exclusions as dict instead of list

* reformatting, auto run indexing

* included renaming suggestions, fixed tests

* tb support for sac

* linting

* moved logging to base class

* tb support for td3

* removed histograms, non-verbose output working

* modifed changelog

* linting

* fixed type error

* moved logger config to utils

* removed episode_rewards log from ppo

* Enable tensorboard in tests

* Remove unused import

* Update logger sub titles

* Minor edit for PPO

* Update logger and tb log folder

* Pass correct logger to Callbacks

* updated docs

* added tb example image to docs

* add support for continuing training in tensorboard

* added tensorboard to docs index

* added tb test

* moved logger config to _setup_learn, updated tests

* accessing verbose from base class

* Update doc and tests

* Rename session -> time

* Update version

* Update logger truncate

* Update types

* Remove duplicated code

Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org>
2020-06-01 11:55:44 +02:00

36 lines
1.1 KiB
Python

import os
import pytest
from stable_baselines3 import A2C, PPO, SAC, TD3
MODEL_DICT = {
'a2c': (A2C, 'CartPole-v1'),
'ppo': (PPO, 'CartPole-v1'),
'sac': (SAC, 'Pendulum-v0'),
'td3': (TD3, 'Pendulum-v0'),
}
N_STEPS = 100
@pytest.mark.parametrize("model_name", MODEL_DICT.keys())
def test_tensorboard(tmp_path, model_name):
# Skip if no tensorboard installed
pytest.importorskip("tensorboard")
logname = model_name.upper()
algo, env_id = MODEL_DICT[model_name]
model = algo('MlpPolicy', env_id, verbose=1, tensorboard_log=tmp_path)
model.learn(N_STEPS)
model.learn(N_STEPS, reset_num_timesteps=False)
assert os.path.isdir(tmp_path / str(logname + "_1"))
assert not os.path.isdir(tmp_path / str(logname + "_2"))
logname = "tb_multiple_runs_" + model_name
model.learn(N_STEPS, tb_log_name=logname)
model.learn(N_STEPS, tb_log_name=logname)
assert os.path.isdir(tmp_path / str(logname + "_1"))
# Check that the log dir name increments correctly
assert os.path.isdir(tmp_path / str(logname + "_2"))