From 89af49ca91bc1fc92def6997bc05ddb9127b4647 Mon Sep 17 00:00:00 2001 From: batu Date: Sun, 26 Sep 2021 08:40:35 -0700 Subject: [PATCH] ONNX Documentation Update (#464) * Updated ONNX documentation First draft on the documentation explaining how to export SB3 models in the ONNX format * Updated changelog with ONNX documentation fix * Address comments * Update changelog.rst * Update rtd env * Fixes + add test example Co-authored-by: Antonin RAFFIN Co-authored-by: Anssi Kanervisto Co-authored-by: Anssi Kanervisto --- docs/conda_env.yml | 6 +-- docs/guide/export.rst | 81 ++++++++++++++++++++++++++++++++++++++++- docs/misc/changelog.rst | 5 ++- 3 files changed, 86 insertions(+), 6 deletions(-) diff --git a/docs/conda_env.yml b/docs/conda_env.yml index 290b499..f2f153d 100644 --- a/docs/conda_env.yml +++ b/docs/conda_env.yml @@ -6,7 +6,7 @@ dependencies: - cpuonly=1.0=0 - pip=20.2 - python=3.6 - - pytorch=1.5.0=py3.6_cpu_0 + - pytorch=1.8.1=py3.6_cpu_0 - pip: - gym>=0.17.2 - cloudpickle @@ -15,6 +15,6 @@ dependencies: - numpy - matplotlib - sphinx_autodoc_typehints - # Tmp fix, docutils==0.17 breaks rtd theme + - sphinx>=4.2 # See https://github.com/readthedocs/sphinx_rtd_theme/issues/1115 - - docutils==0.16 + - sphinx_rtd_theme>=1.0 diff --git a/docs/guide/export.rst b/docs/guide/export.rst index 8be4951..8d1ff0e 100644 --- a/docs/guide/export.rst +++ b/docs/guide/export.rst @@ -31,8 +31,87 @@ to do inference in another framework. Export to ONNX ----------------- -TODO: help is welcomed! +As of June 2021, ONNX format `doesn't support `_ exporting models that use the ``broadcast_tensors`` functionality of pytorch. So in order to export the trained stable-baseline3 models in the ONNX format, we need to first remove the layers that use broadcasting. This can be done by creating a class that removes the unsupported layers. +The following examples are for ``MlpPolicy`` only, and are general examples. Note that you have to preprocess the observation the same way stable-baselines3 agent does (see ``common.preprocessing.preprocess_obs``) + +For PPO, assuming a shared feature extactor. + +.. warning:: + + The following example is for continuous actions only. + When using discrete or binary actions, you must do some `post-processing `_ + to obtain the action (e.g., convert action logits to action). + + +.. code-block:: python + + from stable_baselines3 import PPO + import torch + + class OnnxablePolicy(torch.nn.Module): + def __init__(self, extractor, action_net, value_net): + super(OnnxablePolicy, self).__init__() + self.extractor = extractor + self.action_net = action_net + self.value_net = value_net + + def forward(self, observation): + # NOTE: You may have to process (normalize) observation in the correct + # way before using this. See `common.preprocessing.preprocess_obs` + action_hidden, value_hidden = self.extractor(observation) + return self.action_net(action_hidden), self.value_net(value_hidden) + + # Example: model = PPO("MlpPolicy", "Pendulum-v0") + model = PPO.load("PathToTrainedModel.zip") + model.policy.to("cpu") + onnxable_model = OnnxablePolicy(model.policy.mlp_extractor, model.policy.action_net, model.policy.value_net) + + dummy_input = torch.randn(1, observation_size) + torch.onnx.export(onnxable_model, dummy_input, "my_ppo_model.onnx", opset_version=9) + + ##### Load and test with onnx + + import onnx + import onnxruntime as ort + import numpy as np + + onnx_model = onnx.load(onnx_path) + onnx.checker.check_model(onnx_model) + + observation = np.zeros((1, observation_size)).astype(np.float32) + ort_sess = ort.InferenceSession(onnx_path) + action, value = ort_sess.run(None, {'input.1': observation}) + + +For SAC the procedure is similar. The example shown only exports the actor network as the actor is sufficient to roll out the trained policies. + +.. code-block:: python + + from stable_baselines3 import SAC + import torch + + class OnnxablePolicy(torch.nn.Module): + def __init__(self, actor): + super(OnnxablePolicy, self).__init__() + + # Removing the flatten layer because it can't be onnxed + self.actor = torch.nn.Sequential(actor.latent_pi, actor.mu) + + def forward(self, observation): + # NOTE: You may have to process (normalize) observation in the correct + # way before using this. See `common.preprocessing.preprocess_obs` + return self.actor(observation) + + model = SAC.load("PathToTrainedModel.zip") + onnxable_model = OnnxablePolicy(model.policy.actor) + + dummy_input = torch.randn(1, observation_size) + onnxable_model.policy.to("cpu") + torch.onnx.export(onnxable_model, dummy_input, "my_sac_actor.onnx", opset_version=9) + + +For more discussion around the topic refer to this `issue. `_ Export to C++ ----------------- diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index 4046e90..620ec1c 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -33,6 +33,8 @@ Documentation: - Add Rocket League Gym to list of supported projects (@AechPro) - Added gym-electric-motor to project page (@wkirgsn) - Added policy-distillation-baselines to project page (@CUN-bjy) +- Added ONNX export instructions (@batu) +- Update read the doc env (fixed ``docutils`` issue) Release 1.2.0 (2021-09-03) @@ -776,5 +778,4 @@ And all the contributors: @tirafesi @blurLake @koulakis @joeljosephjin @shwang @rk37 @andyshih12 @RaphaelWag @xicocaio @diditforlulz273 @liorcohen5 @ManifoldFR @mloo3 @SwamyDev @wmmc88 @megan-klaiber @thisray @tfederico @hn2 @LucasAlegre @AptX395 @zampanteymedio @JadenTravnik @decodyng @ardabbour @lorenz-h @mschweizer @lorepieri8 @vwxyzjn -@ShangqunYu @PierreExeter @JacopoPan @ltbd78 @tom-doerr @Atlis @liusida @09tangriro @amy12xx @juancroldan @benblack769 @bstee615 -@c-rizz @skandermoalla @MihaiAnca13 @davidblom603 @ayeright @cyprienc @wkirgsn @AechPro @CUN-bjy +@ShangqunYu @PierreExeter @JacopoPan @ltbd78 @tom-doerr @Atlis @liusida @09tangriro @amy12xx @juancroldan @benblack769 @bstee615 @c-rizz @skandermoalla @MihaiAnca13 @davidblom603 @ayeright @cyprienc @wkirgsn @AechPro @CUN-bjy @batu