diff --git a/docs/guide/integrations.rst b/docs/guide/integrations.rst index 9007ade..7f21bd3 100644 --- a/docs/guide/integrations.rst +++ b/docs/guide/integrations.rst @@ -137,3 +137,56 @@ Then, in this example, we train a PPO agent to play CartPole-v1 and push it to a filename="ppo-CartPole-v1", commit_message="Added Cartpole-v1 model trained with PPO", ) + +MLFLow +====== + +If you want to use `MLFLow `_ to track your SB3 experiments, +you can adapt the following code which defines a custom logger output: + +.. code-block:: python + + import sys + from typing import Any, Dict, Tuple, Union + + import mlflow + import numpy as np + + from stable_baselines3 import SAC + from stable_baselines3.common.logger import HumanOutputFormat, KVWriter, Logger + + + class MLflowOutputFormat(KVWriter): + """ + Dumps key/value pairs into MLflow's numeric format. + """ + + def write( + self, + key_values: Dict[str, Any], + key_excluded: Dict[str, Union[str, Tuple[str, ...]]], + step: int = 0, + ) -> None: + + for (key, value), (_, excluded) in zip( + sorted(key_values.items()), sorted(key_excluded.items()) + ): + + if excluded is not None and "mlflow" in excluded: + continue + + if isinstance(value, np.ScalarType): + if not isinstance(value, str): + mlflow.log_metric(key, value, step) + + + loggers = Logger( + folder=None, + output_formats=[HumanOutputFormat(sys.stdout), MLflowOutputFormat()], + ) + + with mlflow.start_run(): + model = SAC("MlpPolicy", "Pendulum-v1", verbose=2) + # Set custom logger + model.set_logger(loggers) + model.learn(total_timesteps=10000, log_interval=1) diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index 42a1d5a..cb3a17f 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -39,6 +39,7 @@ Documentation: - Fix typo in PPO doc (@bcollazo) - Added link to PPO ICLR blog post - Added remark about breaking Markov assumption and timeout handling +- Added doc about MLFlow integration via custom logger (@git-thor) Release 1.5.0 (2022-03-25) @@ -968,4 +969,4 @@ And all the contributors: @wkirgsn @AechPro @CUN-bjy @batu @IljaAvadiev @timokau @kachayev @cleversonahum @eleurent @ac-93 @cove9988 @theDebugger811 @hsuehch @Demetrio92 @thomasgubler @IperGiove @ScheiklP @simoninithomas @armandpl @manuel-delverme @Gautam-J @gianlucadecola @buoyancy99 @caburu @xy9485 -@Gregwar @ycheng517 @quantitative-technologies @bcollazo +@Gregwar @ycheng517 @quantitative-technologies @bcollazo @git-thor diff --git a/stable_baselines3/common/logger.py b/stable_baselines3/common/logger.py index 7cc3d0a..1295e5b 100644 --- a/stable_baselines3/common/logger.py +++ b/stable_baselines3/common/logger.py @@ -17,6 +17,7 @@ try: except ImportError: SummaryWriter = None + DEBUG = 10 INFO = 20 WARN = 30 @@ -246,12 +247,13 @@ def filter_excluded_keys( class JSONOutputFormat(KVWriter): - def __init__(self, filename: str): - """ - log to a file, in the JSON format + """ + Log to a file, in the JSON format - :param filename: the file to write the log to - """ + :param filename: the file to write the log to + """ + + def __init__(self, filename: str): self.file = open(filename, "wt") def write(self, key_values: Dict[str, Any], key_excluded: Dict[str, Union[str, Tuple[str, ...]]], step: int = 0) -> None: @@ -287,13 +289,13 @@ class JSONOutputFormat(KVWriter): class CSVOutputFormat(KVWriter): + """ + Log to a file, in a CSV format + + :param filename: the file to write the log to + """ + def __init__(self, filename: str): - """ - log to a file, in a CSV format - - :param filename: the file to write the log to - """ - self.file = open(filename, "w+t") self.keys = [] self.separator = "," @@ -351,12 +353,13 @@ class CSVOutputFormat(KVWriter): class TensorBoardOutputFormat(KVWriter): - def __init__(self, folder: str): - """ - Dumps key/value pairs into TensorBoard's numeric format. + """ + Dumps key/value pairs into TensorBoard's numeric format. - :param folder: the folder to write the log to - """ + :param folder: the folder to write the log to + """ + + def __init__(self, folder: str): assert SummaryWriter is not None, "tensorboard is not installed, you can use " "pip install tensorboard to do so" self.writer = SummaryWriter(log_dir=folder)