Add doc to use mlflow logger (#889)

* ADD feature for mlflow logger via MLflowOutputFormat.

* Move MLFlow integration to doc

Co-authored-by: Antonin Raffin <antonin.raffin@ensta.org>
This commit is contained in:
Thomas Rudolf 2022-05-08 15:28:31 +02:00 committed by GitHub
parent e98ae129de
commit c2518dc160
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 74 additions and 17 deletions

View file

@ -137,3 +137,56 @@ Then, in this example, we train a PPO agent to play CartPole-v1 and push it to a
filename="ppo-CartPole-v1",
commit_message="Added Cartpole-v1 model trained with PPO",
)
MLFLow
======
If you want to use `MLFLow <https://github.com/mlflow/mlflow>`_ to track your SB3 experiments,
you can adapt the following code which defines a custom logger output:
.. code-block:: python
import sys
from typing import Any, Dict, Tuple, Union
import mlflow
import numpy as np
from stable_baselines3 import SAC
from stable_baselines3.common.logger import HumanOutputFormat, KVWriter, Logger
class MLflowOutputFormat(KVWriter):
"""
Dumps key/value pairs into MLflow's numeric format.
"""
def write(
self,
key_values: Dict[str, Any],
key_excluded: Dict[str, Union[str, Tuple[str, ...]]],
step: int = 0,
) -> None:
for (key, value), (_, excluded) in zip(
sorted(key_values.items()), sorted(key_excluded.items())
):
if excluded is not None and "mlflow" in excluded:
continue
if isinstance(value, np.ScalarType):
if not isinstance(value, str):
mlflow.log_metric(key, value, step)
loggers = Logger(
folder=None,
output_formats=[HumanOutputFormat(sys.stdout), MLflowOutputFormat()],
)
with mlflow.start_run():
model = SAC("MlpPolicy", "Pendulum-v1", verbose=2)
# Set custom logger
model.set_logger(loggers)
model.learn(total_timesteps=10000, log_interval=1)

View file

@ -39,6 +39,7 @@ Documentation:
- Fix typo in PPO doc (@bcollazo)
- Added link to PPO ICLR blog post
- Added remark about breaking Markov assumption and timeout handling
- Added doc about MLFlow integration via custom logger (@git-thor)
Release 1.5.0 (2022-03-25)
@ -968,4 +969,4 @@ And all the contributors:
@wkirgsn @AechPro @CUN-bjy @batu @IljaAvadiev @timokau @kachayev @cleversonahum
@eleurent @ac-93 @cove9988 @theDebugger811 @hsuehch @Demetrio92 @thomasgubler @IperGiove @ScheiklP
@simoninithomas @armandpl @manuel-delverme @Gautam-J @gianlucadecola @buoyancy99 @caburu @xy9485
@Gregwar @ycheng517 @quantitative-technologies @bcollazo
@Gregwar @ycheng517 @quantitative-technologies @bcollazo @git-thor

View file

@ -17,6 +17,7 @@ try:
except ImportError:
SummaryWriter = None
DEBUG = 10
INFO = 20
WARN = 30
@ -246,12 +247,13 @@ def filter_excluded_keys(
class JSONOutputFormat(KVWriter):
def __init__(self, filename: str):
"""
log to a file, in the JSON format
"""
Log to a file, in the JSON format
:param filename: the file to write the log to
"""
:param filename: the file to write the log to
"""
def __init__(self, filename: str):
self.file = open(filename, "wt")
def write(self, key_values: Dict[str, Any], key_excluded: Dict[str, Union[str, Tuple[str, ...]]], step: int = 0) -> None:
@ -287,13 +289,13 @@ class JSONOutputFormat(KVWriter):
class CSVOutputFormat(KVWriter):
"""
Log to a file, in a CSV format
:param filename: the file to write the log to
"""
def __init__(self, filename: str):
"""
log to a file, in a CSV format
:param filename: the file to write the log to
"""
self.file = open(filename, "w+t")
self.keys = []
self.separator = ","
@ -351,12 +353,13 @@ class CSVOutputFormat(KVWriter):
class TensorBoardOutputFormat(KVWriter):
def __init__(self, folder: str):
"""
Dumps key/value pairs into TensorBoard's numeric format.
"""
Dumps key/value pairs into TensorBoard's numeric format.
:param folder: the folder to write the log to
"""
:param folder: the folder to write the log to
"""
def __init__(self, folder: str):
assert SummaryWriter is not None, "tensorboard is not installed, you can use " "pip install tensorboard to do so"
self.writer = SummaryWriter(log_dir=folder)