diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index b245f2c..9181878 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -17,6 +17,7 @@ New Features: Bug Fixes: ^^^^^^^^^^ - Fix GAE computation for on-policy algorithms (off-by one for the last value) (thanks @Wovchena) +- Fix ignoring the exclude parameter when recording logs using json, csv or log as logging format (@SwamyDev) - Make ``make_vec_env`` support the ``env_kwargs`` argument when using an env ID str (@ManifoldFR) Deprecations: @@ -455,4 +456,4 @@ And all the contributors: @MarvineGothic @jdossgollin @SyllogismRXS @rusu24edward @jbulow @Antymon @seheevic @justinkterry @edbeeching @flodorner @KuKuXia @NeoExtended @PartiallyTyped @mmcenta @richardwu @kinalmehta @rolandgvc @tkelestemur @mloo3 @tirafesi @blurLake @koulakis @joeljosephjin @shwang @rk37 @andyshih12 @RaphaelWag @xicocaio -@diditforlulz273 @liorcohen5 @ManifoldFR @mloo3 +@diditforlulz273 @liorcohen5 @ManifoldFR @mloo3 @SwamyDev diff --git a/stable_baselines3/common/logger.py b/stable_baselines3/common/logger.py index 0b201a7..f97b69a 100644 --- a/stable_baselines3/common/logger.py +++ b/stable_baselines3/common/logger.py @@ -80,7 +80,7 @@ class HumanOutputFormat(KVWriter, SeqWriter): tag = None for (key, value), (_, excluded) in zip(sorted(key_values.items()), sorted(key_excluded.items())): - if excluded is not None and "stdout" in excluded: + if excluded is not None and ("stdout" in excluded or "log" in excluded): continue if isinstance(value, float): @@ -140,6 +140,24 @@ class HumanOutputFormat(KVWriter, SeqWriter): self.file.close() +def filter_excluded_keys( + key_values: Dict[str, Any], key_excluded: Dict[str, Union[str, Tuple[str, ...]]], _format: str +) -> Dict[str, Any]: + """ + Filters the keys specified by ``key_exclude`` for the specified format + + :param key_values: log dictionary to be filtered + :param key_excluded: keys to be excluded per format + :param _format: format for which this filter is run + :return: dict without the excluded keys + """ + + def is_excluded(key: str) -> bool: + return key in key_excluded and key_excluded[key] is not None and _format in key_excluded[key] + + return {key: value for key, value in key_values.items() if not is_excluded(key)} + + class JSONOutputFormat(KVWriter): def __init__(self, filename: str): """ @@ -150,18 +168,20 @@ class JSONOutputFormat(KVWriter): self.file = open(filename, "wt") def write(self, key_values: Dict[str, Any], key_excluded: Dict[str, Union[str, Tuple[str, ...]]], step: int = 0) -> None: - for (key, value), (_, excluded) in zip(sorted(key_values.items()), sorted(key_excluded.items())): - - if excluded is not None and "json" in excluded: - continue - + def cast_to_json_serializable(value: Any): if hasattr(value, "dtype"): if value.shape == () or len(value) == 1: # if value is a dimensionless numpy array or of length 1, serialize as a float - key_values[key] = float(value) + return float(value) else: # otherwise, a value is a numpy array, serialize as a list or nested lists - key_values[key] = value.tolist() + return value.tolist() + return value + + key_values = { + key: cast_to_json_serializable(value) + for key, value in filter_excluded_keys(key_values, key_excluded, "json").items() + } self.file.write(json.dumps(key_values) + "\n") self.file.flush() @@ -187,6 +207,7 @@ class CSVOutputFormat(KVWriter): def write(self, key_values: Dict[str, Any], key_excluded: Dict[str, Union[str, Tuple[str, ...]]], step: int = 0) -> None: # Add our current row to the history + key_values = filter_excluded_keys(key_values, key_excluded, "csv") extra_keys = key_values.keys() - self.keys if extra_keys: self.keys.extend(extra_keys) diff --git a/tests/test_logger.py b/tests/test_logger.py index c399c9e..5ab540f 100644 --- a/tests/test_logger.py +++ b/tests/test_logger.py @@ -1,5 +1,8 @@ +from typing import Sequence + import numpy as np import pytest +from pandas.errors import EmptyDataError from stable_baselines3.common.logger import ( DEBUG, @@ -35,6 +38,60 @@ for key in KEY_VALUES.keys(): KEY_EXCLUDED[key] = None +class LogContent: + """ + A simple wrapper class to provide a common interface to check content for emptiness and report the log format + """ + + def __init__(self, _format: str, lines: Sequence): + self.format = _format + self.lines = lines + + @property + def empty(self): + return len(self.lines) == 0 + + def __repr__(self): + return f"LogContent(_format={self.format}, lines={self.lines})" + + +@pytest.fixture +def read_log(tmp_path, capsys): + def read_fn(_format): + if _format == "csv": + try: + df = read_csv(tmp_path / "progress.csv") + except EmptyDataError: + return LogContent(_format, []) + return LogContent(_format, [r for _, r in df.iterrows() if not r.empty]) + elif _format == "json": + try: + df = read_json(tmp_path / "progress.json") + except EmptyDataError: + return LogContent(_format, []) + return LogContent(_format, [r for _, r in df.iterrows() if not r.empty]) + elif _format == "stdout": + captured = capsys.readouterr() + return LogContent(_format, captured.out.splitlines()) + elif _format == "log": + return LogContent(_format, (tmp_path / "log.txt").read_text().splitlines()) + elif _format == "tensorboard": + from tensorboard.backend.event_processing.event_accumulator import EventAccumulator + + acc = EventAccumulator(str(tmp_path)) + acc.Reload() + + tb_values_logged = [] + for reservoir in [acc.scalars, acc.tensors, acc.images, acc.histograms, acc.compressed_histograms]: + for k in reservoir.Keys(): + tb_values_logged.append(f"{k}: {str(reservoir.Items(k))}") + + content = LogContent(_format, tb_values_logged) + return content + + return read_fn + + def test_main(tmp_path): """ tests for the logger module @@ -71,7 +128,7 @@ def test_main(tmp_path): @pytest.mark.parametrize("_format", ["stdout", "log", "json", "csv", "tensorboard"]) -def test_make_output(tmp_path, _format): +def test_make_output(tmp_path, read_log, _format): """ test make output @@ -83,10 +140,7 @@ def test_make_output(tmp_path, _format): writer = make_output_format(_format, tmp_path) writer.write(KEY_VALUES, KEY_EXCLUDED) - if _format == "csv": - read_csv(tmp_path / "progress.csv") - elif _format == "json": - read_json(tmp_path / "progress.json") + assert not read_log(_format).empty writer.close() @@ -96,3 +150,15 @@ def test_make_output_fail(tmp_path): """ with pytest.raises(ValueError): make_output_format("dummy_format", tmp_path) + + +@pytest.mark.parametrize("_format", ["stdout", "log", "json", "csv", "tensorboard"]) +def test_exclude_keys(tmp_path, read_log, _format): + if _format == "tensorboard": + # Skip if no tensorboard installed + pytest.importorskip("tensorboard") + + writer = make_output_format(_format, tmp_path) + writer.write(dict(some_tag=42), key_excluded=dict(some_tag=(_format))) + writer.close() + assert read_log(_format).empty