mirror of
https://github.com/saymrwulf/stable-baselines3.git
synced 2026-06-05 00:00:04 +00:00
Add Hugging Face integration to SB3 doc (#733)
* Add Hugging Face to SB3 doc * Update doc + fixes * Use SB3 model from the hub * Bump version * Fixes Co-authored-by: simoninithomas <simonini_thomas@outlook.fr>
This commit is contained in:
parent
fc41600225
commit
54bcfa4544
3 changed files with 124 additions and 5 deletions
|
|
@ -43,7 +43,97 @@ The full documentation is available here: https://docs.wandb.ai/guides/integrati
|
|||
run.finish()
|
||||
|
||||
|
||||
Hugging Face
|
||||
============
|
||||
Hugging Face 🤗
|
||||
===============
|
||||
The Hugging Face Hub 🤗 is a central place where anyone can share and explore models. It allows you to host your saved models 💾.
|
||||
|
||||
To be added.
|
||||
You can see the list of stable-baselines3 saved models here: https://huggingface.co/models?other=stable-baselines3
|
||||
|
||||
Official pre-trained models are saved in the SB3 organization on the hub: https://huggingface.co/sb3
|
||||
|
||||
We wrote a tutorial on how to use 🤗 Hub and Stable-Baselines3 here: https://colab.research.google.com/drive/1GI0WpThwRHbl-Fu2RHfczq6dci5GBDVE#scrollTo=q4cz-w9MdO7T
|
||||
|
||||
Installation
|
||||
-------------
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
pip install huggingface_hub
|
||||
pip install huggingface_sb3
|
||||
|
||||
|
||||
Download a model from the Hub
|
||||
-----------------------------
|
||||
You need to copy the repo-id that contains your saved model.
|
||||
For instance ``sb3/demo-hf-CartPole-v1``:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import gym
|
||||
|
||||
from huggingface_sb3 import load_from_hub
|
||||
from stable_baselines3 import PPO
|
||||
from stable_baselines3.common.evaluation import evaluate_policy
|
||||
|
||||
# Retrieve the model from the hub
|
||||
## repo_id = id of the model repository from the Hugging Face Hub (repo_id = {organization}/{repo_name})
|
||||
## filename = name of the model zip file from the repository
|
||||
checkpoint = load_from_hub(
|
||||
repo_id="sb3/demo-hf-CartPole-v1",
|
||||
filename="ppo-CartPole-v1",
|
||||
)
|
||||
model = PPO.load(checkpoint)
|
||||
|
||||
# Evaluate the agent and watch it
|
||||
eval_env = gym.make("CartPole-v1")
|
||||
mean_reward, std_reward = evaluate_policy(
|
||||
model, eval_env, render=True, n_eval_episodes=5, deterministic=True, warn=False
|
||||
)
|
||||
print(f"mean_reward={mean_reward:.2f} +/- {std_reward}")
|
||||
|
||||
|
||||
|
||||
Upload a model to the Hub
|
||||
-------------------------
|
||||
|
||||
First, you need to be logged in to Hugging Face to upload a model:
|
||||
|
||||
- If you're using Colab/Jupyter Notebooks:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from huggingface_hub import notebook_login
|
||||
notebook_login()
|
||||
|
||||
|
||||
- Otheriwse:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
huggingface-cli login
|
||||
|
||||
Then, in this example, we train a PPO agent to play CartPole-v1 and push it to a new repo ``sb3/demo-hf-CartPole-v1``
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from huggingface_sb3 import push_to_hub
|
||||
from stable_baselines3 import PPO
|
||||
|
||||
# Define a PPO model with MLP policy network
|
||||
model = PPO("MlpPolicy", "CartPole-v1", verbose=1)
|
||||
|
||||
# Train it for 10000 timesteps
|
||||
model.learn(total_timesteps=10_000)
|
||||
|
||||
# Save the model
|
||||
model.save("ppo-CartPole-v1")
|
||||
|
||||
# Push this saved model to the hf repo
|
||||
# If this repo does not exists it will be created
|
||||
## repo_id = id of the model repository from the Hugging Face Hub (repo_id = {organization}/{repo_name})
|
||||
## filename: the name of the file == "name" inside model.save("ppo-CartPole-v1")
|
||||
push_to_hub(
|
||||
repo_id="sb3/demo-hf-CartPole-v1",
|
||||
filename="ppo-CartPole-v1",
|
||||
commit_message="Added Cartpole-v1 model trained with PPO",
|
||||
)
|
||||
|
|
|
|||
|
|
@ -4,6 +4,35 @@ Changelog
|
|||
==========
|
||||
|
||||
|
||||
Release 1.4.1a0 (WIP)
|
||||
---------------------------
|
||||
|
||||
|
||||
Breaking Changes:
|
||||
^^^^^^^^^^^^^^^^^
|
||||
|
||||
|
||||
New Features:
|
||||
^^^^^^^^^^^^^
|
||||
|
||||
SB3-Contrib
|
||||
^^^^^^^^^^^
|
||||
|
||||
Bug Fixes:
|
||||
^^^^^^^^^^
|
||||
- Fixed a bug in ``VecMonitor``. The monitor did not consider the ``info_keywords`` during stepping (@ScheiklP)
|
||||
|
||||
Deprecations:
|
||||
^^^^^^^^^^^^^
|
||||
|
||||
Others:
|
||||
^^^^^^^
|
||||
|
||||
Documentation:
|
||||
^^^^^^^^^^^^^^
|
||||
- Added doc on Hugging Face integration (@simoninithomas)
|
||||
|
||||
|
||||
Release 1.4.0 (2022-01-18)
|
||||
---------------------------
|
||||
|
||||
|
|
@ -48,7 +77,6 @@ Bug Fixes:
|
|||
- Fixed a bug where the observation would be incorrectly detected as non-vectorized instead of throwing an error
|
||||
- The env checker now properly checks and warns about potential issues for continuous action spaces when the boundaries are too small or when the dtype is not float32
|
||||
- Fixed a bug in ``VecFrameStack`` with channel first image envs, where the terminal observation would be wrongly created.
|
||||
- Fixed a bug in ``VecMonitor``. The monitor did not consider the ``info_keywords`` during stepping (@ScheiklP)
|
||||
|
||||
Deprecations:
|
||||
^^^^^^^^^^^^^
|
||||
|
|
@ -881,3 +909,4 @@ And all the contributors:
|
|||
@benblack769 @bstee615 @c-rizz @skandermoalla @MihaiAnca13 @davidblom603 @ayeright @cyprienc
|
||||
@wkirgsn @AechPro @CUN-bjy @batu @IljaAvadiev @timokau @kachayev @cleversonahum
|
||||
@eleurent @ac-93 @cove9988 @theDebugger811 @hsuehch @Demetrio92 @thomasgubler @IperGiove @ScheiklP
|
||||
@simoninithomas
|
||||
|
|
|
|||
|
|
@ -1 +1 @@
|
|||
1.4.0
|
||||
1.4.1a0
|
||||
|
|
|
|||
Loading…
Reference in a new issue