mirror of
https://github.com/saymrwulf/stable-baselines3.git
synced 2026-05-18 21:30:19 +00:00
* Fix ignoring the exclude in logger record For the logging formats json, csv, and log the exclude parameter of the logger's record function has been ignored. The necessary checks were missing from some of the format writer classes. Regression tests have been added to prevent this error in the future. * Fix docstring for filter_excluded_keys Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org> * Added missing type hints to local functions * Update stable_baselines3/common/logger.py Co-authored-by: Bernhard Raml <raml.bernhard@gmail.com> Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org>
164 lines
4.3 KiB
Python
164 lines
4.3 KiB
Python
from typing import Sequence
|
|
|
|
import numpy as np
|
|
import pytest
|
|
from pandas.errors import EmptyDataError
|
|
|
|
from stable_baselines3.common.logger import (
|
|
DEBUG,
|
|
ScopedConfigure,
|
|
configure,
|
|
debug,
|
|
dump,
|
|
error,
|
|
info,
|
|
make_output_format,
|
|
read_csv,
|
|
read_json,
|
|
record,
|
|
record_dict,
|
|
record_mean,
|
|
reset,
|
|
set_level,
|
|
warn,
|
|
)
|
|
|
|
KEY_VALUES = {
|
|
"test": 1,
|
|
"b": -3.14,
|
|
"8": 9.9,
|
|
"l": [1, 2],
|
|
"a": np.array([1, 2, 3]),
|
|
"f": np.array(1),
|
|
"g": np.array([[[1]]]),
|
|
}
|
|
|
|
KEY_EXCLUDED = {}
|
|
for key in KEY_VALUES.keys():
|
|
KEY_EXCLUDED[key] = None
|
|
|
|
|
|
class LogContent:
|
|
"""
|
|
A simple wrapper class to provide a common interface to check content for emptiness and report the log format
|
|
"""
|
|
|
|
def __init__(self, _format: str, lines: Sequence):
|
|
self.format = _format
|
|
self.lines = lines
|
|
|
|
@property
|
|
def empty(self):
|
|
return len(self.lines) == 0
|
|
|
|
def __repr__(self):
|
|
return f"LogContent(_format={self.format}, lines={self.lines})"
|
|
|
|
|
|
@pytest.fixture
|
|
def read_log(tmp_path, capsys):
|
|
def read_fn(_format):
|
|
if _format == "csv":
|
|
try:
|
|
df = read_csv(tmp_path / "progress.csv")
|
|
except EmptyDataError:
|
|
return LogContent(_format, [])
|
|
return LogContent(_format, [r for _, r in df.iterrows() if not r.empty])
|
|
elif _format == "json":
|
|
try:
|
|
df = read_json(tmp_path / "progress.json")
|
|
except EmptyDataError:
|
|
return LogContent(_format, [])
|
|
return LogContent(_format, [r for _, r in df.iterrows() if not r.empty])
|
|
elif _format == "stdout":
|
|
captured = capsys.readouterr()
|
|
return LogContent(_format, captured.out.splitlines())
|
|
elif _format == "log":
|
|
return LogContent(_format, (tmp_path / "log.txt").read_text().splitlines())
|
|
elif _format == "tensorboard":
|
|
from tensorboard.backend.event_processing.event_accumulator import EventAccumulator
|
|
|
|
acc = EventAccumulator(str(tmp_path))
|
|
acc.Reload()
|
|
|
|
tb_values_logged = []
|
|
for reservoir in [acc.scalars, acc.tensors, acc.images, acc.histograms, acc.compressed_histograms]:
|
|
for k in reservoir.Keys():
|
|
tb_values_logged.append(f"{k}: {str(reservoir.Items(k))}")
|
|
|
|
content = LogContent(_format, tb_values_logged)
|
|
return content
|
|
|
|
return read_fn
|
|
|
|
|
|
def test_main(tmp_path):
|
|
"""
|
|
tests for the logger module
|
|
"""
|
|
info("hi")
|
|
debug("shouldn't appear")
|
|
set_level(DEBUG)
|
|
debug("should appear")
|
|
configure(folder=str(tmp_path))
|
|
record("a", 3)
|
|
record("b", 2.5)
|
|
dump()
|
|
record("b", -2.5)
|
|
record("a", 5.5)
|
|
dump()
|
|
info("^^^ should see a = 5.5")
|
|
record_mean("b", -22.5)
|
|
record_mean("b", -44.4)
|
|
record("a", 5.5)
|
|
dump()
|
|
with ScopedConfigure(None, None):
|
|
info("^^^ should see b = 33.3")
|
|
|
|
with ScopedConfigure(str(tmp_path / "test-logger"), ["json"]):
|
|
record("b", -2.5)
|
|
dump()
|
|
|
|
reset()
|
|
record("a", "longasslongasslongasslongasslongasslongassvalue")
|
|
dump()
|
|
warn("hey")
|
|
error("oh")
|
|
record_dict({"test": 1})
|
|
|
|
|
|
@pytest.mark.parametrize("_format", ["stdout", "log", "json", "csv", "tensorboard"])
|
|
def test_make_output(tmp_path, read_log, _format):
|
|
"""
|
|
test make output
|
|
|
|
:param _format: (str) output format
|
|
"""
|
|
if _format == "tensorboard":
|
|
# Skip if no tensorboard installed
|
|
pytest.importorskip("tensorboard")
|
|
|
|
writer = make_output_format(_format, tmp_path)
|
|
writer.write(KEY_VALUES, KEY_EXCLUDED)
|
|
assert not read_log(_format).empty
|
|
writer.close()
|
|
|
|
|
|
def test_make_output_fail(tmp_path):
|
|
"""
|
|
test value error on logger
|
|
"""
|
|
with pytest.raises(ValueError):
|
|
make_output_format("dummy_format", tmp_path)
|
|
|
|
|
|
@pytest.mark.parametrize("_format", ["stdout", "log", "json", "csv", "tensorboard"])
|
|
def test_exclude_keys(tmp_path, read_log, _format):
|
|
if _format == "tensorboard":
|
|
# Skip if no tensorboard installed
|
|
pytest.importorskip("tensorboard")
|
|
|
|
writer = make_output_format(_format, tmp_path)
|
|
writer.write(dict(some_tag=42), key_excluded=dict(some_tag=(_format)))
|
|
writer.close()
|
|
assert read_log(_format).empty
|