mirror of
https://github.com/saymrwulf/stable-baselines3.git
synced 2026-05-14 20:58:03 +00:00
* Fix failing set_env test * Fix test failiing due to deprectation of env.seed * Adjust mean reward threshold in failing test * Fix her test failing due to rng * Change seed and revert reward threshold to 90 * Pin gym version * Make VecEnv compatible with gym seeding change * Revert change to VecEnv reset signature * Change subprocenv seed cmd to call reset instead * Fix type check * Add backward compat * Add `compat_gym_seed` helper * Add goal env checks in env_checker * Add docs on HER requirements for envs * Capture user warning in test with inverted box space * Update ale-py version * Fix randint * Allow noop_max to be zero * Update changelog * Update docker image * Update doc conda env and dockerfile * Custom envs should not have any warnings * Fix test for numpy >= 1.21 * Add check for vectorized compute reward * Bump to gym 0.24 * Fix gym default step docstring * Test downgrading gym * Revert "Test downgrading gym" This reverts commit 0072b77156c006ada8a1d6e26ce347ed85a83eeb. * Fix protobuf error * Fix in dependencies * Fix protobuf dep * Use newest version of cartpole * Update gym * Fix warning * Loosen required scipy version * Scipy no longer needed * Try gym 0.25 * Silence warnings from gym * Filter warnings during tests * Update doc * Update requirements * Add gym 26 compat in vec env * Fixes in envs and tests for gym 0.26+ * Enforce gym 0.26 api * format * Fix formatting * Fix dependencies * Fix syntax * Cleanup doc and warnings * Faster tests * Higher budget for HER perf test (revert prev change) * Fixes and update doc * Fix doc build * Fix breaking change * Fixes for rendering * Rename variables in monitor * update render method for gym 0.26 API backwards compatible (mode argument is allowed) while using the gym 0.26 API (render mode is determined at environment creation) * update tests and docs to new gym render API * undo removal of render modes metatadata check * set rgb_array as default render mode for gym.make * undo changes & raise warning if not 'rgb_array' * Fix type check * Remove recursion and fix type checking * Remove hacks for protobuf and gym 0.24 * Fix type annotations * reuse existing render_mode attribute * return tiled images for 'human' render mode * Allow to use opencv for human render, fix typos * Add warning when using non-zero start with Discrete (fixes #1197) * Fix type checking * Bug fixes and handle more cases * Throw proper warnings * Update test * Fix new metadata name * Ignore numpy warnings * Fixes in vec recorder * Global ignore * Filter local warning too * Monkey patch not needed for gym 26 * Add doc of VecEnv vs Gym API * Add render test * Fix return type * Update VecEnv vs Gym API doc * Fix for custom render mode * Fix return type * Fix type checking * check test env test_buffer * skip render check * check env test_dict_env * test_env test_gae * check envs in remaining tests * Update tests * Add warning for Discrete action space with non-zero (#1295) * Fix atari annotation * ignore get_action_meanings [attr-defined] * Fix mypy issues * Add patch for gym/gymnasium transition * Switch to gymnasium * Rely on signature instead of version * More patches * Type ignore because of https://github.com/Farama-Foundation/Gymnasium/pull/39 * Fix doc build * Fix pytype errors * Fix atari requirement * Update env checker due to change in dtype for Discrete * Fix type hint * Convert spaces for saved models * Ignore pytype * Remove gitlab CI * Disable pytype for convert space * Fix undefined info * Fix undefined info * Upgrade shimmy * Fix wrappers type annotation (need PR from Gymnasium) * Fix gymnasium dependency * Fix dependency declaration * Cap pygame version for python 3.7 * Point to master branch (v0.28.0) * Fix: use main not master branch * Rename done to terminated * Fix pygame dependency for python 3.7 * Rename gym to gymnasium * Update Gymnasium * Fix test * Fix tests * Forks don't have access to private variables * Fix linter warnings * Update read the doc env * Fix env checker for GoalEnv * Fix import * Update env checker (more info) and fix dtype * Use micromamab for Docker * Update dependencies * Clarify VecEnv doc * Fix Gymnasium version * Copy file only after mamba install * [ci skip] Update docker doc * Polish code * Reformat * Remove deprecated features * Ignore warning * Update doc * Update examples and changelog * Fix type annotation bundle (SAC, TD3, A2C, PPO, base class) (#1436) * Fix SAC type hints, improve DQN ones * Fix A2C and TD3 type hints * Fix PPO type hints * Fix on-policy type hints * Fix base class type annotation, do not use defaults * Update version * Disable mypy for python 3.7 * Rename Gym26StepReturn * Update continuous critic type annotation * Fix pytype complain --------- Co-authored-by: Carlos Luis <carlos.luisgonc@gmail.com> Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com> Co-authored-by: Thomas Lips <37955681+tlpss@users.noreply.github.com> Co-authored-by: tlips <thomas.lips@ugent.be> Co-authored-by: tlpss <thomas17.lips@gmail.com> Co-authored-by: Quentin GALLOUÉDEC <gallouedec.quentin@gmail.com>
279 lines
8.1 KiB
ReStructuredText
279 lines
8.1 KiB
ReStructuredText
.. _integrations:
|
||
|
||
============
|
||
Integrations
|
||
============
|
||
|
||
Weights & Biases
|
||
================
|
||
|
||
Weights & Biases provides a callback for experiment tracking that allows to visualize and share results.
|
||
|
||
The full documentation is available here: https://docs.wandb.ai/guides/integrations/other/stable-baselines-3
|
||
|
||
.. code-block:: python
|
||
|
||
import gymnasium as gym
|
||
import wandb
|
||
from wandb.integration.sb3 import WandbCallback
|
||
|
||
from stable_baselines3 import PPO
|
||
|
||
config = {
|
||
"policy_type": "MlpPolicy",
|
||
"total_timesteps": 25000,
|
||
"env_id": "CartPole-v1",
|
||
}
|
||
run = wandb.init(
|
||
project="sb3",
|
||
config=config,
|
||
sync_tensorboard=True, # auto-upload sb3's tensorboard metrics
|
||
# monitor_gym=True, # auto-upload the videos of agents playing the game
|
||
# save_code=True, # optional
|
||
)
|
||
|
||
model = PPO(config["policy_type"], config["env_id"], verbose=1, tensorboard_log=f"runs/{run.id}")
|
||
model.learn(
|
||
total_timesteps=config["total_timesteps"],
|
||
callback=WandbCallback(
|
||
model_save_path=f"models/{run.id}",
|
||
verbose=2,
|
||
),
|
||
)
|
||
run.finish()
|
||
|
||
|
||
Hugging Face 🤗
|
||
===============
|
||
The Hugging Face Hub 🤗 is a central place where anyone can share and explore models. It allows you to host your saved models 💾.
|
||
|
||
You can see the list of stable-baselines3 saved models here: https://huggingface.co/models?library=stable-baselines3
|
||
Most of them are available via the RL Zoo.
|
||
|
||
Official pre-trained models are saved in the SB3 organization on the hub: https://huggingface.co/sb3
|
||
|
||
We wrote a tutorial on how to use 🤗 Hub and Stable-Baselines3
|
||
`here <https://colab.research.google.com/github/huggingface/huggingface_sb3/blob/main/notebooks/sb3_huggingface.ipynb>`_.
|
||
|
||
|
||
Installation
|
||
-------------
|
||
|
||
.. code-block:: bash
|
||
|
||
pip install huggingface_sb3
|
||
|
||
|
||
.. note::
|
||
|
||
If you use the `RL Zoo <https://github.com/DLR-RM/rl-baselines3-zoo>`_, pushing/loading models from the hub are already integrated:
|
||
|
||
.. code-block:: bash
|
||
|
||
# Download model and save it into the logs/ folder
|
||
python -m rl_zoo3.load_from_hub --algo a2c --env LunarLander-v2 -orga sb3 -f logs/
|
||
# Test the agent
|
||
python -m rl_zoo3.enjoy --algo a2c --env LunarLander-v2 -f logs/
|
||
# Push model, config and hyperparameters to the hub
|
||
python -m rl_zoo3.push_to_hub --algo a2c --env LunarLander-v2 -f logs/ -orga sb3 -m "Initial commit"
|
||
|
||
|
||
|
||
Download a model from the Hub
|
||
-----------------------------
|
||
You need to copy the repo-id that contains your saved model.
|
||
For instance ``sb3/demo-hf-CartPole-v1``:
|
||
|
||
.. code-block:: python
|
||
|
||
import gymnasium as gym
|
||
|
||
from huggingface_sb3 import load_from_hub
|
||
from stable_baselines3 import PPO
|
||
from stable_baselines3.common.evaluation import evaluate_policy
|
||
|
||
# Retrieve the model from the hub
|
||
## repo_id = id of the model repository from the Hugging Face Hub (repo_id = {organization}/{repo_name})
|
||
## filename = name of the model zip file from the repository
|
||
checkpoint = load_from_hub(
|
||
repo_id="sb3/demo-hf-CartPole-v1",
|
||
filename="ppo-CartPole-v1.zip",
|
||
)
|
||
model = PPO.load(checkpoint)
|
||
|
||
# Evaluate the agent and watch it
|
||
eval_env = gym.make("CartPole-v1")
|
||
mean_reward, std_reward = evaluate_policy(
|
||
model, eval_env, render=True, n_eval_episodes=5, deterministic=True, warn=False
|
||
)
|
||
print(f"mean_reward={mean_reward:.2f} +/- {std_reward}")
|
||
|
||
You need to define two parameters:
|
||
|
||
- ``repo-id``: the name of the Hugging Face repo you want to download.
|
||
- ``filename``: the file you want to download.
|
||
|
||
|
||
Upload a model to the Hub
|
||
-------------------------
|
||
|
||
You can easily upload your models using two different functions:
|
||
|
||
1. ``package_to_hub()``: save the model, evaluate it, generate a model card and record a replay video of your agent before pushing the complete repo to the Hub.
|
||
|
||
2. ``push_to_hub()``: simply push a file to the Hub.
|
||
|
||
|
||
First, you need to be logged in to Hugging Face to upload a model:
|
||
|
||
- If you're using Colab/Jupyter Notebooks:
|
||
|
||
.. code-block:: python
|
||
|
||
from huggingface_hub import notebook_login
|
||
notebook_login()
|
||
|
||
|
||
- Otherwise:
|
||
|
||
.. code-block:: bash
|
||
|
||
huggingface-cli login
|
||
|
||
|
||
Then, in this example, we train a PPO agent to play CartPole-v1 and push it to a new repo ``sb3/demo-hf-CartPole-v1``
|
||
|
||
With ``package_to_hub()``
|
||
^^^^^^^^^^^^^^^^^^^^^^^^^
|
||
|
||
.. code-block:: python
|
||
|
||
from stable_baselines3 import PPO
|
||
from stable_baselines3.common.env_util import make_vec_env
|
||
|
||
from huggingface_sb3 import package_to_hub
|
||
|
||
# Create the environment
|
||
env_id = "CartPole-v1"
|
||
env = make_vec_env(env_id, n_envs=1)
|
||
|
||
# Create the evaluation environment
|
||
eval_env = make_vec_env(env_id, n_envs=1)
|
||
|
||
# Instantiate the agent
|
||
model = PPO("MlpPolicy", env, verbose=1)
|
||
|
||
# Train the agent
|
||
model.learn(total_timesteps=int(5000))
|
||
|
||
# This method save, evaluate, generate a model card and record a replay video of your agent before pushing the repo to the hub
|
||
package_to_hub(model=model,
|
||
model_name="ppo-CartPole-v1",
|
||
model_architecture="PPO",
|
||
env_id=env_id,
|
||
eval_env=eval_env,
|
||
repo_id="sb3/demo-hf-CartPole-v1",
|
||
commit_message="Test commit")
|
||
|
||
You need to define seven parameters:
|
||
|
||
- ``model``: your trained model.
|
||
- ``model_architecture``: name of the architecture of your model (DQN, PPO, A2C, SAC…).
|
||
- ``env_id``: name of the environment.
|
||
- ``eval_env``: environment used to evaluate the agent.
|
||
- ``repo-id``: the name of the Hugging Face repo you want to create or update. It’s <your huggingface username>/<the repo name>.
|
||
- ``commit-message``.
|
||
- ``filename``: the file you want to push to the Hub.
|
||
|
||
With ``push_to_hub()``
|
||
^^^^^^^^^^^^^^^^^^^^^^
|
||
|
||
.. code-block:: python
|
||
|
||
|
||
from stable_baselines3 import PPO
|
||
from stable_baselines3.common.env_util import make_vec_env
|
||
|
||
from huggingface_sb3 import push_to_hub
|
||
|
||
# Create the environment
|
||
env_id = "CartPole-v1"
|
||
env = make_vec_env(env_id, n_envs=1)
|
||
|
||
# Instantiate the agent
|
||
model = PPO("MlpPolicy", env, verbose=1)
|
||
|
||
# Train the agent
|
||
model.learn(total_timesteps=int(5000))
|
||
|
||
# Save the model
|
||
model.save("ppo-CartPole-v1")
|
||
|
||
# Push this saved model .zip file to the hf repo
|
||
# If this repo does not exists it will be created
|
||
## repo_id = id of the model repository from the Hugging Face Hub (repo_id = {organization}/{repo_name})
|
||
## filename: the name of the file == "name" inside model.save("ppo-CartPole-v1")
|
||
push_to_hub(
|
||
repo_id="sb3/demo-hf-CartPole-v1",
|
||
filename="ppo-CartPole-v1.zip",
|
||
commit_message="Added CartPole-v1 model trained with PPO",
|
||
)
|
||
|
||
You need to define three parameters:
|
||
|
||
- ``repo-id``: the name of the Hugging Face repo you want to create or update. It’s <your huggingface username>/<the repo name>.
|
||
- ``filename``: the file you want to push to the Hub.
|
||
- ``commit-message``.
|
||
|
||
MLFLow
|
||
======
|
||
|
||
If you want to use `MLFLow <https://github.com/mlflow/mlflow>`_ to track your SB3 experiments,
|
||
you can adapt the following code which defines a custom logger output:
|
||
|
||
.. code-block:: python
|
||
|
||
import sys
|
||
from typing import Any, Dict, Tuple, Union
|
||
|
||
import mlflow
|
||
import numpy as np
|
||
|
||
from stable_baselines3 import SAC
|
||
from stable_baselines3.common.logger import HumanOutputFormat, KVWriter, Logger
|
||
|
||
|
||
class MLflowOutputFormat(KVWriter):
|
||
"""
|
||
Dumps key/value pairs into MLflow's numeric format.
|
||
"""
|
||
|
||
def write(
|
||
self,
|
||
key_values: Dict[str, Any],
|
||
key_excluded: Dict[str, Union[str, Tuple[str, ...]]],
|
||
step: int = 0,
|
||
) -> None:
|
||
|
||
for (key, value), (_, excluded) in zip(
|
||
sorted(key_values.items()), sorted(key_excluded.items())
|
||
):
|
||
|
||
if excluded is not None and "mlflow" in excluded:
|
||
continue
|
||
|
||
if isinstance(value, np.ScalarType):
|
||
if not isinstance(value, str):
|
||
mlflow.log_metric(key, value, step)
|
||
|
||
|
||
loggers = Logger(
|
||
folder=None,
|
||
output_formats=[HumanOutputFormat(sys.stdout), MLflowOutputFormat()],
|
||
)
|
||
|
||
with mlflow.start_run():
|
||
model = SAC("MlpPolicy", "Pendulum-v1", verbose=2)
|
||
# Set custom logger
|
||
model.set_logger(loggers)
|
||
model.learn(total_timesteps=10000, log_interval=1)
|