Add Hugging Face integration to SB3 doc (#733)

* Add Hugging Face to SB3 doc

* Update doc + fixes

* Use SB3 model from the hub

* Bump version

* Fixes

Co-authored-by: simoninithomas <simonini_thomas@outlook.fr>
This commit is contained in:
Antonin RAFFIN 2022-01-20 10:04:12 +01:00 committed by GitHub
parent fc41600225
commit 54bcfa4544
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 124 additions and 5 deletions

View file

@ -43,7 +43,97 @@ The full documentation is available here: https://docs.wandb.ai/guides/integrati
run.finish()
Hugging Face
============
Hugging Face 🤗
===============
The Hugging Face Hub 🤗 is a central place where anyone can share and explore models. It allows you to host your saved models 💾.
To be added.
You can see the list of stable-baselines3 saved models here: https://huggingface.co/models?other=stable-baselines3
Official pre-trained models are saved in the SB3 organization on the hub: https://huggingface.co/sb3
We wrote a tutorial on how to use 🤗 Hub and Stable-Baselines3 here: https://colab.research.google.com/drive/1GI0WpThwRHbl-Fu2RHfczq6dci5GBDVE#scrollTo=q4cz-w9MdO7T
Installation
-------------
.. code-block:: bash
pip install huggingface_hub
pip install huggingface_sb3
Download a model from the Hub
-----------------------------
You need to copy the repo-id that contains your saved model.
For instance ``sb3/demo-hf-CartPole-v1``:
.. code-block:: python
import gym
from huggingface_sb3 import load_from_hub
from stable_baselines3 import PPO
from stable_baselines3.common.evaluation import evaluate_policy
# Retrieve the model from the hub
## repo_id = id of the model repository from the Hugging Face Hub (repo_id = {organization}/{repo_name})
## filename = name of the model zip file from the repository
checkpoint = load_from_hub(
repo_id="sb3/demo-hf-CartPole-v1",
filename="ppo-CartPole-v1",
)
model = PPO.load(checkpoint)
# Evaluate the agent and watch it
eval_env = gym.make("CartPole-v1")
mean_reward, std_reward = evaluate_policy(
model, eval_env, render=True, n_eval_episodes=5, deterministic=True, warn=False
)
print(f"mean_reward={mean_reward:.2f} +/- {std_reward}")
Upload a model to the Hub
-------------------------
First, you need to be logged in to Hugging Face to upload a model:
- If you're using Colab/Jupyter Notebooks:
.. code-block:: python
from huggingface_hub import notebook_login
notebook_login()
- Otheriwse:
.. code-block:: bash
huggingface-cli login
Then, in this example, we train a PPO agent to play CartPole-v1 and push it to a new repo ``sb3/demo-hf-CartPole-v1``
.. code-block:: python
from huggingface_sb3 import push_to_hub
from stable_baselines3 import PPO
# Define a PPO model with MLP policy network
model = PPO("MlpPolicy", "CartPole-v1", verbose=1)
# Train it for 10000 timesteps
model.learn(total_timesteps=10_000)
# Save the model
model.save("ppo-CartPole-v1")
# Push this saved model to the hf repo
# If this repo does not exists it will be created
## repo_id = id of the model repository from the Hugging Face Hub (repo_id = {organization}/{repo_name})
## filename: the name of the file == "name" inside model.save("ppo-CartPole-v1")
push_to_hub(
repo_id="sb3/demo-hf-CartPole-v1",
filename="ppo-CartPole-v1",
commit_message="Added Cartpole-v1 model trained with PPO",
)

View file

@ -4,6 +4,35 @@ Changelog
==========
Release 1.4.1a0 (WIP)
---------------------------
Breaking Changes:
^^^^^^^^^^^^^^^^^
New Features:
^^^^^^^^^^^^^
SB3-Contrib
^^^^^^^^^^^
Bug Fixes:
^^^^^^^^^^
- Fixed a bug in ``VecMonitor``. The monitor did not consider the ``info_keywords`` during stepping (@ScheiklP)
Deprecations:
^^^^^^^^^^^^^
Others:
^^^^^^^
Documentation:
^^^^^^^^^^^^^^
- Added doc on Hugging Face integration (@simoninithomas)
Release 1.4.0 (2022-01-18)
---------------------------
@ -48,7 +77,6 @@ Bug Fixes:
- Fixed a bug where the observation would be incorrectly detected as non-vectorized instead of throwing an error
- The env checker now properly checks and warns about potential issues for continuous action spaces when the boundaries are too small or when the dtype is not float32
- Fixed a bug in ``VecFrameStack`` with channel first image envs, where the terminal observation would be wrongly created.
- Fixed a bug in ``VecMonitor``. The monitor did not consider the ``info_keywords`` during stepping (@ScheiklP)
Deprecations:
^^^^^^^^^^^^^
@ -881,3 +909,4 @@ And all the contributors:
@benblack769 @bstee615 @c-rizz @skandermoalla @MihaiAnca13 @davidblom603 @ayeright @cyprienc
@wkirgsn @AechPro @CUN-bjy @batu @IljaAvadiev @timokau @kachayev @cleversonahum
@eleurent @ac-93 @cove9988 @theDebugger811 @hsuehch @Demetrio92 @thomasgubler @IperGiove @ScheiklP
@simoninithomas

View file

@ -1 +1 @@
1.4.0
1.4.1a0