stable-baselines3/tests/test_logger.py
Bernhard Raml 97b81f9e9e
Fix ignoring the exclude in the logger's record function for json, csv and log logging formats (#190)
* Fix ignoring the exclude in logger record

For the logging formats json, csv, and log the exclude parameter of the
logger's record function has been ignored. The necessary checks were
missing from some of the format writer classes. Regression tests have
been added to prevent this error in the future.

* Fix docstring for filter_excluded_keys

Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org>

* Added missing type hints to local functions

* Update stable_baselines3/common/logger.py

Co-authored-by: Bernhard Raml <raml.bernhard@gmail.com>
Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org>
2020-10-16 17:34:49 +02:00

164 lines
4.3 KiB
Python

from typing import Sequence
import numpy as np
import pytest
from pandas.errors import EmptyDataError
from stable_baselines3.common.logger import (
DEBUG,
ScopedConfigure,
configure,
debug,
dump,
error,
info,
make_output_format,
read_csv,
read_json,
record,
record_dict,
record_mean,
reset,
set_level,
warn,
)
KEY_VALUES = {
"test": 1,
"b": -3.14,
"8": 9.9,
"l": [1, 2],
"a": np.array([1, 2, 3]),
"f": np.array(1),
"g": np.array([[[1]]]),
}
KEY_EXCLUDED = {}
for key in KEY_VALUES.keys():
KEY_EXCLUDED[key] = None
class LogContent:
"""
A simple wrapper class to provide a common interface to check content for emptiness and report the log format
"""
def __init__(self, _format: str, lines: Sequence):
self.format = _format
self.lines = lines
@property
def empty(self):
return len(self.lines) == 0
def __repr__(self):
return f"LogContent(_format={self.format}, lines={self.lines})"
@pytest.fixture
def read_log(tmp_path, capsys):
def read_fn(_format):
if _format == "csv":
try:
df = read_csv(tmp_path / "progress.csv")
except EmptyDataError:
return LogContent(_format, [])
return LogContent(_format, [r for _, r in df.iterrows() if not r.empty])
elif _format == "json":
try:
df = read_json(tmp_path / "progress.json")
except EmptyDataError:
return LogContent(_format, [])
return LogContent(_format, [r for _, r in df.iterrows() if not r.empty])
elif _format == "stdout":
captured = capsys.readouterr()
return LogContent(_format, captured.out.splitlines())
elif _format == "log":
return LogContent(_format, (tmp_path / "log.txt").read_text().splitlines())
elif _format == "tensorboard":
from tensorboard.backend.event_processing.event_accumulator import EventAccumulator
acc = EventAccumulator(str(tmp_path))
acc.Reload()
tb_values_logged = []
for reservoir in [acc.scalars, acc.tensors, acc.images, acc.histograms, acc.compressed_histograms]:
for k in reservoir.Keys():
tb_values_logged.append(f"{k}: {str(reservoir.Items(k))}")
content = LogContent(_format, tb_values_logged)
return content
return read_fn
def test_main(tmp_path):
"""
tests for the logger module
"""
info("hi")
debug("shouldn't appear")
set_level(DEBUG)
debug("should appear")
configure(folder=str(tmp_path))
record("a", 3)
record("b", 2.5)
dump()
record("b", -2.5)
record("a", 5.5)
dump()
info("^^^ should see a = 5.5")
record_mean("b", -22.5)
record_mean("b", -44.4)
record("a", 5.5)
dump()
with ScopedConfigure(None, None):
info("^^^ should see b = 33.3")
with ScopedConfigure(str(tmp_path / "test-logger"), ["json"]):
record("b", -2.5)
dump()
reset()
record("a", "longasslongasslongasslongasslongasslongassvalue")
dump()
warn("hey")
error("oh")
record_dict({"test": 1})
@pytest.mark.parametrize("_format", ["stdout", "log", "json", "csv", "tensorboard"])
def test_make_output(tmp_path, read_log, _format):
"""
test make output
:param _format: (str) output format
"""
if _format == "tensorboard":
# Skip if no tensorboard installed
pytest.importorskip("tensorboard")
writer = make_output_format(_format, tmp_path)
writer.write(KEY_VALUES, KEY_EXCLUDED)
assert not read_log(_format).empty
writer.close()
def test_make_output_fail(tmp_path):
"""
test value error on logger
"""
with pytest.raises(ValueError):
make_output_format("dummy_format", tmp_path)
@pytest.mark.parametrize("_format", ["stdout", "log", "json", "csv", "tensorboard"])
def test_exclude_keys(tmp_path, read_log, _format):
if _format == "tensorboard":
# Skip if no tensorboard installed
pytest.importorskip("tensorboard")
writer = make_output_format(_format, tmp_path)
writer.write(dict(some_tag=42), key_excluded=dict(some_tag=(_format)))
writer.close()
assert read_log(_format).empty