stable-baselines3/tests/test_logger.py
Roland Gavrilescu bb01253261
Tensorboard integration (#30)
* init commit tensorboard-integration

* Added tb logger to ppo (with output exclusions)

* fixed truncated stdout

* categorize stdout outputs by tag

* separated exclusions from values, added missing logs

* saving exclusions as dict instead of list

* reformatting, auto run indexing

* included renaming suggestions, fixed tests

* tb support for sac

* linting

* moved logging to base class

* tb support for td3

* removed histograms, non-verbose output working

* modifed changelog

* linting

* fixed type error

* moved logger config to utils

* removed episode_rewards log from ppo

* Enable tensorboard in tests

* Remove unused import

* Update logger sub titles

* Minor edit for PPO

* Update logger and tb log folder

* Pass correct logger to Callbacks

* updated docs

* added tb example image to docs

* add support for continuing training in tensorboard

* added tensorboard to docs index

* added tb test

* moved logger config to _setup_learn, updated tests

* accessing verbose from base class

* Update doc and tests

* Rename session -> time

* Update version

* Update logger truncate

* Update types

* Remove duplicated code

Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org>
2020-06-01 11:55:44 +02:00

83 lines
2 KiB
Python

import pytest
import numpy as np
from stable_baselines3.common.logger import (make_output_format, read_csv, read_json, DEBUG, ScopedConfigure,
info, debug, set_level, configure, record, record_dict,
dump, record_mean, warn, error, reset)
KEY_VALUES = {
"test": 1,
"b": -3.14,
"8": 9.9,
"l": [1, 2],
"a": np.array([1, 2, 3]),
"f": np.array(1),
"g": np.array([[[1]]]),
}
KEY_EXCLUDED = {}
for key in KEY_VALUES.keys():
KEY_EXCLUDED[key] = None
def test_main(tmp_path):
"""
tests for the logger module
"""
info("hi")
debug("shouldn't appear")
set_level(DEBUG)
debug("should appear")
configure(folder=str(tmp_path))
record("a", 3)
record("b", 2.5)
dump()
record("b", -2.5)
record("a", 5.5)
dump()
info("^^^ should see a = 5.5")
record_mean("b", -22.5)
record_mean("b", -44.4)
record("a", 5.5)
dump()
with ScopedConfigure(None, None):
info("^^^ should see b = 33.3")
with ScopedConfigure(str(tmp_path / "test-logger"), ["json"]):
record("b", -2.5)
dump()
reset()
record("a", "longasslongasslongasslongasslongasslongassvalue")
dump()
warn("hey")
error("oh")
record_dict({"test": 1})
@pytest.mark.parametrize('_format', ['stdout', 'log', 'json', 'csv', 'tensorboard'])
def test_make_output(tmp_path, _format):
"""
test make output
:param _format: (str) output format
"""
if _format == 'tensorboard':
# Skip if no tensorboard installed
pytest.importorskip("tensorboard")
writer = make_output_format(_format, tmp_path)
writer.write(KEY_VALUES, KEY_EXCLUDED)
if _format == "csv":
read_csv(tmp_path / 'progress.csv')
elif _format == 'json':
read_json(tmp_path / 'progress.json')
writer.close()
def test_make_output_fail(tmp_path):
"""
test value error on logger
"""
with pytest.raises(ValueError):
make_output_format('dummy_format', tmp_path)