mirror of
https://github.com/saymrwulf/stable-baselines3.git
synced 2026-06-05 00:00:04 +00:00
Add doc to use mlflow logger (#889)
* ADD feature for mlflow logger via MLflowOutputFormat. * Move MLFlow integration to doc Co-authored-by: Antonin Raffin <antonin.raffin@ensta.org>
This commit is contained in:
parent
e98ae129de
commit
c2518dc160
3 changed files with 74 additions and 17 deletions
|
|
@ -137,3 +137,56 @@ Then, in this example, we train a PPO agent to play CartPole-v1 and push it to a
|
|||
filename="ppo-CartPole-v1",
|
||||
commit_message="Added Cartpole-v1 model trained with PPO",
|
||||
)
|
||||
|
||||
MLFLow
|
||||
======
|
||||
|
||||
If you want to use `MLFLow <https://github.com/mlflow/mlflow>`_ to track your SB3 experiments,
|
||||
you can adapt the following code which defines a custom logger output:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import sys
|
||||
from typing import Any, Dict, Tuple, Union
|
||||
|
||||
import mlflow
|
||||
import numpy as np
|
||||
|
||||
from stable_baselines3 import SAC
|
||||
from stable_baselines3.common.logger import HumanOutputFormat, KVWriter, Logger
|
||||
|
||||
|
||||
class MLflowOutputFormat(KVWriter):
|
||||
"""
|
||||
Dumps key/value pairs into MLflow's numeric format.
|
||||
"""
|
||||
|
||||
def write(
|
||||
self,
|
||||
key_values: Dict[str, Any],
|
||||
key_excluded: Dict[str, Union[str, Tuple[str, ...]]],
|
||||
step: int = 0,
|
||||
) -> None:
|
||||
|
||||
for (key, value), (_, excluded) in zip(
|
||||
sorted(key_values.items()), sorted(key_excluded.items())
|
||||
):
|
||||
|
||||
if excluded is not None and "mlflow" in excluded:
|
||||
continue
|
||||
|
||||
if isinstance(value, np.ScalarType):
|
||||
if not isinstance(value, str):
|
||||
mlflow.log_metric(key, value, step)
|
||||
|
||||
|
||||
loggers = Logger(
|
||||
folder=None,
|
||||
output_formats=[HumanOutputFormat(sys.stdout), MLflowOutputFormat()],
|
||||
)
|
||||
|
||||
with mlflow.start_run():
|
||||
model = SAC("MlpPolicy", "Pendulum-v1", verbose=2)
|
||||
# Set custom logger
|
||||
model.set_logger(loggers)
|
||||
model.learn(total_timesteps=10000, log_interval=1)
|
||||
|
|
|
|||
|
|
@ -39,6 +39,7 @@ Documentation:
|
|||
- Fix typo in PPO doc (@bcollazo)
|
||||
- Added link to PPO ICLR blog post
|
||||
- Added remark about breaking Markov assumption and timeout handling
|
||||
- Added doc about MLFlow integration via custom logger (@git-thor)
|
||||
|
||||
|
||||
Release 1.5.0 (2022-03-25)
|
||||
|
|
@ -968,4 +969,4 @@ And all the contributors:
|
|||
@wkirgsn @AechPro @CUN-bjy @batu @IljaAvadiev @timokau @kachayev @cleversonahum
|
||||
@eleurent @ac-93 @cove9988 @theDebugger811 @hsuehch @Demetrio92 @thomasgubler @IperGiove @ScheiklP
|
||||
@simoninithomas @armandpl @manuel-delverme @Gautam-J @gianlucadecola @buoyancy99 @caburu @xy9485
|
||||
@Gregwar @ycheng517 @quantitative-technologies @bcollazo
|
||||
@Gregwar @ycheng517 @quantitative-technologies @bcollazo @git-thor
|
||||
|
|
|
|||
|
|
@ -17,6 +17,7 @@ try:
|
|||
except ImportError:
|
||||
SummaryWriter = None
|
||||
|
||||
|
||||
DEBUG = 10
|
||||
INFO = 20
|
||||
WARN = 30
|
||||
|
|
@ -246,12 +247,13 @@ def filter_excluded_keys(
|
|||
|
||||
|
||||
class JSONOutputFormat(KVWriter):
|
||||
def __init__(self, filename: str):
|
||||
"""
|
||||
log to a file, in the JSON format
|
||||
"""
|
||||
Log to a file, in the JSON format
|
||||
|
||||
:param filename: the file to write the log to
|
||||
"""
|
||||
:param filename: the file to write the log to
|
||||
"""
|
||||
|
||||
def __init__(self, filename: str):
|
||||
self.file = open(filename, "wt")
|
||||
|
||||
def write(self, key_values: Dict[str, Any], key_excluded: Dict[str, Union[str, Tuple[str, ...]]], step: int = 0) -> None:
|
||||
|
|
@ -287,13 +289,13 @@ class JSONOutputFormat(KVWriter):
|
|||
|
||||
|
||||
class CSVOutputFormat(KVWriter):
|
||||
"""
|
||||
Log to a file, in a CSV format
|
||||
|
||||
:param filename: the file to write the log to
|
||||
"""
|
||||
|
||||
def __init__(self, filename: str):
|
||||
"""
|
||||
log to a file, in a CSV format
|
||||
|
||||
:param filename: the file to write the log to
|
||||
"""
|
||||
|
||||
self.file = open(filename, "w+t")
|
||||
self.keys = []
|
||||
self.separator = ","
|
||||
|
|
@ -351,12 +353,13 @@ class CSVOutputFormat(KVWriter):
|
|||
|
||||
|
||||
class TensorBoardOutputFormat(KVWriter):
|
||||
def __init__(self, folder: str):
|
||||
"""
|
||||
Dumps key/value pairs into TensorBoard's numeric format.
|
||||
"""
|
||||
Dumps key/value pairs into TensorBoard's numeric format.
|
||||
|
||||
:param folder: the folder to write the log to
|
||||
"""
|
||||
:param folder: the folder to write the log to
|
||||
"""
|
||||
|
||||
def __init__(self, folder: str):
|
||||
assert SummaryWriter is not None, "tensorboard is not installed, you can use " "pip install tensorboard to do so"
|
||||
self.writer = SummaryWriter(log_dir=folder)
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue