Fix ignoring the exclude in the logger's record function for json, csv and log logging formats (#190)

* Fix ignoring the exclude in logger record

For the logging formats json, csv, and log the exclude parameter of the
logger's record function has been ignored. The necessary checks were
missing from some of the format writer classes. Regression tests have
been added to prevent this error in the future.

* Fix docstring for filter_excluded_keys

Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org>

* Added missing type hints to local functions

* Update stable_baselines3/common/logger.py

Co-authored-by: Bernhard Raml <raml.bernhard@gmail.com>
Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org>
This commit is contained in:
Bernhard Raml 2020-10-16 17:34:49 +02:00 committed by GitHub
parent fe6ade3089
commit 97b81f9e9e
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 102 additions and 14 deletions

View file

@ -17,6 +17,7 @@ New Features:
Bug Fixes:
^^^^^^^^^^
- Fix GAE computation for on-policy algorithms (off-by one for the last value) (thanks @Wovchena)
- Fix ignoring the exclude parameter when recording logs using json, csv or log as logging format (@SwamyDev)
- Make ``make_vec_env`` support the ``env_kwargs`` argument when using an env ID str (@ManifoldFR)
Deprecations:
@ -455,4 +456,4 @@ And all the contributors:
@MarvineGothic @jdossgollin @SyllogismRXS @rusu24edward @jbulow @Antymon @seheevic @justinkterry @edbeeching
@flodorner @KuKuXia @NeoExtended @PartiallyTyped @mmcenta @richardwu @kinalmehta @rolandgvc @tkelestemur @mloo3
@tirafesi @blurLake @koulakis @joeljosephjin @shwang @rk37 @andyshih12 @RaphaelWag @xicocaio
@diditforlulz273 @liorcohen5 @ManifoldFR @mloo3
@diditforlulz273 @liorcohen5 @ManifoldFR @mloo3 @SwamyDev

View file

