mirror of
https://github.com/saymrwulf/stable-baselines3.git
synced 2026-07-04 04:07:27 +00:00
ONNX Documentation Update (#464)
* Updated ONNX documentation First draft on the documentation explaining how to export SB3 models in the ONNX format * Updated changelog with ONNX documentation fix * Address comments * Update changelog.rst * Update rtd env * Fixes + add test example Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org> Co-authored-by: Anssi Kanervisto <anssk@Anssis-MacBook-Air.local> Co-authored-by: Anssi Kanervisto <kaneran21@hotmail.com>
This commit is contained in:
parent
914bc10a0d
commit
89af49ca91
3 changed files with 86 additions and 6 deletions
|
|
@ -6,7 +6,7 @@ dependencies:
|
|||
- cpuonly=1.0=0
|
||||
- pip=20.2
|
||||
- python=3.6
|
||||
- pytorch=1.5.0=py3.6_cpu_0
|
||||
- pytorch=1.8.1=py3.6_cpu_0
|
||||
- pip:
|
||||
- gym>=0.17.2
|
||||
- cloudpickle
|
||||
|
|
@ -15,6 +15,6 @@ dependencies:
|
|||
- numpy
|
||||
- matplotlib
|
||||
- sphinx_autodoc_typehints
|
||||
# Tmp fix, docutils==0.17 breaks rtd theme
|
||||
- sphinx>=4.2
|
||||
# See https://github.com/readthedocs/sphinx_rtd_theme/issues/1115
|
||||
- docutils==0.16
|
||||
- sphinx_rtd_theme>=1.0
|
||||
|
|
|
|||
|
|
@ -31,8 +31,87 @@ to do inference in another framework.
|
|||
Export to ONNX
|
||||
-----------------
|
||||
|
||||
TODO: help is welcomed!
|
||||
As of June 2021, ONNX format `doesn't support <https://github.com/onnx/onnx/issues/3033>`_ exporting models that use the ``broadcast_tensors`` functionality of pytorch. So in order to export the trained stable-baseline3 models in the ONNX format, we need to first remove the layers that use broadcasting. This can be done by creating a class that removes the unsupported layers.
|
||||
|
||||
The following examples are for ``MlpPolicy`` only, and are general examples. Note that you have to preprocess the observation the same way stable-baselines3 agent does (see ``common.preprocessing.preprocess_obs``)
|
||||
|
||||
For PPO, assuming a shared feature extactor.
|
||||
|
||||
.. warning::
|
||||
|
||||
The following example is for continuous actions only.
|
||||
When using discrete or binary actions, you must do some `post-processing <https://github.com/DLR-RM/stable-baselines3/blob/f3a35aa786ee41ffff599b99fa1607c067e89074/stable_baselines3/common/policies.py#L621-L637>`_
|
||||
to obtain the action (e.g., convert action logits to action).
|
||||
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from stable_baselines3 import PPO
|
||||
import torch
|
||||
|
||||
class OnnxablePolicy(torch.nn.Module):
|
||||
def __init__(self, extractor, action_net, value_net):
|
||||
super(OnnxablePolicy, self).__init__()
|
||||
self.extractor = extractor
|
||||
self.action_net = action_net
|
||||
self.value_net = value_net
|
||||
|
||||
def forward(self, observation):
|
||||
# NOTE: You may have to process (normalize) observation in the correct
|
||||
# way before using this. See `common.preprocessing.preprocess_obs`
|
||||
action_hidden, value_hidden = self.extractor(observation)
|
||||
return self.action_net(action_hidden), self.value_net(value_hidden)
|
||||
|
||||
# Example: model = PPO("MlpPolicy", "Pendulum-v0")
|
||||
model = PPO.load("PathToTrainedModel.zip")
|
||||
model.policy.to("cpu")
|
||||
onnxable_model = OnnxablePolicy(model.policy.mlp_extractor, model.policy.action_net, model.policy.value_net)
|
||||
|
||||
dummy_input = torch.randn(1, observation_size)
|
||||
torch.onnx.export(onnxable_model, dummy_input, "my_ppo_model.onnx", opset_version=9)
|
||||
|
||||
##### Load and test with onnx
|
||||
|
||||
import onnx
|
||||
import onnxruntime as ort
|
||||
import numpy as np
|
||||
|
||||
onnx_model = onnx.load(onnx_path)
|
||||
onnx.checker.check_model(onnx_model)
|
||||
|
||||
observation = np.zeros((1, observation_size)).astype(np.float32)
|
||||
ort_sess = ort.InferenceSession(onnx_path)
|
||||
action, value = ort_sess.run(None, {'input.1': observation})
|
||||
|
||||
|
||||
For SAC the procedure is similar. The example shown only exports the actor network as the actor is sufficient to roll out the trained policies.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from stable_baselines3 import SAC
|
||||
import torch
|
||||
|
||||
class OnnxablePolicy(torch.nn.Module):
|
||||
def __init__(self, actor):
|
||||
super(OnnxablePolicy, self).__init__()
|
||||
|
||||
# Removing the flatten layer because it can't be onnxed
|
||||
self.actor = torch.nn.Sequential(actor.latent_pi, actor.mu)
|
||||
|
||||
def forward(self, observation):
|
||||
# NOTE: You may have to process (normalize) observation in the correct
|
||||
# way before using this. See `common.preprocessing.preprocess_obs`
|
||||
return self.actor(observation)
|
||||
|
||||
model = SAC.load("PathToTrainedModel.zip")
|
||||
onnxable_model = OnnxablePolicy(model.policy.actor)
|
||||
|
||||
dummy_input = torch.randn(1, observation_size)
|
||||
onnxable_model.policy.to("cpu")
|
||||
torch.onnx.export(onnxable_model, dummy_input, "my_sac_actor.onnx", opset_version=9)
|
||||
|
||||
|
||||
For more discussion around the topic refer to this `issue. <https://github.com/DLR-RM/stable-baselines3/issues/383>`_
|
||||
|
||||
Export to C++
|
||||
-----------------
|
||||
|
|
|
|||
|
|
@ -33,6 +33,8 @@ Documentation:
|
|||
- Add Rocket League Gym to list of supported projects (@AechPro)
|
||||
- Added gym-electric-motor to project page (@wkirgsn)
|
||||
- Added policy-distillation-baselines to project page (@CUN-bjy)
|
||||
- Added ONNX export instructions (@batu)
|
||||
- Update read the doc env (fixed ``docutils`` issue)
|
||||
|
||||
|
||||
Release 1.2.0 (2021-09-03)
|
||||
|
|
@ -776,5 +778,4 @@ And all the contributors:
|
|||
@tirafesi @blurLake @koulakis @joeljosephjin @shwang @rk37 @andyshih12 @RaphaelWag @xicocaio
|
||||
@diditforlulz273 @liorcohen5 @ManifoldFR @mloo3 @SwamyDev @wmmc88 @megan-klaiber @thisray
|
||||
@tfederico @hn2 @LucasAlegre @AptX395 @zampanteymedio @JadenTravnik @decodyng @ardabbour @lorenz-h @mschweizer @lorepieri8 @vwxyzjn
|
||||
@ShangqunYu @PierreExeter @JacopoPan @ltbd78 @tom-doerr @Atlis @liusida @09tangriro @amy12xx @juancroldan @benblack769 @bstee615
|
||||
@c-rizz @skandermoalla @MihaiAnca13 @davidblom603 @ayeright @cyprienc @wkirgsn @AechPro @CUN-bjy
|
||||
@ShangqunYu @PierreExeter @JacopoPan @ltbd78 @tom-doerr @Atlis @liusida @09tangriro @amy12xx @juancroldan @benblack769 @bstee615 @c-rizz @skandermoalla @MihaiAnca13 @davidblom603 @ayeright @cyprienc @wkirgsn @AechPro @CUN-bjy @batu
|
||||
|
|
|
|||
Loading…
Reference in a new issue