mirror of
https://github.com/saymrwulf/stable-baselines3.git
synced 2026-05-18 21:30:19 +00:00
Fix ignoring the exclude in the logger's record function for json, csv and log logging formats (#190)
* Fix ignoring the exclude in logger record For the logging formats json, csv, and log the exclude parameter of the logger's record function has been ignored. The necessary checks were missing from some of the format writer classes. Regression tests have been added to prevent this error in the future. * Fix docstring for filter_excluded_keys Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org> * Added missing type hints to local functions * Update stable_baselines3/common/logger.py Co-authored-by: Bernhard Raml <raml.bernhard@gmail.com> Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org>
This commit is contained in:
parent
fe6ade3089
commit
97b81f9e9e
3 changed files with 102 additions and 14 deletions
|
|
@ -17,6 +17,7 @@ New Features:
|
|||
Bug Fixes:
|
||||
^^^^^^^^^^
|
||||
- Fix GAE computation for on-policy algorithms (off-by one for the last value) (thanks @Wovchena)
|
||||
- Fix ignoring the exclude parameter when recording logs using json, csv or log as logging format (@SwamyDev)
|
||||
- Make ``make_vec_env`` support the ``env_kwargs`` argument when using an env ID str (@ManifoldFR)
|
||||
|
||||
Deprecations:
|
||||
|
|
@ -455,4 +456,4 @@ And all the contributors:
|
|||
@MarvineGothic @jdossgollin @SyllogismRXS @rusu24edward @jbulow @Antymon @seheevic @justinkterry @edbeeching
|
||||
@flodorner @KuKuXia @NeoExtended @PartiallyTyped @mmcenta @richardwu @kinalmehta @rolandgvc @tkelestemur @mloo3
|
||||
@tirafesi @blurLake @koulakis @joeljosephjin @shwang @rk37 @andyshih12 @RaphaelWag @xicocaio
|
||||
@diditforlulz273 @liorcohen5 @ManifoldFR @mloo3
|
||||
@diditforlulz273 @liorcohen5 @ManifoldFR @mloo3 @SwamyDev
|
||||
|
|
|
|||
|
|
@ -80,7 +80,7 @@ class HumanOutputFormat(KVWriter, SeqWriter):
|
|||
tag = None
|
||||
for (key, value), (_, excluded) in zip(sorted(key_values.items()), sorted(key_excluded.items())):
|
||||
|
||||
if excluded is not None and "stdout" in excluded:
|
||||
if excluded is not None and ("stdout" in excluded or "log" in excluded):
|
||||
continue
|
||||
|
||||
if isinstance(value, float):
|
||||
|
|
@ -140,6 +140,24 @@ class HumanOutputFormat(KVWriter, SeqWriter):
|
|||
self.file.close()
|
||||
|
||||
|
||||
def filter_excluded_keys(
|
||||
key_values: Dict[str, Any], key_excluded: Dict[str, Union[str, Tuple[str, ...]]], _format: str
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Filters the keys specified by ``key_exclude`` for the specified format
|
||||
|
||||
:param key_values: log dictionary to be filtered
|
||||
:param key_excluded: keys to be excluded per format
|
||||
:param _format: format for which this filter is run
|
||||
:return: dict without the excluded keys
|
||||
"""
|
||||
|
||||
def is_excluded(key: str) -> bool:
|
||||
return key in key_excluded and key_excluded[key] is not None and _format in key_excluded[key]
|
||||
|
||||
return {key: value for key, value in key_values.items() if not is_excluded(key)}
|
||||
|
||||
|
||||
class JSONOutputFormat(KVWriter):
|
||||
def __init__(self, filename: str):
|
||||
"""
|
||||
|
|
@ -150,18 +168,20 @@ class JSONOutputFormat(KVWriter):
|
|||
self.file = open(filename, "wt")
|
||||
|
||||
def write(self, key_values: Dict[str, Any], key_excluded: Dict[str, Union[str, Tuple[str, ...]]], step: int = 0) -> None:
|
||||
for (key, value), (_, excluded) in zip(sorted(key_values.items()), sorted(key_excluded.items())):
|
||||
|
||||
if excluded is not None and "json" in excluded:
|
||||
continue
|
||||
|
||||
def cast_to_json_serializable(value: Any):
|
||||
if hasattr(value, "dtype"):
|
||||
if value.shape == () or len(value) == 1:
|
||||
# if value is a dimensionless numpy array or of length 1, serialize as a float
|
||||
key_values[key] = float(value)
|
||||
return float(value)
|
||||
else:
|
||||
# otherwise, a value is a numpy array, serialize as a list or nested lists
|
||||
key_values[key] = value.tolist()
|
||||
return value.tolist()
|
||||
return value
|
||||
|
||||
key_values = {
|
||||
key: cast_to_json_serializable(value)
|
||||
for key, value in filter_excluded_keys(key_values, key_excluded, "json").items()
|
||||
}
|
||||
self.file.write(json.dumps(key_values) + "\n")
|
||||
self.file.flush()
|
||||
|
||||
|
|
@ -187,6 +207,7 @@ class CSVOutputFormat(KVWriter):
|
|||
|
||||
def write(self, key_values: Dict[str, Any], key_excluded: Dict[str, Union[str, Tuple[str, ...]]], step: int = 0) -> None:
|
||||
# Add our current row to the history
|
||||
key_values = filter_excluded_keys(key_values, key_excluded, "csv")
|
||||
extra_keys = key_values.keys() - self.keys
|
||||
if extra_keys:
|
||||
self.keys.extend(extra_keys)
|
||||
|
|
|
|||
|
|
@ -1,5 +1,8 @@
|
|||
from typing import Sequence
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
from pandas.errors import EmptyDataError
|
||||
|
||||
from stable_baselines3.common.logger import (
|
||||
DEBUG,
|
||||
|
|
@ -35,6 +38,60 @@ for key in KEY_VALUES.keys():
|
|||
KEY_EXCLUDED[key] = None
|
||||
|
||||
|
||||
class LogContent:
|
||||
"""
|
||||
A simple wrapper class to provide a common interface to check content for emptiness and report the log format
|
||||
"""
|
||||
|
||||
def __init__(self, _format: str, lines: Sequence):
|
||||
self.format = _format
|
||||
self.lines = lines
|
||||
|
||||
@property
|
||||
def empty(self):
|
||||
return len(self.lines) == 0
|
||||
|
||||
def __repr__(self):
|
||||
return f"LogContent(_format={self.format}, lines={self.lines})"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def read_log(tmp_path, capsys):
|
||||
def read_fn(_format):
|
||||
if _format == "csv":
|
||||
try:
|
||||
df = read_csv(tmp_path / "progress.csv")
|
||||
except EmptyDataError:
|
||||
return LogContent(_format, [])
|
||||
return LogContent(_format, [r for _, r in df.iterrows() if not r.empty])
|
||||
elif _format == "json":
|
||||
try:
|
||||
df = read_json(tmp_path / "progress.json")
|
||||
except EmptyDataError:
|
||||
return LogContent(_format, [])
|
||||
return LogContent(_format, [r for _, r in df.iterrows() if not r.empty])
|
||||
elif _format == "stdout":
|
||||
captured = capsys.readouterr()
|
||||
return LogContent(_format, captured.out.splitlines())
|
||||
elif _format == "log":
|
||||
return LogContent(_format, (tmp_path / "log.txt").read_text().splitlines())
|
||||
elif _format == "tensorboard":
|
||||
from tensorboard.backend.event_processing.event_accumulator import EventAccumulator
|
||||
|
||||
acc = EventAccumulator(str(tmp_path))
|
||||
acc.Reload()
|
||||
|
||||
tb_values_logged = []
|
||||
for reservoir in [acc.scalars, acc.tensors, acc.images, acc.histograms, acc.compressed_histograms]:
|
||||
for k in reservoir.Keys():
|
||||
tb_values_logged.append(f"{k}: {str(reservoir.Items(k))}")
|
||||
|
||||
content = LogContent(_format, tb_values_logged)
|
||||
return content
|
||||
|
||||
return read_fn
|
||||
|
||||
|
||||
def test_main(tmp_path):
|
||||
"""
|
||||
tests for the logger module
|
||||
|
|
@ -71,7 +128,7 @@ def test_main(tmp_path):
|
|||
|
||||
|
||||
@pytest.mark.parametrize("_format", ["stdout", "log", "json", "csv", "tensorboard"])
|
||||
def test_make_output(tmp_path, _format):
|
||||
def test_make_output(tmp_path, read_log, _format):
|
||||
"""
|
||||
test make output
|
||||
|
||||
|
|
@ -83,10 +140,7 @@ def test_make_output(tmp_path, _format):
|
|||
|
||||
writer = make_output_format(_format, tmp_path)
|
||||
writer.write(KEY_VALUES, KEY_EXCLUDED)
|
||||
if _format == "csv":
|
||||
read_csv(tmp_path / "progress.csv")
|
||||
elif _format == "json":
|
||||
read_json(tmp_path / "progress.json")
|
||||
assert not read_log(_format).empty
|
||||
writer.close()
|
||||
|
||||
|
||||
|
|
@ -96,3 +150,15 @@ def test_make_output_fail(tmp_path):
|
|||
"""
|
||||
with pytest.raises(ValueError):
|
||||
make_output_format("dummy_format", tmp_path)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("_format", ["stdout", "log", "json", "csv", "tensorboard"])
|
||||
def test_exclude_keys(tmp_path, read_log, _format):
|
||||
if _format == "tensorboard":
|
||||
# Skip if no tensorboard installed
|
||||
pytest.importorskip("tensorboard")
|
||||
|
||||
writer = make_output_format(_format, tmp_path)
|
||||
writer.write(dict(some_tag=42), key_excluded=dict(some_tag=(_format)))
|
||||
writer.close()
|
||||
assert read_log(_format).empty
|
||||
|
|
|
|||
Loading…
Reference in a new issue