diff --git a/docs/guide/tensorboard.rst b/docs/guide/tensorboard.rst index 788ccc1..5f7b305 100644 --- a/docs/guide/tensorboard.rst +++ b/docs/guide/tensorboard.rst @@ -225,3 +225,42 @@ Here is an example of how to render an episode and log the resulting video to Te model = A2C("MlpPolicy", "CartPole-v1", tensorboard_log="runs/", verbose=1) video_recorder = VideoRecorderCallback(gym.make("CartPole-v1"), render_freq=5000) model.learn(total_timesteps=int(5e4), callback=video_recorder) + + +Directly Accessing The Summary Writer +------------------------------------- + +If you would like to log arbitrary data (in one of the formats supported by `pytorch `_), you +can get direct access to the underlying SummaryWriter in a callback: + +.. warning:: + This is method is not recommended and should only be used by advanced users. + +.. code-block:: python + + from stable_baselines3 import SAC + from stable_baselines3.common.callbacks import BaseCallback + from stable_baselines3.common.logger import TensorBoardOutputFormat + + + + model = SAC("MlpPolicy", "Pendulum-v0", tensorboard_log="/tmp/sac/", verbose=1) + + + class SummaryWriterCallback(BaseCallback): + + def _on_training_start(self): + self._log_freq = 1000 # log every 1000 calls + + output_formats = self.logger.Logger.CURRENT.output_formats + # Save reference to tensorboard formatter object + # note: the failure case (not formatter found) is not handled here, should be done with try/except. + self.tb_formatter = next(formatter for formatter in output_formats if isinstance(formatter, TensorBoardOutputFormat)) + + def _on_step(self) -> bool: + if self.n_calls % self._log_freq == 0: + self.tb_formatter.writer.add_text("direct_access", "this is a value", self.num_timesteps) + self.tb_formatter.writer.flush() + + + model.learn(50000, callback=SummaryWriterCallback()) diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index 1fdee75..3ede50f 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -27,6 +27,7 @@ New Features: - Wrap the environments automatically with a ``Monitor`` wrapper when possible. - ``EvalCallback`` now logs the success rate when available (``is_success`` must be present in the info dict) - Added new wrappers to log images and matplotlib figures to tensorboard. (@zampanteymedio) +- Add support for text records to ``Logger``. (@lorenz-h) Bug Fixes: ^^^^^^^^^^ @@ -67,8 +68,10 @@ Documentation: - Fix docstring of classes in atari_wrappers.py which were inside the constructor (@LucasAlegre) - Added SB3-Contrib page - Fix bug in the example code of DQN (@AptX395) +- Add example on how to access the tensorboard summary writer directly. (@lorenz-h) - Updated migration guide + Pre-Release 0.10.0 (2020-10-28) ------------------------------- @@ -542,4 +545,4 @@ And all the contributors: @flodorner @KuKuXia @NeoExtended @PartiallyTyped @mmcenta @richardwu @kinalmehta @rolandgvc @tkelestemur @mloo3 @tirafesi @blurLake @koulakis @joeljosephjin @shwang @rk37 @andyshih12 @RaphaelWag @xicocaio @diditforlulz273 @liorcohen5 @ManifoldFR @mloo3 @SwamyDev @wmmc88 @megan-klaiber @thisray -@tfederico @hn2 @LucasAlegre @AptX395 @zampanteymedio @decodyng @ardabbour +@tfederico @hn2 @LucasAlegre @AptX395 @zampanteymedio @decodyng @ardabbour @lorenz-h diff --git a/stable_baselines3/common/logger.py b/stable_baselines3/common/logger.py index 10283ec..3d6b458 100644 --- a/stable_baselines3/common/logger.py +++ b/stable_baselines3/common/logger.py @@ -137,16 +137,16 @@ class HumanOutputFormat(KVWriter, SeqWriter): if excluded is not None and ("stdout" in excluded or "log" in excluded): continue - if isinstance(value, Video): + elif isinstance(value, Video): raise FormatUnsupportedError(["stdout", "log"], "video") - if isinstance(value, Figure): + elif isinstance(value, Figure): raise FormatUnsupportedError(["stdout", "log"], "figure") - if isinstance(value, Image): + elif isinstance(value, Image): raise FormatUnsupportedError(["stdout", "log"], "image") - if isinstance(value, float): + elif isinstance(value, float): # Align left value_str = f"{value:<8.3g}" else: @@ -273,6 +273,7 @@ class CSVOutputFormat(KVWriter): self.file = open(filename, "w+t") self.keys = [] self.separator = "," + self.quotechar = '"' def write(self, key_values: Dict[str, Any], key_excluded: Dict[str, Union[str, Tuple[str, ...]]], step: int = 0) -> None: # Add our current row to the history @@ -300,13 +301,20 @@ class CSVOutputFormat(KVWriter): if isinstance(value, Video): raise FormatUnsupportedError(["csv"], "video") - if isinstance(value, Figure): + elif isinstance(value, Figure): raise FormatUnsupportedError(["csv"], "figure") - if isinstance(value, Image): + elif isinstance(value, Image): raise FormatUnsupportedError(["csv"], "image") - if value is not None: + elif isinstance(value, str): + # escape quotechars by prepending them with another quotechar + value = value.replace(self.quotechar, self.quotechar + self.quotechar) + + # additionally wrap text with quotechars so that any delimiters in the text are ignored by csv readers + self.file.write(self.quotechar + value + self.quotechar) + + elif value is not None: self.file.write(str(value)) self.file.write("\n") self.file.flush() @@ -336,7 +344,11 @@ class TensorBoardOutputFormat(KVWriter): continue if isinstance(value, np.ScalarType): - self.writer.add_scalar(key, value, step) + if isinstance(value, str): + # str is considered a np.ScalarType + self.writer.add_text(key, value, step) + else: + self.writer.add_scalar(key, value, step) if isinstance(value, th.Tensor): self.writer.add_histogram(key, value, step) diff --git a/tests/test_logger.py b/tests/test_logger.py index c1cce85..74a52e5 100644 --- a/tests/test_logger.py +++ b/tests/test_logger.py @@ -8,6 +8,7 @@ from pandas.errors import EmptyDataError from stable_baselines3.common.logger import ( DEBUG, + INFO, Figure, FormatUnsupportedError, Image, @@ -17,6 +18,9 @@ from stable_baselines3.common.logger import ( debug, dump, error, + get_dir, + get_level, + get_log_dict, info, make_output_format, read_csv, @@ -37,6 +41,7 @@ KEY_VALUES = { "a": np.array([1, 2, 3]), "f": np.array(1), "g": np.array([[[1]]]), + "h": 'this ", ;is a \n tes:,t', } KEY_EXCLUDED = {} @@ -104,9 +109,12 @@ def test_main(tmp_path): """ info("hi") debug("shouldn't appear") + assert get_level() == INFO set_level(DEBUG) + assert get_level() == DEBUG debug("should appear") configure(folder=str(tmp_path)) + assert get_dir() == str(tmp_path) record("a", 3) record("b", 2.5) dump() @@ -114,6 +122,9 @@ def test_main(tmp_path): record("a", 5.5) dump() info("^^^ should see a = 5.5") + record("f", "this text \n \r should appear in one line") + dump() + info('^^^ should see f = "this text \n \r should appear in one line"') record_mean("b", -22.5) record_mean("b", -44.4) record("a", 5.5) @@ -131,6 +142,7 @@ def test_main(tmp_path): warn("hey") error("oh") record_dict({"test": 1}) + assert isinstance(get_log_dict(), dict) and set(get_log_dict().keys()) == {"test"} @pytest.mark.parametrize("_format", ["stdout", "log", "json", "csv", "tensorboard"]) @@ -159,6 +171,7 @@ def test_make_output_fail(tmp_path): @pytest.mark.parametrize("_format", ["stdout", "log", "json", "csv", "tensorboard"]) +@pytest.mark.filterwarnings("ignore:Tried to write empty key-value dict") def test_exclude_keys(tmp_path, read_log, _format): if _format == "tensorboard": # Skip if no tensorboard installed