From 89db65b1fb57eee5bd9a6394d5ba43d09b26aa84 Mon Sep 17 00:00:00 2001 From: Antonin Raffin Date: Mon, 20 Jan 2020 11:58:00 +0100 Subject: [PATCH] Improve logger testing + add readers --- setup.py | 7 ++- tests/test_logger.py | 82 +++++++++++++++++++++++++++++++ torchy_baselines/common/logger.py | 65 ++++++++++++++++++++++++ 3 files changed, 152 insertions(+), 2 deletions(-) create mode 100644 tests/test_logger.py diff --git a/setup.py b/setup.py index e5e411f..037e4bd 100644 --- a/setup.py +++ b/setup.py @@ -24,8 +24,11 @@ setup(name='torchy_baselines', 'sphinx-autobuild', 'sphinx-rtd-theme' ], - 'render': [ - 'opencv-python' + 'extra': [ + # For render + 'opencv-python', + # For reading logs + 'pandas' ] }, description='Pytorch version of Stable Baselines, implementations of reinforcement learning algorithms.', diff --git a/tests/test_logger.py b/tests/test_logger.py new file mode 100644 index 0000000..b55a616 --- /dev/null +++ b/tests/test_logger.py @@ -0,0 +1,82 @@ +import os +import shutil + +import pytest +import numpy as np + +from torchy_baselines.common.logger import (make_output_format, read_csv, read_json, DEBUG, ScopedConfigure, + info, debug, set_level, configure, logkv, logkvs, dumpkvs, logkv_mean, warn, error, reset) + +KEY_VALUES = { + "test": 1, + "b": -3.14, + "8": 9.9, + "l": [1, 2], + "a": np.array([1, 2, 3]), + "f": np.array(1), + "g": np.array([[[1]]]), +} + +LOG_DIR = '/tmp/torchy_baselines/' + + +def test_main(): + """ + tests for the logger module + """ + info("hi") + debug("shouldn't appear") + set_level(DEBUG) + debug("should appear") + folder = "/tmp/testlogging" + if os.path.exists(folder): + shutil.rmtree(folder) + configure(folder=folder) + logkv("a", 3) + logkv("b", 2.5) + dumpkvs() + logkv("b", -2.5) + logkv("a", 5.5) + dumpkvs() + info("^^^ should see a = 5.5") + logkv_mean("b", -22.5) + logkv_mean("b", -44.4) + logkv("a", 5.5) + dumpkvs() + with ScopedConfigure(None, None): + info("^^^ should see b = 33.3") + + with ScopedConfigure("/tmp/test-logger/", ["json"]): + logkv("b", -2.5) + dumpkvs() + + reset() + logkv("a", "longasslongasslongasslongasslongasslongassvalue") + dumpkvs() + warn("hey") + error("oh") + logkvs({"test": 1}) + + +@pytest.mark.parametrize('_format', ['stdout', 'log', 'json', 'csv']) +def test_make_output(_format): + """ + test make output + + :param _format: (str) output format + """ + writer = make_output_format(_format, LOG_DIR) + writer.writekvs(KEY_VALUES) + if _format == "csv": + read_csv(LOG_DIR + 'progress.csv') + elif _format == 'json': + read_json(LOG_DIR + 'progress.json') + writer.close() + + +def test_make_output_fail(): + """ + test value error on logger + """ + with pytest.raises(ValueError): + make_output_format('dummy_format', LOG_DIR) diff --git a/torchy_baselines/common/logger.py b/torchy_baselines/common/logger.py index 15528eb..a909285 100644 --- a/torchy_baselines/common/logger.py +++ b/torchy_baselines/common/logger.py @@ -505,3 +505,68 @@ def configure(folder=None, format_strs=None): Logger.CURRENT = Logger(folder=folder, output_formats=output_formats) log('Logging to %s' % folder) + + +def reset(): + """ + reset the current logger + """ + if Logger.CURRENT is not Logger.DEFAULT: + Logger.CURRENT.close() + Logger.CURRENT = Logger.DEFAULT + log('Reset logger') + + +class ScopedConfigure(object): + def __init__(self, folder=None, format_strs=None): + """ + Class for using context manager while logging + + usage: + with ScopedConfigure(folder=None, format_strs=None): + {code} + + :param folder: (str) the logging folder + :param format_strs: ([str]) the list of output logging format + """ + self.dir = folder + self.format_strs = format_strs + self.prevlogger = None + + def __enter__(self): + self.prevlogger = Logger.CURRENT + configure(folder=self.dir, format_strs=self.format_strs) + + def __exit__(self, *args): + Logger.CURRENT.close() + Logger.CURRENT = self.prevlogger + + +# ================================================================ +# Readers +# ================================================================ + +def read_json(fname): + """ + read a json file using pandas + + :param fname: (str) the file path to read + :return: (pandas DataFrame) the data in the json + """ + import pandas + data = [] + with open(fname, 'rt') as file_handler: + for line in file_handler: + data.append(json.loads(line)) + return pandas.DataFrame(data) + + +def read_csv(fname): + """ + read a csv file using pandas + + :param fname: (str) the file path to read + :return: (pandas DataFrame) the data in the csv + """ + import pandas + return pandas.read_csv(fname, index_col=None, comment='#')