mirror of
https://github.com/saymrwulf/stable-baselines3.git
synced 2026-06-06 00:03:28 +00:00
Add a section on exporting to TFLite/Coral with demonstration (#679)
* Add a section on exporting to TFLite/Coral with demonstration * Changelog to reflect new export documentation * Update docs/guide/export.rst Fingers on autopilot make word wrong Co-authored-by: Anssi <kaneran21@hotmail.com> * Update docs/guide/export.rst Better wording clarity Co-authored-by: Anssi <kaneran21@hotmail.com> * Update docs/guide/export.rst Better wording clarity Co-authored-by: Anssi <kaneran21@hotmail.com> * Clarify motivations and hardware * Update docs/misc/changelog.rst Make consistent with other changelog entries Co-authored-by: Anssi <kaneran21@hotmail.com> * Sphinx wants the section underline to be at least this long * Remove first-person voice * Typos Co-authored-by: Anssi <kaneran21@hotmail.com>
This commit is contained in:
parent
3b68dc7312
commit
8e5ede783f
2 changed files with 34 additions and 1 deletions
|
|
@ -33,7 +33,7 @@ Export to ONNX
|
|||
|
||||
As of June 2021, ONNX format `doesn't support <https://github.com/onnx/onnx/issues/3033>`_ exporting models that use the ``broadcast_tensors`` functionality of pytorch. So in order to export the trained stable-baseline3 models in the ONNX format, we need to first remove the layers that use broadcasting. This can be done by creating a class that removes the unsupported layers.
|
||||
|
||||
The following examples are for ``MlpPolicy`` only, and are general examples. Note that you have to preprocess the observation the same way stable-baselines3 agent does (see ``common.preprocessing.preprocess_obs``)
|
||||
The following examples are for ``MlpPolicy`` only, and are general examples. Note that you have to preprocess the observation the same way stable-baselines3 agent does (see ``common.preprocessing.preprocess_obs``).
|
||||
|
||||
For PPO, assuming a shared feature extactor.
|
||||
|
||||
|
|
@ -127,6 +127,38 @@ TODO: contributors help is welcomed!
|
|||
Probably a good starting point: https://github.com/elliotwaite/pytorch-to-javascript-with-onnx-js
|
||||
|
||||
|
||||
Export to TFLite / Coral (Edge TPU)
|
||||
-----------------------------------
|
||||
|
||||
Full example code: https://github.com/chunky/sb3_to_coral
|
||||
|
||||
Google created a chip called the "Coral" for deploying AI to the
|
||||
edge. It's available in a variety of form factors, including USB (using
|
||||
the Coral on a Rasbperry pi, with a SB3-developed model, was the original
|
||||
motivation for the code example above).
|
||||
|
||||
The Coral chip is fast, with very low power consumption, but only has limited
|
||||
on-device training abilities. More information is on the webpage here:
|
||||
https://coral.ai.
|
||||
|
||||
To deploy to a Coral, one must work via TFLite, and quantise the
|
||||
network to reflect the Coral's capabilities. The full chain to go from
|
||||
SB3 to Coral is: SB3 (Torch) => ONNX => TensorFlow => TFLite => Coral.
|
||||
|
||||
The code linked above is a complete, minimal, example that:
|
||||
|
||||
1. Creates a model using SB3
|
||||
2. Follows the path of exports all the way to TFLite and Google Coral
|
||||
3. Demonstrates the forward pass for most exported variants
|
||||
|
||||
There are a number of pitfalls along the way to the complete conversion
|
||||
that this example covers, including:
|
||||
|
||||
- Making the Gym's observation work with ONNX properly
|
||||
- Quantising the TFLite model appropriately to align with Gym
|
||||
while still taking advantage of Coral
|
||||
- Using OnnxablePolicy described as described in the above example
|
||||
|
||||
|
||||
Manual export
|
||||
-------------
|
||||
|
|
|
|||
|
|
@ -35,6 +35,7 @@ Documentation:
|
|||
- Add tactile-gym to projects page (@ac-93)
|
||||
- Fix indentation in the RL tips page (@cove9988)
|
||||
- Update GAE computation docstring
|
||||
- Add documentation on exporting to TFLite/Coral
|
||||
|
||||
|
||||
Release 1.3.0 (2021-10-23)
|
||||
|
|
|
|||
Loading…
Reference in a new issue