From 97b81f9e9ee2e5ba7eb37328bbd21f8eade44e72 Mon Sep 17 00:00:00 2001 From: Bernhard Raml Date: Fri, 16 Oct 2020 17:34:49 +0200 Subject: [PATCH] Fix ignoring the exclude in the logger's record function for json, csv and log logging formats (#190) * Fix ignoring the exclude in logger record For the logging formats json, csv, and log the exclude parameter of the logger's record function has been ignored. The necessary checks were missing from some of the format writer classes. Regression tests have been added to prevent this error in the future. * Fix docstring for filter_excluded_keys Co-authored-by: Antonin RAFFIN * Added missing type hints to local functions * Update stable_baselines3/common/logger.py Co-authored-by: Bernhard Raml Co-authored-by: Antonin RAFFIN --- docs/misc/changelog.rst | 3 +- stable_baselines3/common/logger.py | 37 +++++++++++---- tests/test_logger.py | 76 ++++++++++++++++++++++++++++-- 3 files changed, 102 insertions(+), 14 deletions(-) diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index b245f2c..9181878 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -17,6 +17,7 @@ New Features: Bug Fixes: ^^^^^^^^^^ - Fix GAE computation for on-policy algorithms (off-by one for the last value) (thanks @Wovchena) +- Fix ignoring the exclude parameter when recording logs using json, csv or log as logging format (@SwamyDev) - Make ``make_vec_env`` support the ``env_kwargs`` argument when using an env ID str (@ManifoldFR) Deprecations: @@ -455,4 +456,4 @@ And all the contributors: @MarvineGothic @jdossgollin @SyllogismRXS @rusu24edward @jbulow @Antymon @seheevic @justinkterry @edbeeching @flodorner @KuKuXia @NeoExtended @PartiallyTyped @mmcenta @richardwu @kinalmehta @rolandgvc @tkelestemur @mloo3 @tirafesi @blurLake @koulakis @joeljosephjin @shwang @rk37 @andyshih12 @RaphaelWag @xicocaio -@diditforlulz273 @liorcohen5 @ManifoldFR @mloo3 +@diditforlulz273 @liorcohen5 @ManifoldFR @mloo3 @SwamyDev diff --git a/stable_baselines3/common/logger.py b/stable_baselines3/common/logger.py index 0b201a7..f97b69a 100644 --- a/stable_baselines3/common/logger.py +++ b/stable_baselines3/common/logger.py @@ -80,7 +80,7 @@ class HumanOutputFormat(KVWriter, SeqWriter): tag = None for (key, value), (_, excluded) in zip(sorted(key_values.items()), sorted(key_excluded.items())): - if excluded is not None and "stdout" in excluded: + if excluded is not None and ("stdout" in excluded or "log" in excluded): continue if isinstance(value, float): @@ -140,6 +140,24 @@ class HumanOutputFormat(KVWriter, SeqWriter): self.file.close() +def filter_excluded_keys( + key_values: Dict[str, Any], key_excluded: Dict[str, Union[str, Tuple[str, ...]]], _format: str +) -> Dict[str, Any]: + """ + Filters the keys specified by ``key_exclude`` for the specified format + + :param key_values: log dictionary to be filtered + :param key_excluded: keys to be excluded per format + :param _format: format for which this filter is run + :return: dict without the excluded keys + """ + + def is_excluded(key: str) -> bool: + return key in key_excluded and key_excluded[key] is not None and _format in key_excluded[key] + + return {key: value for key, value in key_values.items() if not is_excluded(key)} + + class JSONOutputFormat(KVWriter): def __init__(self, filename: str): """ @@ -150,18 +168,20 @@ class JSONOutputFormat(KVWriter): self.file = open(filename, "wt") def write(self, key_values: Dict[str, Any], key_excluded: Dict[str, Union[str, Tuple[str, ...]]], step: int = 0) -> None: - for (key, value), (_, excluded) in zip(sorted(key_values.items()), sorted(key_excluded.items())): - - if excluded is not None and "json" in excluded: - continue - + def cast_to_json_serializable(value: Any): if hasattr(value, "dtype"): if value.shape == () or len(value) == 1: # if value is a dimensionless numpy array or of length 1, serialize as a float - key_values[key] = float(value) + return float(value) else: # otherwise, a value is a numpy array, serialize as a list or nested lists - key_values[key] = value.tolist() + return value.tolist() + return value + + key_values = { + key: cast_to_json_serializable(value) + for key, value in filter_excluded_keys(key_values, key_excluded, "json").items() + } self.file.write(json.dumps(key_values) + "\n") self.file.flush() @@ -187,6 +207,7 @@ class CSVOutputFormat(KVWriter): def write(self, key_values: Dict[str, Any], key_excluded: Dict[str, Union[str, Tuple[str, ...]]], step: int = 0) -> None: # Add our current row to the history + key_values = filter_excluded_keys(key_values, key_excluded, "csv") extra_keys = key_values.keys() - self.keys if extra_keys: self.keys.extend(extra_keys) diff --git a/tests/test_logger.py b/tests/test_logger.py index c399c9e..5ab540f 100644 --- a/tests/test_logger.py +++ b/tests/test_logger.py @@ -1,5 +1,8 @@ +from typing import Sequence + import numpy as np import pytest +from pandas.errors import EmptyDataError from stable_baselines3.common.logger import ( DEBUG, @@ -35,6 +38,60 @@ for key in KEY_VALUES.keys(): KEY_EXCLUDED[key] = None +class LogContent: + """ + A simple wrapper class to provide a common interface to check content for emptiness and report the log format + """ + + def __init__(self, _format: str, lines: Sequence): + self.format = _format + self.lines = lines + + @property + def empty(self): + return len(self.lines) == 0 + + def __repr__(self): + return f"LogContent(_format={self.format}, lines={self.lines})" + + +@pytest.fixture +def read_log(tmp_path, capsys): + def read_fn(_format): + if _format == "csv": + try: + df = read_csv(tmp_path / "progress.csv") + except EmptyDataError: + return LogContent(_format, []) + return LogContent(_format, [r for _, r in df.iterrows() if not r.empty]) + elif _format == "json": + try: + df = read_json(tmp_path / "progress.json") + except EmptyDataError: + return LogContent(_format, []) + return LogContent(_format, [r for _, r in df.iterrows() if not r.empty]) + elif _format == "stdout": + captured = capsys.readouterr() + return LogContent(_format, captured.out.splitlines()) + elif _format == "log": + return LogContent(_format, (tmp_path / "log.txt").read_text().splitlines()) + elif _format == "tensorboard": + from tensorboard.backend.event_processing.event_accumulator import EventAccumulator + + acc = EventAccumulator(str(tmp_path)) + acc.Reload() + + tb_values_logged = [] + for reservoir in [acc.scalars, acc.tensors, acc.images, acc.histograms, acc.compressed_histograms]: + for k in reservoir.Keys(): + tb_values_logged.append(f"{k}: {str(reservoir.Items(k))}") + + content = LogContent(_format, tb_values_logged) + return content + + return read_fn + + def test_main(tmp_path): """ tests for the logger module @@ -71,7 +128,7 @@ def test_main(tmp_path): @pytest.mark.parametrize("_format", ["stdout", "log", "json", "csv", "tensorboard"]) -def test_make_output(tmp_path, _format): +def test_make_output(tmp_path, read_log, _format): """ test make output @@ -83,10 +140,7 @@ def test_make_output(tmp_path, _format): writer = make_output_format(_format, tmp_path) writer.write(KEY_VALUES, KEY_EXCLUDED) - if _format == "csv": - read_csv(tmp_path / "progress.csv") - elif _format == "json": - read_json(tmp_path / "progress.json") + assert not read_log(_format).empty writer.close() @@ -96,3 +150,15 @@ def test_make_output_fail(tmp_path): """ with pytest.raises(ValueError): make_output_format("dummy_format", tmp_path) + + +@pytest.mark.parametrize("_format", ["stdout", "log", "json", "csv", "tensorboard"]) +def test_exclude_keys(tmp_path, read_log, _format): + if _format == "tensorboard": + # Skip if no tensorboard installed + pytest.importorskip("tensorboard") + + writer = make_output_format(_format, tmp_path) + writer.write(dict(some_tag=42), key_excluded=dict(some_tag=(_format))) + writer.close() + assert read_log(_format).empty