diff --git a/docs/guide/integrations.rst b/docs/guide/integrations.rst index 98cf84a..81d19ca 100644 --- a/docs/guide/integrations.rst +++ b/docs/guide/integrations.rst @@ -47,23 +47,35 @@ Hugging Face 🤗 =============== The Hugging Face Hub 🤗 is a central place where anyone can share and explore models. It allows you to host your saved models 💾. -You can see the list of stable-baselines3 saved models here: https://huggingface.co/models?other=stable-baselines3 +You can see the list of stable-baselines3 saved models here: https://huggingface.co/models?library=stable-baselines3 Most of them are available via the RL Zoo. Official pre-trained models are saved in the SB3 organization on the hub: https://huggingface.co/sb3 -We wrote a tutorial on how to use 🤗 Hub and Stable-Baselines3 here: https://colab.research.google.com/drive/1GI0WpThwRHbl-Fu2RHfczq6dci5GBDVE#scrollTo=q4cz-w9MdO7T +We wrote a tutorial on how to use 🤗 Hub and Stable-Baselines3 +`here `_. -For up to date instructions (for instance for using ``package_to_hub()``), please take a look at the Huggingface SB3 package README: https://github.com/huggingface/huggingface_sb3 Installation ------------- .. code-block:: bash - pip install huggingface_hub pip install huggingface_sb3 + .. note:: + + If you use the `RL Zoo `_, pushing/loading models from the hub is integrated in the RL Zoo: + + .. code-block:: bash + # Download model and save it into the logs/ folder + python -m rl_zoo3.load_from_hub --algo a2c --env LunarLander-v2 -orga sb3 -f logs/ + # Test the agent + python -m rl_zoo3.enjoy --algo a2c --env LunarLander-v2 -f logs/ + # push model, config and hyperparameters to the hub + python -m rl_zoo3.push_to_hub --algo a2c --env LunarLander-v2 -f logs/ -orga sb3 -m "Initial commit" + + Download a model from the Hub ----------------------------- @@ -83,7 +95,7 @@ For instance ``sb3/demo-hf-CartPole-v1``: ## filename = name of the model zip file from the repository checkpoint = load_from_hub( repo_id="sb3/demo-hf-CartPole-v1", - filename="ppo-CartPole-v1", + filename="ppo-CartPole-v1.zip", ) model = PPO.load(checkpoint) @@ -94,11 +106,22 @@ For instance ``sb3/demo-hf-CartPole-v1``: ) print(f"mean_reward={mean_reward:.2f} +/- {std_reward}") +You need to define two parameters: + +- ``repo-id``: the name of the Hugging Face repo you want to download. +- ``filename``: the file you want to download. Upload a model to the Hub ------------------------- +You can easily upload your models using two different functions: + +1. ``package_to_hub()``: save the model, evaluate it, generate a model card and record a replay video of your agent before pushing the complete repo to the Hub. + +2. ``push_to_hub()``: simply push a file to the Hub. + + First, you need to be logged in to Hugging Face to upload a model: - If you're using Colab/Jupyter Notebooks: @@ -109,38 +132,97 @@ First, you need to be logged in to Hugging Face to upload a model: notebook_login() -- Otheriwse: +- Otherwise: .. code-block:: bash huggingface-cli login + Then, in this example, we train a PPO agent to play CartPole-v1 and push it to a new repo ``sb3/demo-hf-CartPole-v1`` +With ``package_to_hub()`` +^^^^^^^^^^^^^^^^^^^^^^^^^ + .. code-block:: python - from huggingface_sb3 import push_to_hub from stable_baselines3 import PPO + from stable_baselines3.common.env_util import make_vec_env - # Define a PPO model with MLP policy network - model = PPO("MlpPolicy", "CartPole-v1", verbose=1) + from huggingface_sb3 import package_to_hub - # Train it for 10000 timesteps - model.learn(total_timesteps=10_000) + # Create the environment + env_id = "CartPole-v1" + env = make_vec_env(env_id, n_envs=1) + + # Create the evaluation environment + eval_env = make_vec_env(env_id, n_envs=1) + + # Instantiate the agent + model = PPO("MlpPolicy", env, verbose=1) + + # Train the agent + model.learn(total_timesteps=int(5000)) + + # This method save, evaluate, generate a model card and record a replay video of your agent before pushing the repo to the hub + package_to_hub(model=model, + model_name="ppo-CartPole-v1", + model_architecture="PPO", + env_id=env_id, + eval_env=eval_env, + repo_id="sb3/demo-hf-CartPole-v1", + commit_message="Test commit") + +You need to define seven parameters: + +- ``model``: your trained model. +- ``model_architecture``: name of the architecture of your model (DQN, PPO, A2C, SAC…). +- ``env_id``: name of the environment. +- ``eval_env``: environment used to evaluate the agent. +- ``repo-id``: the name of the Hugging Face repo you want to create or update. It’s /. +- ``commit-message``. +- ``filename``: the file you want to push to the Hub. + +With ``push_to_hub()`` +^^^^^^^^^^^^^^^^^^^^^^ + +.. code-block:: python + + + from stable_baselines3 import PPO + from stable_baselines3.common.env_util import make_vec_env + + from huggingface_sb3 import push_to_hub + + # Create the environment + env_id = "CartPole-v1" + env = make_vec_env(env_id, n_envs=1) + + # Instantiate the agent + model = PPO("MlpPolicy", env, verbose=1) + + # Train the agent + model.learn(total_timesteps=int(5000)) # Save the model model.save("ppo-CartPole-v1") - # Push this saved model to the hf repo + # Push this saved model .zip file to the hf repo # If this repo does not exists it will be created ## repo_id = id of the model repository from the Hugging Face Hub (repo_id = {organization}/{repo_name}) ## filename: the name of the file == "name" inside model.save("ppo-CartPole-v1") push_to_hub( - repo_id="sb3/demo-hf-CartPole-v1", - filename="ppo-CartPole-v1", - commit_message="Added Cartpole-v1 model trained with PPO", + repo_id="sb3/demo-hf-CartPole-v1", + filename="ppo-CartPole-v1.zip", + commit_message="Added CartPole-v1 model trained with PPO", ) +You need to define three parameters: + +- ``repo-id``: the name of the Hugging Face repo you want to create or update. It’s /. +- ``filename``: the file you want to push to the Hub. +- ``commit-message``. + MLFLow ====== diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index e11c5c9..42ca89d 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -35,7 +35,7 @@ Others: Documentation: ^^^^^^^^^^^^^^ - +- Updated Hugging Face Integration page (@simoninithomas) Release 1.6.2 (2022-10-10)