mirror of
https://github.com/saymrwulf/stable-baselines3.git
synced 2026-06-06 00:03:28 +00:00
Improve logger testing + add readers
This commit is contained in:
parent
c542009641
commit
89db65b1fb
3 changed files with 152 additions and 2 deletions
7
setup.py
7
setup.py
|
|
@ -24,8 +24,11 @@ setup(name='torchy_baselines',
|
|||
'sphinx-autobuild',
|
||||
'sphinx-rtd-theme'
|
||||
],
|
||||
'render': [
|
||||
'opencv-python'
|
||||
'extra': [
|
||||
# For render
|
||||
'opencv-python',
|
||||
# For reading logs
|
||||
'pandas'
|
||||
]
|
||||
},
|
||||
description='Pytorch version of Stable Baselines, implementations of reinforcement learning algorithms.',
|
||||
|
|
|
|||
82
tests/test_logger.py
Normal file
82
tests/test_logger.py
Normal file
|
|
@ -0,0 +1,82 @@
|
|||
import os
|
||||
import shutil
|
||||
|
||||
import pytest
|
||||
import numpy as np
|
||||
|
||||
from torchy_baselines.common.logger import (make_output_format, read_csv, read_json, DEBUG, ScopedConfigure,
|
||||
info, debug, set_level, configure, logkv, logkvs, dumpkvs, logkv_mean, warn, error, reset)
|
||||
|
||||
KEY_VALUES = {
|
||||
"test": 1,
|
||||
"b": -3.14,
|
||||
"8": 9.9,
|
||||
"l": [1, 2],
|
||||
"a": np.array([1, 2, 3]),
|
||||
"f": np.array(1),
|
||||
"g": np.array([[[1]]]),
|
||||
}
|
||||
|
||||
LOG_DIR = '/tmp/torchy_baselines/'
|
||||
|
||||
|
||||
def test_main():
|
||||
"""
|
||||
tests for the logger module
|
||||
"""
|
||||
info("hi")
|
||||
debug("shouldn't appear")
|
||||
set_level(DEBUG)
|
||||
debug("should appear")
|
||||
folder = "/tmp/testlogging"
|
||||
if os.path.exists(folder):
|
||||
shutil.rmtree(folder)
|
||||
configure(folder=folder)
|
||||
logkv("a", 3)
|
||||
logkv("b", 2.5)
|
||||
dumpkvs()
|
||||
logkv("b", -2.5)
|
||||
logkv("a", 5.5)
|
||||
dumpkvs()
|
||||
info("^^^ should see a = 5.5")
|
||||
logkv_mean("b", -22.5)
|
||||
logkv_mean("b", -44.4)
|
||||
logkv("a", 5.5)
|
||||
dumpkvs()
|
||||
with ScopedConfigure(None, None):
|
||||
info("^^^ should see b = 33.3")
|
||||
|
||||
with ScopedConfigure("/tmp/test-logger/", ["json"]):
|
||||
logkv("b", -2.5)
|
||||
dumpkvs()
|
||||
|
||||
reset()
|
||||
logkv("a", "longasslongasslongasslongasslongasslongassvalue")
|
||||
dumpkvs()
|
||||
warn("hey")
|
||||
error("oh")
|
||||
logkvs({"test": 1})
|
||||
|
||||
|
||||
@pytest.mark.parametrize('_format', ['stdout', 'log', 'json', 'csv'])
|
||||
def test_make_output(_format):
|
||||
"""
|
||||
test make output
|
||||
|
||||
:param _format: (str) output format
|
||||
"""
|
||||
writer = make_output_format(_format, LOG_DIR)
|
||||
writer.writekvs(KEY_VALUES)
|
||||
if _format == "csv":
|
||||
read_csv(LOG_DIR + 'progress.csv')
|
||||
elif _format == 'json':
|
||||
read_json(LOG_DIR + 'progress.json')
|
||||
writer.close()
|
||||
|
||||
|
||||
def test_make_output_fail():
|
||||
"""
|
||||
test value error on logger
|
||||
"""
|
||||
with pytest.raises(ValueError):
|
||||
make_output_format('dummy_format', LOG_DIR)
|
||||
|
|
@ -505,3 +505,68 @@ def configure(folder=None, format_strs=None):
|
|||
|
||||
Logger.CURRENT = Logger(folder=folder, output_formats=output_formats)
|
||||
log('Logging to %s' % folder)
|
||||
|
||||
|
||||
def reset():
|
||||
"""
|
||||
reset the current logger
|
||||
"""
|
||||
if Logger.CURRENT is not Logger.DEFAULT:
|
||||
Logger.CURRENT.close()
|
||||
Logger.CURRENT = Logger.DEFAULT
|
||||
log('Reset logger')
|
||||
|
||||
|
||||
class ScopedConfigure(object):
|
||||
def __init__(self, folder=None, format_strs=None):
|
||||
"""
|
||||
Class for using context manager while logging
|
||||
|
||||
usage:
|
||||
with ScopedConfigure(folder=None, format_strs=None):
|
||||
{code}
|
||||
|
||||
:param folder: (str) the logging folder
|
||||
:param format_strs: ([str]) the list of output logging format
|
||||
"""
|
||||
self.dir = folder
|
||||
self.format_strs = format_strs
|
||||
self.prevlogger = None
|
||||
|
||||
def __enter__(self):
|
||||
self.prevlogger = Logger.CURRENT
|
||||
configure(folder=self.dir, format_strs=self.format_strs)
|
||||
|
||||
def __exit__(self, *args):
|
||||
Logger.CURRENT.close()
|
||||
Logger.CURRENT = self.prevlogger
|
||||
|
||||
|
||||
# ================================================================
|
||||
# Readers
|
||||
# ================================================================
|
||||
|
||||
def read_json(fname):
|
||||
"""
|
||||
read a json file using pandas
|
||||
|
||||
:param fname: (str) the file path to read
|
||||
:return: (pandas DataFrame) the data in the json
|
||||
"""
|
||||
import pandas
|
||||
data = []
|
||||
with open(fname, 'rt') as file_handler:
|
||||
for line in file_handler:
|
||||
data.append(json.loads(line))
|
||||
return pandas.DataFrame(data)
|
||||
|
||||
|
||||
def read_csv(fname):
|
||||
"""
|
||||
read a csv file using pandas
|
||||
|
||||
:param fname: (str) the file path to read
|
||||
:return: (pandas DataFrame) the data in the csv
|
||||
"""
|
||||
import pandas
|
||||
return pandas.read_csv(fname, index_col=None, comment='#')
|
||||
|
|
|
|||
Loading…
Reference in a new issue