mirror of
https://github.com/saymrwulf/stable-baselines3.git
synced 2026-05-17 21:20:11 +00:00
* init commit tensorboard-integration * Added tb logger to ppo (with output exclusions) * fixed truncated stdout * categorize stdout outputs by tag * separated exclusions from values, added missing logs * saving exclusions as dict instead of list * reformatting, auto run indexing * included renaming suggestions, fixed tests * tb support for sac * linting * moved logging to base class * tb support for td3 * removed histograms, non-verbose output working * modifed changelog * linting * fixed type error * moved logger config to utils * removed episode_rewards log from ppo * Enable tensorboard in tests * Remove unused import * Update logger sub titles * Minor edit for PPO * Update logger and tb log folder * Pass correct logger to Callbacks * updated docs * added tb example image to docs * add support for continuing training in tensorboard * added tensorboard to docs index * added tb test * moved logger config to _setup_learn, updated tests * accessing verbose from base class * Update doc and tests * Rename session -> time * Update version * Update logger truncate * Update types * Remove duplicated code Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org>
83 lines
2 KiB
Python
83 lines
2 KiB
Python
import pytest
|
|
import numpy as np
|
|
|
|
from stable_baselines3.common.logger import (make_output_format, read_csv, read_json, DEBUG, ScopedConfigure,
|
|
info, debug, set_level, configure, record, record_dict,
|
|
dump, record_mean, warn, error, reset)
|
|
|
|
KEY_VALUES = {
|
|
"test": 1,
|
|
"b": -3.14,
|
|
"8": 9.9,
|
|
"l": [1, 2],
|
|
"a": np.array([1, 2, 3]),
|
|
"f": np.array(1),
|
|
"g": np.array([[[1]]]),
|
|
}
|
|
|
|
KEY_EXCLUDED = {}
|
|
for key in KEY_VALUES.keys():
|
|
KEY_EXCLUDED[key] = None
|
|
|
|
|
|
def test_main(tmp_path):
|
|
"""
|
|
tests for the logger module
|
|
"""
|
|
info("hi")
|
|
debug("shouldn't appear")
|
|
set_level(DEBUG)
|
|
debug("should appear")
|
|
configure(folder=str(tmp_path))
|
|
record("a", 3)
|
|
record("b", 2.5)
|
|
dump()
|
|
record("b", -2.5)
|
|
record("a", 5.5)
|
|
dump()
|
|
info("^^^ should see a = 5.5")
|
|
record_mean("b", -22.5)
|
|
record_mean("b", -44.4)
|
|
record("a", 5.5)
|
|
dump()
|
|
with ScopedConfigure(None, None):
|
|
info("^^^ should see b = 33.3")
|
|
|
|
with ScopedConfigure(str(tmp_path / "test-logger"), ["json"]):
|
|
record("b", -2.5)
|
|
dump()
|
|
|
|
reset()
|
|
record("a", "longasslongasslongasslongasslongasslongassvalue")
|
|
dump()
|
|
warn("hey")
|
|
error("oh")
|
|
record_dict({"test": 1})
|
|
|
|
|
|
@pytest.mark.parametrize('_format', ['stdout', 'log', 'json', 'csv', 'tensorboard'])
|
|
def test_make_output(tmp_path, _format):
|
|
"""
|
|
test make output
|
|
|
|
:param _format: (str) output format
|
|
"""
|
|
if _format == 'tensorboard':
|
|
# Skip if no tensorboard installed
|
|
pytest.importorskip("tensorboard")
|
|
|
|
writer = make_output_format(_format, tmp_path)
|
|
writer.write(KEY_VALUES, KEY_EXCLUDED)
|
|
if _format == "csv":
|
|
read_csv(tmp_path / 'progress.csv')
|
|
elif _format == 'json':
|
|
read_json(tmp_path / 'progress.json')
|
|
writer.close()
|
|
|
|
|
|
def test_make_output_fail(tmp_path):
|
|
"""
|
|
test value error on logger
|
|
"""
|
|
with pytest.raises(ValueError):
|
|
make_output_format('dummy_format', tmp_path)
|