Update export doc (fixes + add torch jit) (#1074)

* Update export doc (fixes + add torch jit)

* Fix conflicts

* Update according to code review comments

* fix torch -> th

Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
This commit is contained in:
Antonin RAFFIN 2022-09-30 14:30:40 +02:00 committed by GitHub
parent 21300c9aaf
commit 537a82a7fd
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 98 additions and 38 deletions

View file

@ -46,29 +46,40 @@ For PPO, assuming a shared feature extactor.
.. code-block:: python
import torch as th
from stable_baselines3 import PPO
import torch
class OnnxablePolicy(torch.nn.Module):
def __init__(self, extractor, action_net, value_net):
super(OnnxablePolicy, self).__init__()
self.extractor = extractor
self.action_net = action_net
self.value_net = value_net
def forward(self, observation):
# NOTE: You may have to process (normalize) observation in the correct
# way before using this. See `common.preprocessing.preprocess_obs`
action_hidden, value_hidden = self.extractor(observation)
return self.action_net(action_hidden), self.value_net(value_hidden)
class OnnxablePolicy(th.nn.Module):
def __init__(self, extractor, action_net, value_net):
super().__init__()
self.extractor = extractor
self.action_net = action_net
self.value_net = value_net
def forward(self, observation):
# NOTE: You may have to process (normalize) observation in the correct
# way before using this. See `common.preprocessing.preprocess_obs`
action_hidden, value_hidden = self.extractor(observation)
return self.action_net(action_hidden), self.value_net(value_hidden)
# Example: model = PPO("MlpPolicy", "Pendulum-v1")
model = PPO.load("PathToTrainedModel.zip")
model.policy.to("cpu")
onnxable_model = OnnxablePolicy(model.policy.mlp_extractor, model.policy.action_net, model.policy.value_net)
model = PPO.load("PathToTrainedModel.zip", device="cpu")
onnxable_model = OnnxablePolicy(
model.policy.mlp_extractor, model.policy.action_net, model.policy.value_net
)
dummy_input = torch.randn(1, observation_size)
torch.onnx.export(onnxable_model, dummy_input, "my_ppo_model.onnx", opset_version=9)
observation_size = model.observation_space.shape
dummy_input = th.randn(1, *observation_size)
th.onnx.export(
onnxable_model,
dummy_input,
"my_ppo_model.onnx",
opset_version=9,
input_names=["input"],
)
##### Load and test with onnx
@ -76,48 +87,97 @@ For PPO, assuming a shared feature extactor.
import onnxruntime as ort
import numpy as np
onnx_path = "my_ppo_model.onnx"
onnx_model = onnx.load(onnx_path)
onnx.checker.check_model(onnx_model)
observation = np.zeros((1, observation_size)).astype(np.float32)
observation = np.zeros((1, *observation_size)).astype(np.float32)
ort_sess = ort.InferenceSession(onnx_path)
action, value = ort_sess.run(None, {'input.1': observation})
action, value = ort_sess.run(None, {"input": observation})
For SAC the procedure is similar. The example shown only exports the actor network as the actor is sufficient to roll out the trained policies.
.. code-block:: python
import torch as th
from stable_baselines3 import SAC
import torch
class OnnxablePolicy(torch.nn.Module):
def __init__(self, actor):
super(OnnxablePolicy, self).__init__()
# Removing the flatten layer because it can't be onnxed
self.actor = torch.nn.Sequential(actor.latent_pi, actor.mu)
class OnnxablePolicy(th.nn.Module):
def __init__(self, actor: th.nn.Module):
super().__init__()
# Removing the flatten layer because it can't be onnxed
self.actor = th.nn.Sequential(
actor.latent_pi,
actor.mu,
# For gSDE
# th.nn.Hardtanh(min_val=-actor.clip_mean, max_val=actor.clip_mean),
# Squash the output
th.nn.Tanh(),
)
def forward(self, observation):
# NOTE: You may have to process (normalize) observation in the correct
# way before using this. See `common.preprocessing.preprocess_obs`
return self.actor(observation)
def forward(self, observation: th.Tensor) -> th.Tensor:
# NOTE: You may have to process (normalize) observation in the correct
# way before using this. See `common.preprocessing.preprocess_obs`
return self.actor(observation)
model = SAC.load("PathToTrainedModel.zip")
# Example: model = SAC("MlpPolicy", "Pendulum-v1")
model = SAC.load("PathToTrainedModel.zip", device="cpu")
onnxable_model = OnnxablePolicy(model.policy.actor)
dummy_input = torch.randn(1, observation_size)
onnxable_model.policy.to("cpu")
torch.onnx.export(onnxable_model, dummy_input, "my_sac_actor.onnx", opset_version=9)
observation_size = model.observation_space.shape
dummy_input = th.randn(1, *observation_size)
th.onnx.export(
onnxable_model,
dummy_input,
"my_sac_actor.onnx",
opset_version=9,
input_names=["input"],
)
##### Load and test with onnx
import onnxruntime as ort
import numpy as np
onnx_path = "my_sac_actor.onnx"
observation = np.zeros((1, *observation_size)).astype(np.float32)
ort_sess = ort.InferenceSession(onnx_path)
action = ort_sess.run(None, {"input": observation})
For more discussion around the topic refer to this `issue. <https://github.com/DLR-RM/stable-baselines3/issues/383>`_
Export to C++
-----------------
Trace/Export to C++
-------------------
(using PyTorch JIT)
TODO: help is welcomed!
You can use PyTorch JIT to trace and save a trained model that can be re-used in other applications
(for instance inference code written in C++).
There is a draft PR in the RL Zoo about C++ export: https://github.com/DLR-RM/rl-baselines3-zoo/pull/228
.. code-block:: python
# See "ONNX export" for imports and OnnxablePolicy
jit_path = "sac_traced.pt"
# Trace and optimize the module
traced_module = th.jit.trace(onnxable_model.eval(), dummy_input)
frozen_module = th.jit.freeze(traced_module)
frozen_module = th.jit.optimize_for_inference(frozen_module)
th.jit.save(frozen_module, jit_path)
##### Load and test with torch
import torch as th
dummy_input = th.randn(1, *observation_size)
loaded_module = th.jit.load(jit_path)
action_jit = loaded_module(dummy_input)
Export to tensorflowjs / ONNX-JS

View file

@ -55,9 +55,9 @@ Documentation:
- Fixed typo in install doc(@jlp-ue)
- Clarified and standardized verbosity documentation
- Added link to a GitHub issue in the custom policy documentation (@AlexPasqua)
- Update doc on exporting models (fixes and added torch jit)
- Fixed typos (@Akhilez)
Release 1.6.0 (2022-07-11)
---------------------------