@ -80,7 +80,7 @@ class HumanOutputFormat(KVWriter, SeqWriter):
tag = None
for (key, value), (_, excluded) in zip(sorted(key_values.items()), sorted(key_excluded.items())):
if excluded is not None and "stdout" in excluded:
if excluded is not None and ("stdout" in excluded or "log" in excluded):
continue
if isinstance(value, float):
@ -140,6 +140,24 @@ class HumanOutputFormat(KVWriter, SeqWriter):
self.file.close()
def filter_excluded_keys(
key_values: Dict[str, Any], key_excluded: Dict[str, Union[str, Tuple[str, ...]]], _format: str
) -> Dict[str, Any]:
"""
Filters the keys specified by ``key_exclude`` for the specified format
:param key_values: log dictionary to be filtered
:param key_excluded: keys to be excluded per format
:param _format: format for which this filter is run
:return: dict without the excluded keys
"""
def is_excluded(key: str) -> bool:
return key in key_excluded and key_excluded[key] is not None and _format in key_excluded[key]
return {key: value for key, value in key_values.items() if not is_excluded(key)}
class JSONOutputFormat(KVWriter):
def __init__(self, filename: str):
"""
@ -150,18 +168,20 @@ class JSONOutputFormat(KVWriter):
self.file = open(filename, "wt")
def write(self, key_values: Dict[str, Any], key_excluded: Dict[str, Union[str, Tuple[str, ...]]], step: int = 0) -> None:
for (key, value), (_, excluded) in zip(sorted(key_values.items()), sorted(key_excluded.items())):
if excluded is not None and "json" in excluded:
continue
def cast_to_json_serializable(value: Any):
if hasattr(value, "dtype"):
if value.shape == () or len(value) == 1:
# if value is a dimensionless numpy array or of length 1, serialize as a float
key_values[key] = float(value)
return float(value)
else:
# otherwise, a value is a numpy array, serialize as a list or nested lists
key_values[key] = value.tolist()
return value.tolist()
return value
key_values = {
key: cast_to_json_serializable(value)
for key, value in filter_excluded_keys(key_values, key_excluded, "json").items()
}
self.file.write(json.dumps(key_values) + "\n")
self.file.flush()
@ -187,6 +207,7 @@ class CSVOutputFormat(KVWriter):
def write(self, key_values: Dict[str, Any], key_excluded: Dict[str, Union[str, Tuple[str, ...]]], step: int = 0) -> None:
# Add our current row to the history
key_values = filter_excluded_keys(key_values, key_excluded, "csv")
extra_keys = key_values.keys() - self.keys
if extra_keys:
self.keys.extend(extra_keys)

View file

@ -1,5 +1,8 @@
from typing import Sequence
import numpy as np
import pytest
from pandas.errors import EmptyDataError
from stable_baselines3.common.logger import (
DEBUG,
@ -35,6 +38,60 @@ for key in KEY_VALUES.keys():
KEY_EXCLUDED[key] = None
class LogContent:
"""
A simple wrapper class to provide a common interface to check content for emptiness and report the log format
"""
def __init__(self, _format: str, lines: Sequence):
self.format = _format
self.lines = lines
@property
def empty(self):
return len(self.lines) == 0
def __repr__(self):
return f"LogContent(_format={self.format}, lines={self.lines})"
@pytest.fixture
def read_log(tmp_path, capsys):
def read_fn(_format):
if _format == "csv":
try:
df = read_csv(tmp_path / "progress.csv")
except EmptyDataError:
return LogContent(_format, [])
return LogContent(_format, [r for _, r in df.iterrows() if not r.empty])
elif _format == "json":
try:
df = read_json(tmp_path / "progress.json")
except EmptyDataError:
return LogContent(_format, [])
return LogContent(_format, [r for _, r in df.iterrows() if not r.empty])
elif _format == "stdout":
captured = capsys.readouterr()
return LogContent(_format, captured.out.splitlines())
elif _format == "log":
return LogContent(_format, (tmp_path / "log.txt").read_text().splitlines())
elif _format == "tensorboard":
from tensorboard.backend.event_processing.event_accumulator import EventAccumulator
acc = EventAccumulator(str(tmp_path))
acc.Reload()
tb_values_logged = []
for reservoir in [acc.scalars, acc.tensors, acc.images, acc.histograms, acc.compressed_histograms]:
for k in reservoir.Keys():
tb_values_logged.append(f"{k}: {str(reservoir.Items(k))}")
content = LogContent(_format, tb_values_logged)
return content
return read_fn
def test_main(tmp_path):
"""
tests for the logger module
@ -71,7 +128,7 @@ def test_main(tmp_path):
@pytest.mark.parametrize("_format", ["stdout", "log", "json", "csv", "tensorboard"])
def test_make_output(tmp_path, _format):
def test_make_output(tmp_path, read_log, _format):
"""
test make output
@ -83,10 +140,7 @@ def test_make_output(tmp_path, _format):
writer = make_output_format(_format, tmp_path)
writer.write(KEY_VALUES, KEY_EXCLUDED)
if _format == "csv":
read_csv(tmp_path / "progress.csv")
elif _format == "json":
read_json(tmp_path / "progress.json")
assert not read_log(_format).empty
writer.close()
@ -96,3 +150,15 @@ def test_make_output_fail(tmp_path):
"""
with pytest.raises(ValueError):
make_output_format("dummy_format", tmp_path)
@pytest.mark.parametrize("_format", ["stdout", "log", "json", "csv", "tensorboard"])
def test_exclude_keys(tmp_path, read_log, _format):
if _format == "tensorboard":
# Skip if no tensorboard installed
pytest.importorskip("tensorboard")
writer = make_output_format(_format, tmp_path)
writer.write(dict(some_tag=42), key_excluded=dict(some_tag=(_format)))
writer.close()
assert read_log(_format).empty