ONNX Documentation Update (#464)

* Updated ONNX documentation

First draft on the documentation explaining how to export SB3 models in the ONNX format

* Updated changelog with ONNX documentation fix

* Address comments

* Update changelog.rst

* Update rtd env

* Fixes + add test example

Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org>
Co-authored-by: Anssi Kanervisto <anssk@Anssis-MacBook-Air.local>
Co-authored-by: Anssi Kanervisto <kaneran21@hotmail.com>
This commit is contained in:
batu 2021-09-26 08:40:35 -07:00 committed by GitHub
parent 914bc10a0d
commit 89af49ca91
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 86 additions and 6 deletions

View file

@ -6,7 +6,7 @@ dependencies:
- cpuonly=1.0=0
- pip=20.2
- python=3.6
- pytorch=1.5.0=py3.6_cpu_0
- pytorch=1.8.1=py3.6_cpu_0
- pip:
- gym>=0.17.2
- cloudpickle
@ -15,6 +15,6 @@ dependencies:
- numpy
- matplotlib
- sphinx_autodoc_typehints
# Tmp fix, docutils==0.17 breaks rtd theme
- sphinx>=4.2
# See https://github.com/readthedocs/sphinx_rtd_theme/issues/1115
- docutils==0.16
- sphinx_rtd_theme>=1.0

View file

@ -31,8 +31,87 @@ to do inference in another framework.
Export to ONNX
-----------------
TODO: help is welcomed!
As of June 2021, ONNX format `doesn't support <https://github.com/onnx/onnx/issues/3033>`_ exporting models that use the ``broadcast_tensors`` functionality of pytorch. So in order to export the trained stable-baseline3 models in the ONNX format, we need to first remove the layers that use broadcasting. This can be done by creating a class that removes the unsupported layers.
The following examples are for ``MlpPolicy`` only, and are general examples. Note that you have to preprocess the observation the same way stable-baselines3 agent does (see ``common.preprocessing.preprocess_obs``)
For PPO, assuming a shared feature extactor.
.. warning::
The following example is for continuous actions only.
When using discrete or binary actions, you must do some `post-processing <https://github.com/DLR-RM/stable-baselines3/blob/f3a35aa786ee41ffff599b99fa1607c067e89074/stable_baselines3/common/policies.py#L621-L637>`_
to obtain the action (e.g., convert action logits to action).
.. code-block:: python
from stable_baselines3 import PPO
import torch
class OnnxablePolicy(torch.nn.Module):
def __init__(self, extractor, action_net, value_net):
super(OnnxablePolicy, self).__init__()
self.extractor = extractor
self.action_net = action_net
self.value_net = value_net
def forward(self, observation):
# NOTE: You may have to process (normalize) observation in the correct
# way before using this. See `common.preprocessing.preprocess_obs`
action_hidden, value_hidden = self.extractor(observation)
return self.action_net(action_hidden), self.value_net(value_hidden)
# Example: model = PPO("MlpPolicy", "Pendulum-v0")
model = PPO.load("PathToTrainedModel.zip")
model.policy.to("cpu")
onnxable_model = OnnxablePolicy(model.policy.mlp_extractor, model.policy.action_net, model.policy.value_net)
dummy_input = torch.randn(1, observation_size)
torch.onnx.export(onnxable_model, dummy_input, "my_ppo_model.onnx", opset_version=9)
##### Load and test with onnx
import onnx
import onnxruntime as ort
import numpy as np
onnx_model = onnx.load(onnx_path)
onnx.checker.check_model(onnx_model)
observation = np.zeros((1, observation_size)).astype(np.float32)
ort_sess = ort.InferenceSession(onnx_path)
action, value = ort_sess.run(None, {'input.1': observation})
For SAC the procedure is similar. The example shown only exports the actor network as the actor is sufficient to roll out the trained policies.
.. code-block:: python
from stable_baselines3 import SAC
import torch
class OnnxablePolicy(torch.nn.Module):
def __init__(self, actor):
super(OnnxablePolicy, self).__init__()
# Removing the flatten layer because it can't be onnxed
self.actor = torch.nn.Sequential(actor.latent_pi, actor.mu)
def forward(self, observation):
# NOTE: You may have to process (normalize) observation in the correct
# way before using this. See `common.preprocessing.preprocess_obs`
return self.actor(observation)
model = SAC.load("PathToTrainedModel.zip")
onnxable_model = OnnxablePolicy(model.policy.actor)
dummy_input = torch.randn(1, observation_size)
onnxable_model.policy.to("cpu")
torch.onnx.export(onnxable_model, dummy_input, "my_sac_actor.onnx", opset_version=9)
For more discussion around the topic refer to this `issue. <https://github.com/DLR-RM/stable-baselines3/issues/383>`_
Export to C++
-----------------

View file

@ -33,6 +33,8 @@ Documentation:
- Add Rocket League Gym to list of supported projects (@AechPro)
- Added gym-electric-motor to project page (@wkirgsn)
- Added policy-distillation-baselines to project page (@CUN-bjy)
- Added ONNX export instructions (@batu)
- Update read the doc env (fixed ``docutils`` issue)
Release 1.2.0 (2021-09-03)
@ -776,5 +778,4 @@ And all the contributors:
@tirafesi @blurLake @koulakis @joeljosephjin @shwang @rk37 @andyshih12 @RaphaelWag @xicocaio
@diditforlulz273 @liorcohen5 @ManifoldFR @mloo3 @SwamyDev @wmmc88 @megan-klaiber @thisray
@tfederico @hn2 @LucasAlegre @AptX395 @zampanteymedio @JadenTravnik @decodyng @ardabbour @lorenz-h @mschweizer @lorepieri8 @vwxyzjn
@ShangqunYu @PierreExeter @JacopoPan @ltbd78 @tom-doerr @Atlis @liusida @09tangriro @amy12xx @juancroldan @benblack769 @bstee615
@c-rizz @skandermoalla @MihaiAnca13 @davidblom603 @ayeright @cyprienc @wkirgsn @AechPro @CUN-bjy
@ShangqunYu @PierreExeter @JacopoPan @ltbd78 @tom-doerr @Atlis @liusida @09tangriro @amy12xx @juancroldan @benblack769 @bstee615 @c-rizz @skandermoalla @MihaiAnca13 @davidblom603 @ayeright @cyprienc @wkirgsn @AechPro @CUN-bjy @batu