From 88cee2ba55bec4a04791efb00bf4d6f0e9b9a0a8 Mon Sep 17 00:00:00 2001 From: Antonin RAFFIN Date: Tue, 5 May 2020 14:49:32 +0200 Subject: [PATCH] Add type hints and f-strings to logger --- docs/misc/changelog.rst | 5 +- torchy_baselines/common/logger.py | 158 +++++++++++++++--------------- torchy_baselines/version.txt | 2 +- 3 files changed, 86 insertions(+), 79 deletions(-) diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index 74b5c5b..a4fd097 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -3,9 +3,11 @@ Changelog ========== -Pre-Release 0.5.0a2 (WIP) +Pre-Release 0.5.0 (2020-05-05) ------------------------------ +**CnnPolicy support for image observations, complete saving/loading for policies** + Breaking Changes: ^^^^^^^^^^^^^^^^^ - Previous loading of policy weights is broken and replace by the new saving/loading for policy @@ -30,6 +32,7 @@ Others: ^^^^^^^ - Cleanup rollout return - Added ``get_device`` util to manage PyTorch devices +- Added type hints to logger + use f-strings Documentation: ^^^^^^^^^^^^^^ diff --git a/torchy_baselines/common/logger.py b/torchy_baselines/common/logger.py index 0da70e6..15074bf 100644 --- a/torchy_baselines/common/logger.py +++ b/torchy_baselines/common/logger.py @@ -1,6 +1,3 @@ -""" -Taken from stable-baselines -""" import sys import datetime import json @@ -8,6 +5,9 @@ import os import tempfile import warnings from collections import defaultdict +from typing import Dict, List, TextIO, Union, Any, Optional + +import pandas DEBUG = 10 INFO = 20 @@ -20,7 +20,7 @@ class KVWriter(object): """ Key Value writer """ - def writekvs(self, kvs): + def writekvs(self, kvs: Dict) -> None: """ write a dictionary to file @@ -28,12 +28,18 @@ class KVWriter(object): """ raise NotImplementedError + def close(self) -> None: + """ + Close owned resources + """ + raise NotImplementedError + class SeqWriter(object): """ sequence writer """ - def writeseq(self, seq): + def writeseq(self, seq: List): """ write an array to file @@ -43,7 +49,7 @@ class SeqWriter(object): class HumanOutputFormat(KVWriter, SeqWriter): - def __init__(self, filename_or_file): + def __init__(self, filename_or_file: Union [str, TextIO]): """ log to a file, in a human readable format @@ -57,34 +63,32 @@ class HumanOutputFormat(KVWriter, SeqWriter): self.file = filename_or_file self.own_file = False - def writekvs(self, kvs): + def writekvs(self, kvs: Dict) -> None: # Create strings for printing key2str = {} for (key, val) in sorted(kvs.items()): if isinstance(val, float): - valstr = '%-8.3g' % (val,) + # Align left + val_str = f'{val:<8.3g}' else: - valstr = str(val) - key2str[self._truncate(key)] = self._truncate(valstr) + val_str = str(val) + key2str[self._truncate(key)] = self._truncate(val_str) # Find max widths if len(key2str) == 0: warnings.warn('Tried to write empty key-value dict') return else: - keywidth = max(map(len, key2str.keys())) - valwidth = max(map(len, key2str.values())) + key_width = max(map(len, key2str.keys())) + val_width = max(map(len, key2str.values())) # Write out the data - dashes = '-' * (keywidth + valwidth + 7) + dashes = '-' * (key_width + val_width + 7) lines = [dashes] for (key, val) in sorted(key2str.items()): - lines.append('| %s%s | %s%s |' % ( - key, - ' ' * (keywidth - len(key)), - val, - ' ' * (valwidth - len(val)), - )) + key_space = ' ' * (key_width - len(key)) + val_space = ' ' * (val_width - len(val)) + lines.append(f"| {key}{key_space} | {val}{val_space} |") lines.append(dashes) self.file.write('\n'.join(lines) + '\n') @@ -92,10 +96,10 @@ class HumanOutputFormat(KVWriter, SeqWriter): self.file.flush() @classmethod - def _truncate(cls, string): + def _truncate(cls, string: str) -> str: return string[:20] + '...' if len(string) > 23 else string - def writeseq(self, seq): + def writeseq(self, seq: List) -> None: seq = list(seq) for (i, elem) in enumerate(seq): self.file.write(elem) @@ -104,7 +108,7 @@ class HumanOutputFormat(KVWriter, SeqWriter): self.file.write('\n') self.file.flush() - def close(self): + def close(self) -> None: """ closes the file """ @@ -113,7 +117,7 @@ class HumanOutputFormat(KVWriter, SeqWriter): class JSONOutputFormat(KVWriter): - def __init__(self, filename): + def __init__(self, filename: str): """ log to a file, in the JSON format @@ -121,7 +125,7 @@ class JSONOutputFormat(KVWriter): """ self.file = open(filename, 'wt') - def writekvs(self, kvs): + def writekvs(self, kvs: Dict) -> None: for key, value in sorted(kvs.items()): if hasattr(value, 'dtype'): if value.shape == () or len(value) == 1: @@ -133,7 +137,7 @@ class JSONOutputFormat(KVWriter): self.file.write(json.dumps(kvs) + '\n') self.file.flush() - def close(self): + def close(self) -> None: """ closes the file """ @@ -141,7 +145,7 @@ class JSONOutputFormat(KVWriter): class CSVOutputFormat(KVWriter): - def __init__(self, filename): + def __init__(self, filename: str): """ log to a file, in a CSV format @@ -151,7 +155,7 @@ class CSVOutputFormat(KVWriter): self.keys = [] self.sep = ',' - def writekvs(self, kvs): + def writekvs(self, kvs: Dict) -> None: # Add our current row to the history extra_keys = kvs.keys() - self.keys if extra_keys: @@ -177,14 +181,14 @@ class CSVOutputFormat(KVWriter): self.file.write('\n') self.file.flush() - def close(self): + def close(self) -> None: """ closes the file """ self.file.close() -def valid_float_value(value): +def valid_float_value(value: Any) -> bool: """ Returns True if the value can be successfully cast into a float @@ -198,33 +202,33 @@ def valid_float_value(value): return False -def make_output_format(_format, ev_dir, log_suffix=''): +def make_output_format(_format: str, log_dir: str, log_suffix: str = '') -> KVWriter: """ return a logger for the requested format :param _format: (str) the requested format to log to ('stdout', 'log', 'json' or 'csv') - :param ev_dir: (str) the logging directory + :param log_dir: (str) the logging directory :param log_suffix: (str) the suffix for the log file - :return: (KVWrite) the logger + :return: (KVWriter) the logger """ - os.makedirs(ev_dir, exist_ok=True) + os.makedirs(log_dir, exist_ok=True) if _format == 'stdout': return HumanOutputFormat(sys.stdout) elif _format == 'log': - return HumanOutputFormat(os.path.join(ev_dir, 'log%s.txt' % log_suffix)) + return HumanOutputFormat(os.path.join(log_dir, f'log{log_suffix}.txt')) elif _format == 'json': - return JSONOutputFormat(os.path.join(ev_dir, 'progress%s.json' % log_suffix)) + return JSONOutputFormat(os.path.join(log_dir, f'progress{log_suffix}.json')) elif _format == 'csv': - return CSVOutputFormat(os.path.join(ev_dir, 'progress%s.csv' % log_suffix)) + return CSVOutputFormat(os.path.join(log_dir, f'progress{log_suffix}.csv')) else: - raise ValueError('Unknown format specified: %s' % (_format,)) + raise ValueError(f'Unknown format specified: {_format}') # ================================================================ # API # ================================================================ -def logkv(key, val): +def logkv(key: Any, val: Any) -> None: """ Log a value of some diagnostic Call this once for each diagnostic quantity, each iteration @@ -236,7 +240,7 @@ def logkv(key, val): Logger.CURRENT.logkv(key, val) -def logkv_mean(key, val): +def logkv_mean(key: Any, val: Union[int, float]) -> None: """ The same as logkv(), but if called many times, values averaged. @@ -246,7 +250,7 @@ def logkv_mean(key, val): Logger.CURRENT.logkv_mean(key, val) -def logkvs(key_values): +def logkvs(key_values: Dict) -> None: """ Log a dictionary of key-value pairs @@ -256,14 +260,14 @@ def logkvs(key_values): logkv(key, value) -def dumpkvs(): +def dumpkvs() -> None: """ Write all of the diagnostics from the current iteration """ Logger.CURRENT.dumpkvs() -def getkvs(): +def getkvs() -> Dict: """ get the key values logs @@ -272,7 +276,7 @@ def getkvs(): return Logger.CURRENT.name2val -def log(*args, level=INFO): +def log(*args, level: int = INFO) -> None: """ Write the sequence of args, with no separators, to the console and output files (if you've configured an output file). @@ -286,7 +290,7 @@ def log(*args, level=INFO): Logger.CURRENT.log(*args, level=level) -def debug(*args): +def debug(*args) -> None: """ Write the sequence of args, with no separators, to the console and output files (if you've configured an output file). @@ -297,7 +301,7 @@ def debug(*args): log(*args, level=DEBUG) -def info(*args): +def info(*args) -> None: """ Write the sequence of args, with no separators, to the console and output files (if you've configured an output file). @@ -308,7 +312,7 @@ def info(*args): log(*args, level=INFO) -def warn(*args): +def warn(*args) -> None: """ Write the sequence of args, with no separators, to the console and output files (if you've configured an output file). @@ -319,7 +323,7 @@ def warn(*args): log(*args, level=WARN) -def error(*args): +def error(*args) -> None: """ Write the sequence of args, with no separators, to the console and output files (if you've configured an output file). @@ -330,7 +334,7 @@ def error(*args): log(*args, level=ERROR) -def set_level(level): +def set_level(level: int) -> None: """ Set logging threshold on current logger. @@ -339,7 +343,7 @@ def set_level(level): Logger.CURRENT.set_level(level) -def get_level(): +def get_level() -> int: """ Get logging threshold on current logger. :return: (int) the logging level (can be DEBUG=10, INFO=20, WARN=30, ERROR=40, DISABLED=50) @@ -347,7 +351,7 @@ def get_level(): return Logger.CURRENT.level -def get_dir(): +def get_dir() -> str: """ Get directory that log files are being written to. will be None if there is no output directory (i.e., if you didn't call start) @@ -371,7 +375,7 @@ class Logger(object): DEFAULT = None CURRENT = None # Current logger being used by the free functions above - def __init__(self, folder, output_formats): + def __init__(self, folder: Optional[str], output_formats: List[KVWriter]): """ the logger class @@ -386,7 +390,7 @@ class Logger(object): # Logging API, forwarded # ---------------------------------------- - def logkv(self, key, val): + def logkv(self, key: Any, val: Any) -> None: """ Log a value of some diagnostic Call this once for each diagnostic quantity, each iteration @@ -397,7 +401,7 @@ class Logger(object): """ self.name2val[key] = val - def logkv_mean(self, key, val): + def logkv_mean(self, key: Any, val: Any) -> None: """ The same as logkv(), but if called many times, values averaged. @@ -411,7 +415,7 @@ class Logger(object): self.name2val[key] = oldval * cnt / (cnt + 1) + val / (cnt + 1) self.name2cnt[key] = cnt + 1 - def dumpkvs(self): + def dumpkvs(self) -> None: """ Write all of the diagnostics from the current iteration """ @@ -423,7 +427,7 @@ class Logger(object): self.name2val.clear() self.name2cnt.clear() - def log(self, *args, level=INFO): + def log(self, *args, level: int = INFO) -> None: """ Write the sequence of args, with no separators, to the console and output files (if you've configured an output file). @@ -439,7 +443,7 @@ class Logger(object): # Configuration # ---------------------------------------- - def set_level(self, level): + def set_level(self, level: int) -> None: """ Set logging threshold on current logger. @@ -447,7 +451,7 @@ class Logger(object): """ self.level = level - def get_dir(self): + def get_dir(self) -> str: """ Get directory that log files are being written to. will be None if there is no output directory (i.e., if you didn't call start) @@ -456,7 +460,7 @@ class Logger(object): """ return self.dir - def close(self): + def close(self) -> None: """ closes the file """ @@ -465,7 +469,7 @@ class Logger(object): # Misc # ---------------------------------------- - def _do_log(self, args): + def _do_log(self, args) -> None: """ log to the requested format outputs @@ -476,15 +480,17 @@ class Logger(object): fmt.writeseq(map(str, args)) +# Initialize logger Logger.DEFAULT = Logger.CURRENT = Logger(folder=None, output_formats=[HumanOutputFormat(sys.stdout)]) -def configure(folder=None, format_strs=None): +def configure(folder: Optional[str] = None, format_strs: Optional[List[str]] = None) -> None: """ configure the current logger - :param folder: (str) the save location (if None, $BASELINES_LOGDIR, if still None, tempdir/baselines-[date & time]) - :param format_strs: (list) the output logging format + :param folder: (Optional[str]) the save location + (if None, $BASELINES_LOGDIR, if still None, tempdir/baselines-[date & time]) + :param format_strs: (Optional[List[str]]) the output logging format (if None, $BASELINES_LOG_FORMAT, if still None, ['stdout', 'log', 'csv']) """ if folder is None: @@ -502,10 +508,10 @@ def configure(folder=None, format_strs=None): output_formats = [make_output_format(f, folder, log_suffix) for f in format_strs] Logger.CURRENT = Logger(folder=folder, output_formats=output_formats) - log('Logging to %s' % folder) + log(f'Logging to {folder}') -def reset(): +def reset() -> None: """ reset the current logger """ @@ -516,7 +522,7 @@ def reset(): class ScopedConfigure(object): - def __init__(self, folder=None, format_strs=None): + def __init__(self, folder: Optional[str] = None, format_strs: Optional[List[str]] = None): """ Class for using context manager while logging @@ -531,11 +537,11 @@ class ScopedConfigure(object): self.format_strs = format_strs self.prevlogger = None - def __enter__(self): + def __enter__(self) -> None: self.prevlogger = Logger.CURRENT configure(folder=self.dir, format_strs=self.format_strs) - def __exit__(self, *args): + def __exit__(self, *args) -> None: Logger.CURRENT.close() Logger.CURRENT = self.prevlogger @@ -544,27 +550,25 @@ class ScopedConfigure(object): # Readers # ================================================================ -def read_json(fname): +def read_json(filename: str) -> pandas.DataFrame: """ read a json file using pandas - :param fname: (str) the file path to read - :return: (pandas DataFrame) the data in the json + :param filename: (str) the file path to read + :return: (pandas.DataFrame) the data in the json """ - import pandas data = [] - with open(fname, 'rt') as file_handler: + with open(filename, 'rt') as file_handler: for line in file_handler: data.append(json.loads(line)) return pandas.DataFrame(data) -def read_csv(fname): +def read_csv(filename: str) -> pandas.DataFrame: """ read a csv file using pandas - :param fname: (str) the file path to read - :return: (pandas DataFrame) the data in the csv + :param filename: (str) the file path to read + :return: (pandas.DataFrame) the data in the csv """ - import pandas - return pandas.read_csv(fname, index_col=None, comment='#') + return pandas.read_csv(filename, index_col=None, comment='#') diff --git a/torchy_baselines/version.txt b/torchy_baselines/version.txt index 8413ee4..8f0916f 100644 --- a/torchy_baselines/version.txt +++ b/torchy_baselines/version.txt @@ -1 +1 @@ -0.5.0a2 +0.5.0