mirror of
https://github.com/saymrwulf/stable-baselines3.git
synced 2026-05-27 22:55:17 +00:00
Add Support for Text Records to Logger, Add Hint on How To Access SummaryWriter in Docs. (#303)
* add support for text records to logger * add note on how to access summary writer directly * escape unicode chars for HumanOutputFormat * update changelog * fix formatting * fix docs * add tests * fix formatting * fix example, link to pytorch docs, update changelog * move unicode escaping to own function, properly escape quotechars in csv formatter * switch from n_calls to num_timesteps in example * make step coherent in example * use n_calls to check when to login example * add small hint about log frequency Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org> * add comment about str is scalar type, improve test input * Update tests * Update test_logger.py * use repr to handle strings in logger * remove repr from text log output Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org>
This commit is contained in:
parent
c722c4f5bd
commit
b01bde3e2d
4 changed files with 76 additions and 9 deletions
|
|
@ -225,3 +225,42 @@ Here is an example of how to render an episode and log the resulting video to Te
|
|||
model = A2C("MlpPolicy", "CartPole-v1", tensorboard_log="runs/", verbose=1)
|
||||
video_recorder = VideoRecorderCallback(gym.make("CartPole-v1"), render_freq=5000)
|
||||
model.learn(total_timesteps=int(5e4), callback=video_recorder)
|
||||
|
||||
|
||||
Directly Accessing The Summary Writer
|
||||
-------------------------------------
|
||||
|
||||
If you would like to log arbitrary data (in one of the formats supported by `pytorch <https://pytorch.org/docs/stable/tensorboard.html>`_), you
|
||||
can get direct access to the underlying SummaryWriter in a callback:
|
||||
|
||||
.. warning::
|
||||
This is method is not recommended and should only be used by advanced users.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from stable_baselines3 import SAC
|
||||
from stable_baselines3.common.callbacks import BaseCallback
|
||||
from stable_baselines3.common.logger import TensorBoardOutputFormat
|
||||
|
||||
|
||||
|
||||
model = SAC("MlpPolicy", "Pendulum-v0", tensorboard_log="/tmp/sac/", verbose=1)
|
||||
|
||||
|
||||
class SummaryWriterCallback(BaseCallback):
|
||||
|
||||
def _on_training_start(self):
|
||||
self._log_freq = 1000 # log every 1000 calls
|
||||
|
||||
output_formats = self.logger.Logger.CURRENT.output_formats
|
||||
# Save reference to tensorboard formatter object
|
||||
# note: the failure case (not formatter found) is not handled here, should be done with try/except.
|
||||
self.tb_formatter = next(formatter for formatter in output_formats if isinstance(formatter, TensorBoardOutputFormat))
|
||||
|
||||
def _on_step(self) -> bool:
|
||||
if self.n_calls % self._log_freq == 0:
|
||||
self.tb_formatter.writer.add_text("direct_access", "this is a value", self.num_timesteps)
|
||||
self.tb_formatter.writer.flush()
|
||||
|
||||
|
||||
model.learn(50000, callback=SummaryWriterCallback())
|
||||
|
|
|
|||
|
|
@ -27,6 +27,7 @@ New Features:
|
|||
- Wrap the environments automatically with a ``Monitor`` wrapper when possible.
|
||||
- ``EvalCallback`` now logs the success rate when available (``is_success`` must be present in the info dict)
|
||||
- Added new wrappers to log images and matplotlib figures to tensorboard. (@zampanteymedio)
|
||||
- Add support for text records to ``Logger``. (@lorenz-h)
|
||||
|
||||
Bug Fixes:
|
||||
^^^^^^^^^^
|
||||
|
|
@ -67,8 +68,10 @@ Documentation:
|
|||
- Fix docstring of classes in atari_wrappers.py which were inside the constructor (@LucasAlegre)
|
||||
- Added SB3-Contrib page
|
||||
- Fix bug in the example code of DQN (@AptX395)
|
||||
- Add example on how to access the tensorboard summary writer directly. (@lorenz-h)
|
||||
- Updated migration guide
|
||||
|
||||
|
||||
Pre-Release 0.10.0 (2020-10-28)
|
||||
-------------------------------
|
||||
|
||||
|
|
@ -542,4 +545,4 @@ And all the contributors:
|
|||
@flodorner @KuKuXia @NeoExtended @PartiallyTyped @mmcenta @richardwu @kinalmehta @rolandgvc @tkelestemur @mloo3
|
||||
@tirafesi @blurLake @koulakis @joeljosephjin @shwang @rk37 @andyshih12 @RaphaelWag @xicocaio
|
||||
@diditforlulz273 @liorcohen5 @ManifoldFR @mloo3 @SwamyDev @wmmc88 @megan-klaiber @thisray
|
||||
@tfederico @hn2 @LucasAlegre @AptX395 @zampanteymedio @decodyng @ardabbour
|
||||
@tfederico @hn2 @LucasAlegre @AptX395 @zampanteymedio @decodyng @ardabbour @lorenz-h
|
||||
|
|
|
|||
|
|
@ -137,16 +137,16 @@ class HumanOutputFormat(KVWriter, SeqWriter):
|
|||
if excluded is not None and ("stdout" in excluded or "log" in excluded):
|
||||
continue
|
||||
|
||||
if isinstance(value, Video):
|
||||
elif isinstance(value, Video):
|
||||
raise FormatUnsupportedError(["stdout", "log"], "video")
|
||||
|
||||
if isinstance(value, Figure):
|
||||
elif isinstance(value, Figure):
|
||||
raise FormatUnsupportedError(["stdout", "log"], "figure")
|
||||
|
||||
if isinstance(value, Image):
|
||||
elif isinstance(value, Image):
|
||||
raise FormatUnsupportedError(["stdout", "log"], "image")
|
||||
|
||||
if isinstance(value, float):
|
||||
elif isinstance(value, float):
|
||||
# Align left
|
||||
value_str = f"{value:<8.3g}"
|
||||
else:
|
||||
|
|
@ -273,6 +273,7 @@ class CSVOutputFormat(KVWriter):
|
|||
self.file = open(filename, "w+t")
|
||||
self.keys = []
|
||||
self.separator = ","
|
||||
self.quotechar = '"'
|
||||
|
||||
def write(self, key_values: Dict[str, Any], key_excluded: Dict[str, Union[str, Tuple[str, ...]]], step: int = 0) -> None:
|
||||
# Add our current row to the history
|
||||
|
|
@ -300,13 +301,20 @@ class CSVOutputFormat(KVWriter):
|
|||
if isinstance(value, Video):
|
||||
raise FormatUnsupportedError(["csv"], "video")
|
||||
|
||||
if isinstance(value, Figure):
|
||||
elif isinstance(value, Figure):
|
||||
raise FormatUnsupportedError(["csv"], "figure")
|
||||
|
||||
if isinstance(value, Image):
|
||||
elif isinstance(value, Image):
|
||||
raise FormatUnsupportedError(["csv"], "image")
|
||||
|
||||
if value is not None:
|
||||
elif isinstance(value, str):
|
||||
# escape quotechars by prepending them with another quotechar
|
||||
value = value.replace(self.quotechar, self.quotechar + self.quotechar)
|
||||
|
||||
# additionally wrap text with quotechars so that any delimiters in the text are ignored by csv readers
|
||||
self.file.write(self.quotechar + value + self.quotechar)
|
||||
|
||||
elif value is not None:
|
||||
self.file.write(str(value))
|
||||
self.file.write("\n")
|
||||
self.file.flush()
|
||||
|
|
@ -336,7 +344,11 @@ class TensorBoardOutputFormat(KVWriter):
|
|||
continue
|
||||
|
||||
if isinstance(value, np.ScalarType):
|
||||
self.writer.add_scalar(key, value, step)
|
||||
if isinstance(value, str):
|
||||
# str is considered a np.ScalarType
|
||||
self.writer.add_text(key, value, step)
|
||||
else:
|
||||
self.writer.add_scalar(key, value, step)
|
||||
|
||||
if isinstance(value, th.Tensor):
|
||||
self.writer.add_histogram(key, value, step)
|
||||
|
|
|
|||
|
|
@ -8,6 +8,7 @@ from pandas.errors import EmptyDataError
|
|||
|
||||
from stable_baselines3.common.logger import (
|
||||
DEBUG,
|
||||
INFO,
|
||||
Figure,
|
||||
FormatUnsupportedError,
|
||||
Image,
|
||||
|
|
@ -17,6 +18,9 @@ from stable_baselines3.common.logger import (
|
|||
debug,
|
||||
dump,
|
||||
error,
|
||||
get_dir,
|
||||
get_level,
|
||||
get_log_dict,
|
||||
info,
|
||||
make_output_format,
|
||||
read_csv,
|
||||
|
|
@ -37,6 +41,7 @@ KEY_VALUES = {
|
|||
"a": np.array([1, 2, 3]),
|
||||
"f": np.array(1),
|
||||
"g": np.array([[[1]]]),
|
||||
"h": 'this ", ;is a \n tes:,t',
|
||||
}
|
||||
|
||||
KEY_EXCLUDED = {}
|
||||
|
|
@ -104,9 +109,12 @@ def test_main(tmp_path):
|
|||
"""
|
||||
info("hi")
|
||||
debug("shouldn't appear")
|
||||
assert get_level() == INFO
|
||||
set_level(DEBUG)
|
||||
assert get_level() == DEBUG
|
||||
debug("should appear")
|
||||
configure(folder=str(tmp_path))
|
||||
assert get_dir() == str(tmp_path)
|
||||
record("a", 3)
|
||||
record("b", 2.5)
|
||||
dump()
|
||||
|
|
@ -114,6 +122,9 @@ def test_main(tmp_path):
|
|||
record("a", 5.5)
|
||||
dump()
|
||||
info("^^^ should see a = 5.5")
|
||||
record("f", "this text \n \r should appear in one line")
|
||||
dump()
|
||||
info('^^^ should see f = "this text \n \r should appear in one line"')
|
||||
record_mean("b", -22.5)
|
||||
record_mean("b", -44.4)
|
||||
record("a", 5.5)
|
||||
|
|
@ -131,6 +142,7 @@ def test_main(tmp_path):
|
|||
warn("hey")
|
||||
error("oh")
|
||||
record_dict({"test": 1})
|
||||
assert isinstance(get_log_dict(), dict) and set(get_log_dict().keys()) == {"test"}
|
||||
|
||||
|
||||
@pytest.mark.parametrize("_format", ["stdout", "log", "json", "csv", "tensorboard"])
|
||||
|
|
@ -159,6 +171,7 @@ def test_make_output_fail(tmp_path):
|
|||
|
||||
|
||||
@pytest.mark.parametrize("_format", ["stdout", "log", "json", "csv", "tensorboard"])
|
||||
@pytest.mark.filterwarnings("ignore:Tried to write empty key-value dict")
|
||||
def test_exclude_keys(tmp_path, read_log, _format):
|
||||
if _format == "tensorboard":
|
||||
# Skip if no tensorboard installed
|
||||
|
|
|
|||
Loading…
Reference in a new issue