From 537a82a7fdadc710d1ca0e98dddc967e3870f19d Mon Sep 17 00:00:00 2001 From: Antonin RAFFIN Date: Fri, 30 Sep 2022 14:30:40 +0200 Subject: [PATCH] Update export doc (fixes + add torch jit) (#1074) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Update export doc (fixes + add torch jit) * Fix conflicts * Update according to code review comments * fix torch -> th Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com> --- docs/guide/export.rst | 134 +++++++++++++++++++++++++++++----------- docs/misc/changelog.rst | 2 +- 2 files changed, 98 insertions(+), 38 deletions(-) diff --git a/docs/guide/export.rst b/docs/guide/export.rst index b6884c1..3a21749 100644 --- a/docs/guide/export.rst +++ b/docs/guide/export.rst @@ -46,29 +46,40 @@ For PPO, assuming a shared feature extactor. .. code-block:: python + import torch as th + from stable_baselines3 import PPO - import torch - class OnnxablePolicy(torch.nn.Module): - def __init__(self, extractor, action_net, value_net): - super(OnnxablePolicy, self).__init__() - self.extractor = extractor - self.action_net = action_net - self.value_net = value_net - def forward(self, observation): - # NOTE: You may have to process (normalize) observation in the correct - # way before using this. See `common.preprocessing.preprocess_obs` - action_hidden, value_hidden = self.extractor(observation) - return self.action_net(action_hidden), self.value_net(value_hidden) + class OnnxablePolicy(th.nn.Module): + def __init__(self, extractor, action_net, value_net): + super().__init__() + self.extractor = extractor + self.action_net = action_net + self.value_net = value_net + + def forward(self, observation): + # NOTE: You may have to process (normalize) observation in the correct + # way before using this. See `common.preprocessing.preprocess_obs` + action_hidden, value_hidden = self.extractor(observation) + return self.action_net(action_hidden), self.value_net(value_hidden) + # Example: model = PPO("MlpPolicy", "Pendulum-v1") - model = PPO.load("PathToTrainedModel.zip") - model.policy.to("cpu") - onnxable_model = OnnxablePolicy(model.policy.mlp_extractor, model.policy.action_net, model.policy.value_net) + model = PPO.load("PathToTrainedModel.zip", device="cpu") + onnxable_model = OnnxablePolicy( + model.policy.mlp_extractor, model.policy.action_net, model.policy.value_net + ) - dummy_input = torch.randn(1, observation_size) - torch.onnx.export(onnxable_model, dummy_input, "my_ppo_model.onnx", opset_version=9) + observation_size = model.observation_space.shape + dummy_input = th.randn(1, *observation_size) + th.onnx.export( + onnxable_model, + dummy_input, + "my_ppo_model.onnx", + opset_version=9, + input_names=["input"], + ) ##### Load and test with onnx @@ -76,48 +87,97 @@ For PPO, assuming a shared feature extactor. import onnxruntime as ort import numpy as np + onnx_path = "my_ppo_model.onnx" onnx_model = onnx.load(onnx_path) onnx.checker.check_model(onnx_model) - observation = np.zeros((1, observation_size)).astype(np.float32) + observation = np.zeros((1, *observation_size)).astype(np.float32) ort_sess = ort.InferenceSession(onnx_path) - action, value = ort_sess.run(None, {'input.1': observation}) + action, value = ort_sess.run(None, {"input": observation}) For SAC the procedure is similar. The example shown only exports the actor network as the actor is sufficient to roll out the trained policies. .. code-block:: python + import torch as th + from stable_baselines3 import SAC - import torch - class OnnxablePolicy(torch.nn.Module): - def __init__(self, actor): - super(OnnxablePolicy, self).__init__() - # Removing the flatten layer because it can't be onnxed - self.actor = torch.nn.Sequential(actor.latent_pi, actor.mu) + class OnnxablePolicy(th.nn.Module): + def __init__(self, actor: th.nn.Module): + super().__init__() + # Removing the flatten layer because it can't be onnxed + self.actor = th.nn.Sequential( + actor.latent_pi, + actor.mu, + # For gSDE + # th.nn.Hardtanh(min_val=-actor.clip_mean, max_val=actor.clip_mean), + # Squash the output + th.nn.Tanh(), + ) - def forward(self, observation): - # NOTE: You may have to process (normalize) observation in the correct - # way before using this. See `common.preprocessing.preprocess_obs` - return self.actor(observation) + def forward(self, observation: th.Tensor) -> th.Tensor: + # NOTE: You may have to process (normalize) observation in the correct + # way before using this. See `common.preprocessing.preprocess_obs` + return self.actor(observation) - model = SAC.load("PathToTrainedModel.zip") + + # Example: model = SAC("MlpPolicy", "Pendulum-v1") + model = SAC.load("PathToTrainedModel.zip", device="cpu") onnxable_model = OnnxablePolicy(model.policy.actor) - dummy_input = torch.randn(1, observation_size) - onnxable_model.policy.to("cpu") - torch.onnx.export(onnxable_model, dummy_input, "my_sac_actor.onnx", opset_version=9) + observation_size = model.observation_space.shape + dummy_input = th.randn(1, *observation_size) + th.onnx.export( + onnxable_model, + dummy_input, + "my_sac_actor.onnx", + opset_version=9, + input_names=["input"], + ) + + ##### Load and test with onnx + + import onnxruntime as ort + import numpy as np + + onnx_path = "my_sac_actor.onnx" + + observation = np.zeros((1, *observation_size)).astype(np.float32) + ort_sess = ort.InferenceSession(onnx_path) + action = ort_sess.run(None, {"input": observation}) For more discussion around the topic refer to this `issue. `_ -Export to C++ ------------------ +Trace/Export to C++ +------------------- -(using PyTorch JIT) -TODO: help is welcomed! +You can use PyTorch JIT to trace and save a trained model that can be re-used in other applications +(for instance inference code written in C++). + +There is a draft PR in the RL Zoo about C++ export: https://github.com/DLR-RM/rl-baselines3-zoo/pull/228 + +.. code-block:: python + + # See "ONNX export" for imports and OnnxablePolicy + jit_path = "sac_traced.pt" + + # Trace and optimize the module + traced_module = th.jit.trace(onnxable_model.eval(), dummy_input) + frozen_module = th.jit.freeze(traced_module) + frozen_module = th.jit.optimize_for_inference(frozen_module) + th.jit.save(frozen_module, jit_path) + + ##### Load and test with torch + + import torch as th + + dummy_input = th.randn(1, *observation_size) + loaded_module = th.jit.load(jit_path) + action_jit = loaded_module(dummy_input) Export to tensorflowjs / ONNX-JS diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index 64be0e4..5f98147 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -55,9 +55,9 @@ Documentation: - Fixed typo in install doc(@jlp-ue) - Clarified and standardized verbosity documentation - Added link to a GitHub issue in the custom policy documentation (@AlexPasqua) +- Update doc on exporting models (fixes and added torch jit) - Fixed typos (@Akhilez) - Release 1.6.0 (2022-07-11) ---------------------------