stable-baselines3/docs/guide/integrations.rst
Antonin RAFFIN 8b3723c6d8
Update ruff and documentation for hf sb3 (#1866)
* Update ruff

* Only load weights with `torch.load()` to avoid security issues

* Update doc about HF integration and remote code execution

* Fix doc build

* Revert weight_only=True for policies
2024-03-11 13:53:06 +01:00

288 lines
8.4 KiB
ReStructuredText
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

.. _integrations:
============
Integrations
============
Weights & Biases
================
Weights & Biases provides a callback for experiment tracking that allows to visualize and share results.
The full documentation is available here: https://docs.wandb.ai/guides/integrations/other/stable-baselines-3
.. code-block:: python
import gymnasium as gym
import wandb
from wandb.integration.sb3 import WandbCallback
from stable_baselines3 import PPO
config = {
"policy_type": "MlpPolicy",
"total_timesteps": 25000,
"env_id": "CartPole-v1",
}
run = wandb.init(
project="sb3",
config=config,
sync_tensorboard=True, # auto-upload sb3's tensorboard metrics
# monitor_gym=True, # auto-upload the videos of agents playing the game
# save_code=True, # optional
)
model = PPO(config["policy_type"], config["env_id"], verbose=1, tensorboard_log=f"runs/{run.id}")
model.learn(
total_timesteps=config["total_timesteps"],
callback=WandbCallback(
model_save_path=f"models/{run.id}",
verbose=2,
),
)
run.finish()
Hugging Face 🤗
===============
The Hugging Face Hub 🤗 is a central place where anyone can share and explore models. It allows you to host your saved models 💾.
You can see the list of stable-baselines3 saved models here: https://huggingface.co/models?library=stable-baselines3
Most of them are available via the RL Zoo.
Official pre-trained models are saved in the SB3 organization on the hub: https://huggingface.co/sb3
We wrote a tutorial on how to use 🤗 Hub and Stable-Baselines3
`here <https://colab.research.google.com/github/huggingface/huggingface_sb3/blob/main/notebooks/sb3_huggingface.ipynb>`_.
Installation
-------------
.. code-block:: bash
pip install huggingface_sb3
.. note::
If you use the `RL Zoo <https://github.com/DLR-RM/rl-baselines3-zoo>`_, pushing/loading models from the hub are already integrated:
.. code-block:: bash
# Download model and save it into the logs/ folder
# Only use TRUST_REMOTE_CODE=True with HF models that can be trusted (here the SB3 organization)
TRUST_REMOTE_CODE=True python -m rl_zoo3.load_from_hub --algo a2c --env LunarLander-v2 -orga sb3 -f logs/
# Test the agent
python -m rl_zoo3.enjoy --algo a2c --env LunarLander-v2 -f logs/
# Push model, config and hyperparameters to the hub
python -m rl_zoo3.push_to_hub --algo a2c --env LunarLander-v2 -f logs/ -orga sb3 -m "Initial commit"
Download a model from the Hub
-----------------------------
You need to copy the repo-id that contains your saved model.
For instance ``sb3/demo-hf-CartPole-v1``:
.. code-block:: python
import os
import gymnasium as gym
from huggingface_sb3 import load_from_hub
from stable_baselines3 import PPO
from stable_baselines3.common.evaluation import evaluate_policy
# Allow the use of `pickle.load()` when downloading model from the hub
# Please make sure that the organization from which you download can be trusted
os.environ["TRUST_REMOTE_CODE"] = "True"
# Retrieve the model from the hub
## repo_id = id of the model repository from the Hugging Face Hub (repo_id = {organization}/{repo_name})
## filename = name of the model zip file from the repository
checkpoint = load_from_hub(
repo_id="sb3/demo-hf-CartPole-v1",
filename="ppo-CartPole-v1.zip",
)
model = PPO.load(checkpoint)
# Evaluate the agent and watch it
eval_env = gym.make("CartPole-v1")
mean_reward, std_reward = evaluate_policy(
model, eval_env, render=True, n_eval_episodes=5, deterministic=True, warn=False
)
print(f"mean_reward={mean_reward:.2f} +/- {std_reward}")
You need to define two parameters:
- ``repo-id``: the name of the Hugging Face repo you want to download.
- ``filename``: the file you want to download.
Upload a model to the Hub
-------------------------
You can easily upload your models using two different functions:
1. ``package_to_hub()``: save the model, evaluate it, generate a model card and record a replay video of your agent before pushing the complete repo to the Hub.
2. ``push_to_hub()``: simply push a file to the Hub.
First, you need to be logged in to Hugging Face to upload a model:
- If you're using Colab/Jupyter Notebooks:
.. code-block:: python
from huggingface_hub import notebook_login
notebook_login()
- Otherwise:
.. code-block:: bash
huggingface-cli login
Then, in this example, we train a PPO agent to play CartPole-v1 and push it to a new repo ``sb3/demo-hf-CartPole-v1``
With ``package_to_hub()``
^^^^^^^^^^^^^^^^^^^^^^^^^
.. code-block:: python
from stable_baselines3 import PPO
from stable_baselines3.common.env_util import make_vec_env
from huggingface_sb3 import package_to_hub
# Create the environment
env_id = "CartPole-v1"
env = make_vec_env(env_id, n_envs=1)
# Create the evaluation environment
eval_env = make_vec_env(env_id, n_envs=1)
# Instantiate the agent
model = PPO("MlpPolicy", env, verbose=1)
# Train the agent
model.learn(total_timesteps=int(5000))
# This method save, evaluate, generate a model card and record a replay video of your agent before pushing the repo to the hub
package_to_hub(model=model,
model_name="ppo-CartPole-v1",
model_architecture="PPO",
env_id=env_id,
eval_env=eval_env,
repo_id="sb3/demo-hf-CartPole-v1",
commit_message="Test commit")
You need to define seven parameters:
- ``model``: your trained model.
- ``model_architecture``: name of the architecture of your model (DQN, PPO, A2C, SAC…).
- ``env_id``: name of the environment.
- ``eval_env``: environment used to evaluate the agent.
- ``repo-id``: the name of the Hugging Face repo you want to create or update. Its <your huggingface username>/<the repo name>.
- ``commit-message``.
- ``filename``: the file you want to push to the Hub.
With ``push_to_hub()``
^^^^^^^^^^^^^^^^^^^^^^
.. code-block:: python
from stable_baselines3 import PPO
from stable_baselines3.common.env_util import make_vec_env
from huggingface_sb3 import push_to_hub
# Create the environment
env_id = "CartPole-v1"
env = make_vec_env(env_id, n_envs=1)
# Instantiate the agent
model = PPO("MlpPolicy", env, verbose=1)
# Train the agent
model.learn(total_timesteps=int(5000))
# Save the model
model.save("ppo-CartPole-v1")
# Push this saved model .zip file to the hf repo
# If this repo does not exists it will be created
## repo_id = id of the model repository from the Hugging Face Hub (repo_id = {organization}/{repo_name})
## filename: the name of the file == "name" inside model.save("ppo-CartPole-v1")
push_to_hub(
repo_id="sb3/demo-hf-CartPole-v1",
filename="ppo-CartPole-v1.zip",
commit_message="Added CartPole-v1 model trained with PPO",
)
You need to define three parameters:
- ``repo-id``: the name of the Hugging Face repo you want to create or update. Its <your huggingface username>/<the repo name>.
- ``filename``: the file you want to push to the Hub.
- ``commit-message``.
MLFLow
======
If you want to use `MLFLow <https://github.com/mlflow/mlflow>`_ to track your SB3 experiments,
you can adapt the following code which defines a custom logger output:
.. code-block:: python
import sys
from typing import Any, Dict, Tuple, Union
import mlflow
import numpy as np
from stable_baselines3 import SAC
from stable_baselines3.common.logger import HumanOutputFormat, KVWriter, Logger
class MLflowOutputFormat(KVWriter):
"""
Dumps key/value pairs into MLflow's numeric format.
"""
def write(
self,
key_values: Dict[str, Any],
key_excluded: Dict[str, Union[str, Tuple[str, ...]]],
step: int = 0,
) -> None:
for (key, value), (_, excluded) in zip(
sorted(key_values.items()), sorted(key_excluded.items())
):
if excluded is not None and "mlflow" in excluded:
continue
if isinstance(value, np.ScalarType):
if not isinstance(value, str):
mlflow.log_metric(key, value, step)
loggers = Logger(
folder=None,
output_formats=[HumanOutputFormat(sys.stdout), MLflowOutputFormat()],
)
with mlflow.start_run():
model = SAC("MlpPolicy", "Pendulum-v1", verbose=2)
# Set custom logger
model.set_logger(loggers)
model.learn(total_timesteps=10000, log_interval=1)