Add Support for Text Records to Logger, Add Hint on How To Access SummaryWriter in Docs. (#303)

* add support for text records to logger

* add note on how to access summary writer directly

* escape unicode chars for HumanOutputFormat

* update changelog

* fix formatting

* fix docs

* add tests

* fix formatting

* fix example, link to pytorch docs, update changelog

* move unicode escaping to own function, properly escape quotechars in csv formatter

* switch from n_calls to num_timesteps in example

* make step coherent in example

* use n_calls to check when to login example

* add small hint about log frequency

Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org>

* add comment about str is scalar type, improve test input

* Update tests

* Update test_logger.py

* use repr to handle strings in logger

* remove repr from text log output

Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org>
This commit is contained in:
Lorenz Hetzel 2021-02-01 11:56:33 +01:00 committed by GitHub
parent c722c4f5bd
commit b01bde3e2d
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 76 additions and 9 deletions

View file

@ -225,3 +225,42 @@ Here is an example of how to render an episode and log the resulting video to Te
model = A2C("MlpPolicy", "CartPole-v1", tensorboard_log="runs/", verbose=1)
video_recorder = VideoRecorderCallback(gym.make("CartPole-v1"), render_freq=5000)
model.learn(total_timesteps=int(5e4), callback=video_recorder)
Directly Accessing The Summary Writer
-------------------------------------
If you would like to log arbitrary data (in one of the formats supported by `pytorch <https://pytorch.org/docs/stable/tensorboard.html>`_), you
can get direct access to the underlying SummaryWriter in a callback:
.. warning::
This is method is not recommended and should only be used by advanced users.
.. code-block:: python
from stable_baselines3 import SAC
from stable_baselines3.common.callbacks import BaseCallback
from stable_baselines3.common.logger import TensorBoardOutputFormat
model = SAC("MlpPolicy", "Pendulum-v0", tensorboard_log="/tmp/sac/", verbose=1)
class SummaryWriterCallback(BaseCallback):
def _on_training_start(self):
self._log_freq = 1000 # log every 1000 calls
output_formats = self.logger.Logger.CURRENT.output_formats
# Save reference to tensorboard formatter object
# note: the failure case (not formatter found) is not handled here, should be done with try/except.
self.tb_formatter = next(formatter for formatter in output_formats if isinstance(formatter, TensorBoardOutputFormat))
def _on_step(self) -> bool:
if self.n_calls % self._log_freq == 0:
self.tb_formatter.writer.add_text("direct_access", "this is a value", self.num_timesteps)
self.tb_formatter.writer.flush()
model.learn(50000, callback=SummaryWriterCallback())

View file

@ -27,6 +27,7 @@ New Features:
- Wrap the environments automatically with a ``Monitor`` wrapper when possible.
- ``EvalCallback`` now logs the success rate when available (``is_success`` must be present in the info dict)
- Added new wrappers to log images and matplotlib figures to tensorboard. (@zampanteymedio)
- Add support for text records to ``Logger``. (@lorenz-h)
Bug Fixes:
^^^^^^^^^^
@ -67,8 +68,10 @@ Documentation:
- Fix docstring of classes in atari_wrappers.py which were inside the constructor (@LucasAlegre)
- Added SB3-Contrib page
- Fix bug in the example code of DQN (@AptX395)
- Add example on how to access the tensorboard summary writer directly. (@lorenz-h)
- Updated migration guide
Pre-Release 0.10.0 (2020-10-28)
-------------------------------
@ -542,4 +545,4 @@ And all the contributors:
@flodorner @KuKuXia @NeoExtended @PartiallyTyped @mmcenta @richardwu @kinalmehta @rolandgvc @tkelestemur @mloo3
@tirafesi @blurLake @koulakis @joeljosephjin @shwang @rk37 @andyshih12 @RaphaelWag @xicocaio
@diditforlulz273 @liorcohen5 @ManifoldFR @mloo3 @SwamyDev @wmmc88 @megan-klaiber @thisray
@tfederico @hn2 @LucasAlegre @AptX395 @zampanteymedio @decodyng @ardabbour
@tfederico @hn2 @LucasAlegre @AptX395 @zampanteymedio @decodyng @ardabbour @lorenz-h

View file

@ -137,16 +137,16 @@ class HumanOutputFormat(KVWriter, SeqWriter):
if excluded is not None and ("stdout" in excluded or "log" in excluded):
continue
if isinstance(value, Video):
elif isinstance(value, Video):
raise FormatUnsupportedError(["stdout", "log"], "video")
if isinstance(value, Figure):
elif isinstance(value, Figure):
raise FormatUnsupportedError(["stdout", "log"], "figure")
if isinstance(value, Image):
elif isinstance(value, Image):
raise FormatUnsupportedError(["stdout", "log"], "image")
if isinstance(value, float):
elif isinstance(value, float):
# Align left
value_str = f"{value:<8.3g}"
else:
@ -273,6 +273,7 @@ class CSVOutputFormat(KVWriter):
self.file = open(filename, "w+t")
self.keys = []
self.separator = ","
self.quotechar = '"'
def write(self, key_values: Dict[str, Any], key_excluded: Dict[str, Union[str, Tuple[str, ...]]], step: int = 0) -> None:
# Add our current row to the history
@ -300,13 +301,20 @@ class CSVOutputFormat(KVWriter):
if isinstance(value, Video):
raise FormatUnsupportedError(["csv"], "video")
if isinstance(value, Figure):
elif isinstance(value, Figure):
raise FormatUnsupportedError(["csv"], "figure")
if isinstance(value, Image):
elif isinstance(value, Image):
raise FormatUnsupportedError(["csv"], "image")
if value is not None:
elif isinstance(value, str):
# escape quotechars by prepending them with another quotechar
value = value.replace(self.quotechar, self.quotechar + self.quotechar)
# additionally wrap text with quotechars so that any delimiters in the text are ignored by csv readers
self.file.write(self.quotechar + value + self.quotechar)
elif value is not None:
self.file.write(str(value))
self.file.write("\n")
self.file.flush()
@ -336,7 +344,11 @@ class TensorBoardOutputFormat(KVWriter):
continue
if isinstance(value, np.ScalarType):
self.writer.add_scalar(key, value, step)
if isinstance(value, str):
# str is considered a np.ScalarType
self.writer.add_text(key, value, step)
else:
self.writer.add_scalar(key, value, step)
if isinstance(value, th.Tensor):
self.writer.add_histogram(key, value, step)

View file

@ -8,6 +8,7 @@ from pandas.errors import EmptyDataError
from stable_baselines3.common.logger import (
DEBUG,
INFO,
Figure,
FormatUnsupportedError,
Image,
@ -17,6 +18,9 @@ from stable_baselines3.common.logger import (
debug,
dump,
error,
get_dir,
get_level,
get_log_dict,
info,
make_output_format,
read_csv,
@ -37,6 +41,7 @@ KEY_VALUES = {
"a": np.array([1, 2, 3]),
"f": np.array(1),
"g": np.array([[[1]]]),
"h": 'this ", ;is a \n tes:,t',
}
KEY_EXCLUDED = {}
@ -104,9 +109,12 @@ def test_main(tmp_path):
"""
info("hi")
debug("shouldn't appear")
assert get_level() == INFO
set_level(DEBUG)
assert get_level() == DEBUG
debug("should appear")
configure(folder=str(tmp_path))
assert get_dir() == str(tmp_path)
record("a", 3)
record("b", 2.5)
dump()
@ -114,6 +122,9 @@ def test_main(tmp_path):
record("a", 5.5)
dump()
info("^^^ should see a = 5.5")
record("f", "this text \n \r should appear in one line")
dump()
info('^^^ should see f = "this text \n \r should appear in one line"')
record_mean("b", -22.5)
record_mean("b", -44.4)
record("a", 5.5)
@ -131,6 +142,7 @@ def test_main(tmp_path):
warn("hey")
error("oh")
record_dict({"test": 1})
assert isinstance(get_log_dict(), dict) and set(get_log_dict().keys()) == {"test"}
@pytest.mark.parametrize("_format", ["stdout", "log", "json", "csv", "tensorboard"])
@ -159,6 +171,7 @@ def test_make_output_fail(tmp_path):
@pytest.mark.parametrize("_format", ["stdout", "log", "json", "csv", "tensorboard"])
@pytest.mark.filterwarnings("ignore:Tried to write empty key-value dict")
def test_exclude_keys(tmp_path, read_log, _format):
if _format == "tensorboard":
# Skip if no tensorboard installed