mirror of
https://github.com/saymrwulf/stable-baselines3.git
synced 2026-06-03 23:49:57 +00:00
Update export doc (fixes + add torch jit) (#1074)
* Update export doc (fixes + add torch jit) * Fix conflicts * Update according to code review comments * fix torch -> th Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
This commit is contained in:
parent
21300c9aaf
commit
537a82a7fd
2 changed files with 98 additions and 38 deletions
|
|
@ -46,29 +46,40 @@ For PPO, assuming a shared feature extactor.
|
|||
|
||||
.. code-block:: python
|
||||
|
||||
import torch as th
|
||||
|
||||
from stable_baselines3 import PPO
|
||||
import torch
|
||||
|
||||
class OnnxablePolicy(torch.nn.Module):
|
||||
def __init__(self, extractor, action_net, value_net):
|
||||
super(OnnxablePolicy, self).__init__()
|
||||
self.extractor = extractor
|
||||
self.action_net = action_net
|
||||
self.value_net = value_net
|
||||
|
||||
def forward(self, observation):
|
||||
# NOTE: You may have to process (normalize) observation in the correct
|
||||
# way before using this. See `common.preprocessing.preprocess_obs`
|
||||
action_hidden, value_hidden = self.extractor(observation)
|
||||
return self.action_net(action_hidden), self.value_net(value_hidden)
|
||||
class OnnxablePolicy(th.nn.Module):
|
||||
def __init__(self, extractor, action_net, value_net):
|
||||
super().__init__()
|
||||
self.extractor = extractor
|
||||
self.action_net = action_net
|
||||
self.value_net = value_net
|
||||
|
||||
def forward(self, observation):
|
||||
# NOTE: You may have to process (normalize) observation in the correct
|
||||
# way before using this. See `common.preprocessing.preprocess_obs`
|
||||
action_hidden, value_hidden = self.extractor(observation)
|
||||
return self.action_net(action_hidden), self.value_net(value_hidden)
|
||||
|
||||
|
||||
# Example: model = PPO("MlpPolicy", "Pendulum-v1")
|
||||
model = PPO.load("PathToTrainedModel.zip")
|
||||
model.policy.to("cpu")
|
||||
onnxable_model = OnnxablePolicy(model.policy.mlp_extractor, model.policy.action_net, model.policy.value_net)
|
||||
model = PPO.load("PathToTrainedModel.zip", device="cpu")
|
||||
onnxable_model = OnnxablePolicy(
|
||||
model.policy.mlp_extractor, model.policy.action_net, model.policy.value_net
|
||||
)
|
||||
|
||||
dummy_input = torch.randn(1, observation_size)
|
||||
torch.onnx.export(onnxable_model, dummy_input, "my_ppo_model.onnx", opset_version=9)
|
||||
observation_size = model.observation_space.shape
|
||||
dummy_input = th.randn(1, *observation_size)
|
||||
th.onnx.export(
|
||||
onnxable_model,
|
||||
dummy_input,
|
||||
"my_ppo_model.onnx",
|
||||
opset_version=9,
|
||||
input_names=["input"],
|
||||
)
|
||||
|
||||
##### Load and test with onnx
|
||||
|
||||
|
|
@ -76,48 +87,97 @@ For PPO, assuming a shared feature extactor.
|
|||
import onnxruntime as ort
|
||||
import numpy as np
|
||||
|
||||
onnx_path = "my_ppo_model.onnx"
|
||||
onnx_model = onnx.load(onnx_path)
|
||||
onnx.checker.check_model(onnx_model)
|
||||
|
||||
observation = np.zeros((1, observation_size)).astype(np.float32)
|
||||
observation = np.zeros((1, *observation_size)).astype(np.float32)
|
||||
ort_sess = ort.InferenceSession(onnx_path)
|
||||
action, value = ort_sess.run(None, {'input.1': observation})
|
||||
action, value = ort_sess.run(None, {"input": observation})
|
||||
|
||||
|
||||
For SAC the procedure is similar. The example shown only exports the actor network as the actor is sufficient to roll out the trained policies.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import torch as th
|
||||
|
||||
from stable_baselines3 import SAC
|
||||
import torch
|
||||
|
||||
class OnnxablePolicy(torch.nn.Module):
|
||||
def __init__(self, actor):
|
||||
super(OnnxablePolicy, self).__init__()
|
||||
|
||||
# Removing the flatten layer because it can't be onnxed
|
||||
self.actor = torch.nn.Sequential(actor.latent_pi, actor.mu)
|
||||
class OnnxablePolicy(th.nn.Module):
|
||||
def __init__(self, actor: th.nn.Module):
|
||||
super().__init__()
|
||||
# Removing the flatten layer because it can't be onnxed
|
||||
self.actor = th.nn.Sequential(
|
||||
actor.latent_pi,
|
||||
actor.mu,
|
||||
# For gSDE
|
||||
# th.nn.Hardtanh(min_val=-actor.clip_mean, max_val=actor.clip_mean),
|
||||
# Squash the output
|
||||
th.nn.Tanh(),
|
||||
)
|
||||
|
||||
def forward(self, observation):
|
||||
# NOTE: You may have to process (normalize) observation in the correct
|
||||
# way before using this. See `common.preprocessing.preprocess_obs`
|
||||
return self.actor(observation)
|
||||
def forward(self, observation: th.Tensor) -> th.Tensor:
|
||||
# NOTE: You may have to process (normalize) observation in the correct
|
||||
# way before using this. See `common.preprocessing.preprocess_obs`
|
||||
return self.actor(observation)
|
||||
|
||||
model = SAC.load("PathToTrainedModel.zip")
|
||||
|
||||
# Example: model = SAC("MlpPolicy", "Pendulum-v1")
|
||||
model = SAC.load("PathToTrainedModel.zip", device="cpu")
|
||||
onnxable_model = OnnxablePolicy(model.policy.actor)
|
||||
|
||||
dummy_input = torch.randn(1, observation_size)
|
||||
onnxable_model.policy.to("cpu")
|
||||
torch.onnx.export(onnxable_model, dummy_input, "my_sac_actor.onnx", opset_version=9)
|
||||
observation_size = model.observation_space.shape
|
||||
dummy_input = th.randn(1, *observation_size)
|
||||
th.onnx.export(
|
||||
onnxable_model,
|
||||
dummy_input,
|
||||
"my_sac_actor.onnx",
|
||||
opset_version=9,
|
||||
input_names=["input"],
|
||||
)
|
||||
|
||||
##### Load and test with onnx
|
||||
|
||||
import onnxruntime as ort
|
||||
import numpy as np
|
||||
|
||||
onnx_path = "my_sac_actor.onnx"
|
||||
|
||||
observation = np.zeros((1, *observation_size)).astype(np.float32)
|
||||
ort_sess = ort.InferenceSession(onnx_path)
|
||||
action = ort_sess.run(None, {"input": observation})
|
||||
|
||||
|
||||
For more discussion around the topic refer to this `issue. <https://github.com/DLR-RM/stable-baselines3/issues/383>`_
|
||||
|
||||
Export to C++
|
||||
-----------------
|
||||
Trace/Export to C++
|
||||
-------------------
|
||||
|
||||
(using PyTorch JIT)
|
||||
TODO: help is welcomed!
|
||||
You can use PyTorch JIT to trace and save a trained model that can be re-used in other applications
|
||||
(for instance inference code written in C++).
|
||||
|
||||
There is a draft PR in the RL Zoo about C++ export: https://github.com/DLR-RM/rl-baselines3-zoo/pull/228
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
# See "ONNX export" for imports and OnnxablePolicy
|
||||
jit_path = "sac_traced.pt"
|
||||
|
||||
# Trace and optimize the module
|
||||
traced_module = th.jit.trace(onnxable_model.eval(), dummy_input)
|
||||
frozen_module = th.jit.freeze(traced_module)
|
||||
frozen_module = th.jit.optimize_for_inference(frozen_module)
|
||||
th.jit.save(frozen_module, jit_path)
|
||||
|
||||
##### Load and test with torch
|
||||
|
||||
import torch as th
|
||||
|
||||
dummy_input = th.randn(1, *observation_size)
|
||||
loaded_module = th.jit.load(jit_path)
|
||||
action_jit = loaded_module(dummy_input)
|
||||
|
||||
|
||||
Export to tensorflowjs / ONNX-JS
|
||||
|
|
|
|||
|
|
@ -55,9 +55,9 @@ Documentation:
|
|||
- Fixed typo in install doc(@jlp-ue)
|
||||
- Clarified and standardized verbosity documentation
|
||||
- Added link to a GitHub issue in the custom policy documentation (@AlexPasqua)
|
||||
- Update doc on exporting models (fixes and added torch jit)
|
||||
- Fixed typos (@Akhilez)
|
||||
|
||||
|
||||
Release 1.6.0 (2022-07-11)
|
||||
---------------------------
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue