stable-baselines3/docs/guide/tensorboard.rst

353 lines
13 KiB
ReStructuredText
Raw Normal View History

.. _tensorboard:
Tensorboard Integration
=======================
Basic Usage
------------
To use Tensorboard with stable baselines3, you simply need to pass the location of the log folder to the RL agent:
.. code-block:: python
from stable_baselines3 import A2C
model = A2C("MlpPolicy", "CartPole-v1", verbose=1, tensorboard_log="./a2c_cartpole_tensorboard/")
model.learn(total_timesteps=10_000)
You can also define custom logging name when training (by default it is the algorithm name)
.. code-block:: python
from stable_baselines3 import A2C
model = A2C("MlpPolicy", "CartPole-v1", verbose=1, tensorboard_log="./a2c_cartpole_tensorboard/")
model.learn(total_timesteps=10_000, tb_log_name="first_run")
# Pass reset_num_timesteps=False to continue the training curve in tensorboard
# By default, it will create a new curve
# Keep tb_log_name constant to have continuous curve (see note below)
model.learn(total_timesteps=10_000, tb_log_name="second_run", reset_num_timesteps=False)
model.learn(total_timesteps=10_000, tb_log_name="third_run", reset_num_timesteps=False)
.. note::
If you specify different ``tb_log_name`` in subsequent runs, you will have split graphs, like in the figure below.
If you want them to be continuous, you must keep the same ``tb_log_name`` (see `issue #975 <https://github.com/DLR-RM/stable-baselines3/issues/975#issuecomment-1198992211>`_).
And, if you still managed to get your graphs split by other means, just put tensorboard log files into the same folder.
.. image:: ../_static/img/split_graph.png
:width: 330
:alt: split_graph
Once the learn function is called, you can monitor the RL agent during or after the training, with the following bash command:
.. code-block:: bash
tensorboard --logdir ./a2c_cartpole_tensorboard/
.. note::
You can find explanations about the logger output and names in the :ref:`Logger <logger>` section.
you can also add past logging folders:
.. code-block:: bash
tensorboard --logdir ./a2c_cartpole_tensorboard/;./ppo2_cartpole_tensorboard/
It will display information such as the episode reward (when using a ``Monitor`` wrapper), the model losses and other parameter unique to some models.
.. image:: ../_static/img/Tensorboard_example.png
:width: 600
:alt: plotting
Logging More Values
-------------------
Using a callback, you can easily log more values with TensorBoard.
Here is a simple example on how to log both additional tensor or arbitrary scalar value:
.. code-block:: python
import numpy as np
from stable_baselines3 import SAC
from stable_baselines3.common.callbacks import BaseCallback
model = SAC("MlpPolicy", "Pendulum-v1", tensorboard_log="/tmp/sac/", verbose=1)
class TensorboardCallback(BaseCallback):
"""
Custom callback for plotting additional values in tensorboard.
"""
def __init__(self, verbose=0):
super().__init__(verbose)
def _on_step(self) -> bool:
# Log scalar value (here a random variable)
value = np.random.random()
self.logger.record("random_value", value)
return True
model.learn(50000, callback=TensorboardCallback())
.. note::
If you want to log values more often than the default to tensorboard, you manually call ``self.logger.dump(self.num_timesteps)`` in a callback
(see `issue #506 <https://github.com/DLR-RM/stable-baselines3/issues/506>`_).
Logging Images
--------------
TensorBoard supports periodic logging of image data, which helps evaluating agents at various stages during training.
.. warning::
To support image logging `pillow <https://github.com/python-pillow/Pillow>`_ must be installed otherwise, TensorBoard ignores the image and logs a warning.
Here is an example of how to render an image to TensorBoard at regular intervals:
.. code-block:: python
from stable_baselines3 import SAC
from stable_baselines3.common.callbacks import BaseCallback
from stable_baselines3.common.logger import Image
model = SAC("MlpPolicy", "Pendulum-v1", tensorboard_log="/tmp/sac/", verbose=1)
class ImageRecorderCallback(BaseCallback):
def __init__(self, verbose=0):
super().__init__(verbose)
def _on_step(self):
image = self.training_env.render(mode="rgb_array")
# "HWC" specify the dataformat of the image, here channel last
# (H for height, W for width, C for channel)
# See https://pytorch.org/docs/stable/tensorboard.html
# for supported formats
self.logger.record("trajectory/image", Image(image, "HWC"), exclude=("stdout", "log", "json", "csv"))
return True
model.learn(50000, callback=ImageRecorderCallback())
Logging Figures/Plots
---------------------
TensorBoard supports periodic logging of figures/plots created with matplotlib, which helps evaluate agents at various stages during training.
.. warning::
To support figure logging `matplotlib <https://matplotlib.org/>`_ must be installed otherwise, TensorBoard ignores the figure and logs a warning.
Here is an example of how to store a plot in TensorBoard at regular intervals:
.. code-block:: python
import numpy as np
import matplotlib.pyplot as plt
from stable_baselines3 import SAC
from stable_baselines3.common.callbacks import BaseCallback
from stable_baselines3.common.logger import Figure
model = SAC("MlpPolicy", "Pendulum-v1", tensorboard_log="/tmp/sac/", verbose=1)
class FigureRecorderCallback(BaseCallback):
def __init__(self, verbose=0):
super().__init__(verbose)
def _on_step(self):
# Plot values (here a random variable)
figure = plt.figure()
figure.add_subplot().plot(np.random.random(3))
# Close the figure after logging it
self.logger.record("trajectory/figure", Figure(figure, close=True), exclude=("stdout", "log", "json", "csv"))
plt.close()
return True
model.learn(50000, callback=FigureRecorderCallback())
Logging Videos
--------------
TensorBoard supports periodic logging of video data, which helps evaluate agents at various stages during training.
.. warning::
To support video logging `moviepy <https://zulko.github.io/moviepy/>`_ must be installed otherwise, TensorBoard ignores the video and logs a warning.
Here is an example of how to render an episode and log the resulting video to TensorBoard at regular intervals:
.. code-block:: python
from typing import Any, Dict
Add Gymnasium support (#1327) * Fix failing set_env test * Fix test failiing due to deprectation of env.seed * Adjust mean reward threshold in failing test * Fix her test failing due to rng * Change seed and revert reward threshold to 90 * Pin gym version * Make VecEnv compatible with gym seeding change * Revert change to VecEnv reset signature * Change subprocenv seed cmd to call reset instead * Fix type check * Add backward compat * Add `compat_gym_seed` helper * Add goal env checks in env_checker * Add docs on HER requirements for envs * Capture user warning in test with inverted box space * Update ale-py version * Fix randint * Allow noop_max to be zero * Update changelog * Update docker image * Update doc conda env and dockerfile * Custom envs should not have any warnings * Fix test for numpy >= 1.21 * Add check for vectorized compute reward * Bump to gym 0.24 * Fix gym default step docstring * Test downgrading gym * Revert "Test downgrading gym" This reverts commit 0072b77156c006ada8a1d6e26ce347ed85a83eeb. * Fix protobuf error * Fix in dependencies * Fix protobuf dep * Use newest version of cartpole * Update gym * Fix warning * Loosen required scipy version * Scipy no longer needed * Try gym 0.25 * Silence warnings from gym * Filter warnings during tests * Update doc * Update requirements * Add gym 26 compat in vec env * Fixes in envs and tests for gym 0.26+ * Enforce gym 0.26 api * format * Fix formatting * Fix dependencies * Fix syntax * Cleanup doc and warnings * Faster tests * Higher budget for HER perf test (revert prev change) * Fixes and update doc * Fix doc build * Fix breaking change * Fixes for rendering * Rename variables in monitor * update render method for gym 0.26 API backwards compatible (mode argument is allowed) while using the gym 0.26 API (render mode is determined at environment creation) * update tests and docs to new gym render API * undo removal of render modes metatadata check * set rgb_array as default render mode for gym.make * undo changes & raise warning if not 'rgb_array' * Fix type check * Remove recursion and fix type checking * Remove hacks for protobuf and gym 0.24 * Fix type annotations * reuse existing render_mode attribute * return tiled images for 'human' render mode * Allow to use opencv for human render, fix typos * Add warning when using non-zero start with Discrete (fixes #1197) * Fix type checking * Bug fixes and handle more cases * Throw proper warnings * Update test * Fix new metadata name * Ignore numpy warnings * Fixes in vec recorder * Global ignore * Filter local warning too * Monkey patch not needed for gym 26 * Add doc of VecEnv vs Gym API * Add render test * Fix return type * Update VecEnv vs Gym API doc * Fix for custom render mode * Fix return type * Fix type checking * check test env test_buffer * skip render check * check env test_dict_env * test_env test_gae * check envs in remaining tests * Update tests * Add warning for Discrete action space with non-zero (#1295) * Fix atari annotation * ignore get_action_meanings [attr-defined] * Fix mypy issues * Add patch for gym/gymnasium transition * Switch to gymnasium * Rely on signature instead of version * More patches * Type ignore because of https://github.com/Farama-Foundation/Gymnasium/pull/39 * Fix doc build * Fix pytype errors * Fix atari requirement * Update env checker due to change in dtype for Discrete * Fix type hint * Convert spaces for saved models * Ignore pytype * Remove gitlab CI * Disable pytype for convert space * Fix undefined info * Fix undefined info * Upgrade shimmy * Fix wrappers type annotation (need PR from Gymnasium) * Fix gymnasium dependency * Fix dependency declaration * Cap pygame version for python 3.7 * Point to master branch (v0.28.0) * Fix: use main not master branch * Rename done to terminated * Fix pygame dependency for python 3.7 * Rename gym to gymnasium * Update Gymnasium * Fix test * Fix tests * Forks don't have access to private variables * Fix linter warnings * Update read the doc env * Fix env checker for GoalEnv * Fix import * Update env checker (more info) and fix dtype * Use micromamab for Docker * Update dependencies * Clarify VecEnv doc * Fix Gymnasium version * Copy file only after mamba install * [ci skip] Update docker doc * Polish code * Reformat * Remove deprecated features * Ignore warning * Update doc * Update examples and changelog * Fix type annotation bundle (SAC, TD3, A2C, PPO, base class) (#1436) * Fix SAC type hints, improve DQN ones * Fix A2C and TD3 type hints * Fix PPO type hints * Fix on-policy type hints * Fix base class type annotation, do not use defaults * Update version * Disable mypy for python 3.7 * Rename Gym26StepReturn * Update continuous critic type annotation * Fix pytype complain --------- Co-authored-by: Carlos Luis <carlos.luisgonc@gmail.com> Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com> Co-authored-by: Thomas Lips <37955681+tlpss@users.noreply.github.com> Co-authored-by: tlips <thomas.lips@ugent.be> Co-authored-by: tlpss <thomas17.lips@gmail.com> Co-authored-by: Quentin GALLOUÉDEC <gallouedec.quentin@gmail.com>
2023-04-14 11:13:59 +00:00
import gymnasium as gym
import torch as th
import numpy as np
from stable_baselines3 import A2C
from stable_baselines3.common.callbacks import BaseCallback
from stable_baselines3.common.evaluation import evaluate_policy
from stable_baselines3.common.logger import Video
class VideoRecorderCallback(BaseCallback):
def __init__(self, eval_env: gym.Env, render_freq: int, n_eval_episodes: int = 1, deterministic: bool = True):
"""
Records a video of an agent's trajectory traversing ``eval_env`` and logs it to TensorBoard
:param eval_env: A gym environment from which the trajectory is recorded
:param render_freq: Render the agent's trajectory every eval_freq call of the callback.
:param n_eval_episodes: Number of episodes to render
:param deterministic: Whether to use deterministic or stochastic policy
"""
super().__init__()
self._eval_env = eval_env
self._render_freq = render_freq
self._n_eval_episodes = n_eval_episodes
self._deterministic = deterministic
def _on_step(self) -> bool:
if self.n_calls % self._render_freq == 0:
screens = []
def grab_screens(_locals: Dict[str, Any], _globals: Dict[str, Any]) -> None:
"""
Renders the environment in its current state, recording the screen in the captured `screens` list
:param _locals: A dictionary containing all local variables of the callback's scope
:param _globals: A dictionary containing all global variables of the callback's scope
"""
# We expect `render()` to return a uint8 array with values in [0, 255] or a float array
# with values in [0, 1], as described in
# https://pytorch.org/docs/stable/tensorboard.html#torch.utils.tensorboard.writer.SummaryWriter.add_video
screen = self._eval_env.render(mode="rgb_array")
# PyTorch uses CxHxW vs HxWxC gym (and tensorflow) image convention
screens.append(screen.transpose(2, 0, 1))
evaluate_policy(
self.model,
self._eval_env,
callback=grab_screens,
n_eval_episodes=self._n_eval_episodes,
deterministic=self._deterministic,
)
self.logger.record(
"trajectory/video",
Video(th.from_numpy(np.asarray([screens])), fps=40),
exclude=("stdout", "log", "json", "csv"),
)
return True
model = A2C("MlpPolicy", "CartPole-v1", tensorboard_log="runs/", verbose=1)
video_recorder = VideoRecorderCallback(gym.make("CartPole-v1"), render_freq=5000)
model.learn(total_timesteps=int(5e4), callback=video_recorder)
Logging Hyperparameters
-----------------------
TensorBoard supports logging of hyperparameters in its HPARAMS tab, which helps to compare agents trainings.
.. warning::
To display hyperparameters in the HPARAMS section, a ``metric_dict`` must be given (as well as a ``hparam_dict``).
Here is an example of how to save hyperparameters in TensorBoard:
.. code-block:: python
from stable_baselines3 import A2C
from stable_baselines3.common.callbacks import BaseCallback
from stable_baselines3.common.logger import HParam
class HParamCallback(BaseCallback):
"""
Saves the hyperparameters and metrics at the start of the training, and logs them to TensorBoard.
"""
def _on_training_start(self) -> None:
hparam_dict = {
"algorithm": self.model.__class__.__name__,
"learning rate": self.model.learning_rate,
"gamma": self.model.gamma,
}
# define the metrics that will appear in the `HPARAMS` Tensorboard tab by referencing their tag
# Tensorbaord will find & display metrics from the `SCALARS` tab
metric_dict = {
"rollout/ep_len_mean": 0,
"train/value_loss": 0.0,
}
self.logger.record(
"hparams",
HParam(hparam_dict, metric_dict),
exclude=("stdout", "log", "json", "csv"),
)
def _on_step(self) -> bool:
return True
model = A2C("MlpPolicy", "CartPole-v1", tensorboard_log="runs/", verbose=1)
model.learn(total_timesteps=int(5e4), callback=HParamCallback())
Directly Accessing The Summary Writer
-------------------------------------
If you would like to log arbitrary data (in one of the formats supported by `pytorch <https://pytorch.org/docs/stable/tensorboard.html>`_), you
can get direct access to the underlying SummaryWriter in a callback:
.. warning::
This is method is not recommended and should only be used by advanced users.
.. note::
If you want a concrete example, you can watch `how to log lap time with donkeycar env <https://www.youtube.com/watch?v=v8j2bpcE4Rg&t=4619s>`_,
or read the code in the `RL Zoo <https://github.com/DLR-RM/rl-baselines3-zoo/blob/feat/gym-donkeycar/rl_zoo3/callbacks.py#L251-L270>`_.
You might also want to take a look at `issue #1160 <https://github.com/DLR-RM/stable-baselines3/issues/1160>`_ and `issue #1219 <https://github.com/DLR-RM/stable-baselines3/issues/1219>`_.
.. code-block:: python
from stable_baselines3 import SAC
from stable_baselines3.common.callbacks import BaseCallback
from stable_baselines3.common.logger import TensorBoardOutputFormat
model = SAC("MlpPolicy", "Pendulum-v1", tensorboard_log="/tmp/sac/", verbose=1)
class SummaryWriterCallback(BaseCallback):
def _on_training_start(self):
self._log_freq = 1000 # log every 1000 calls
output_formats = self.logger.output_formats
# Save reference to tensorboard formatter object
# note: the failure case (not formatter found) is not handled here, should be done with try/except.
self.tb_formatter = next(formatter for formatter in output_formats if isinstance(formatter, TensorBoardOutputFormat))
def _on_step(self) -> bool:
if self.n_calls % self._log_freq == 0:
# You can have access to info from the env using self.locals.
# for instance, when using one env (index 0 of locals["infos"]):
# lap_count = self.locals["infos"][0]["lap_count"]
# self.tb_formatter.writer.add_scalar("train/lap_count", lap_count, self.num_timesteps)
self.tb_formatter.writer.add_text("direct_access", "this is a value", self.num_timesteps)
self.tb_formatter.writer.flush()
model.learn(50000, callback=SummaryWriterCallback())