mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-31 23:27:43 +00:00
Fix ORTModule python doc generation (#7704)
* Fix ORTModule python doc generation * Address comment
This commit is contained in:
parent
ebee380911
commit
4fe2ffae16
4 changed files with 17 additions and 24 deletions
|
|
@ -28,4 +28,8 @@ html_static_path = ['_static']
|
|||
|
||||
# -- Options for intersphinx extension ---------------------------------------
|
||||
|
||||
intersphinx_mapping = {'https://docs.python.org/': None}
|
||||
intersphinx_mapping = {
|
||||
'python': ('https://docs.python.org/3', None),
|
||||
'numpy': ('https://numpy.org/doc/stable', None),
|
||||
'torch': ('https://pytorch.org/docs/stable/', None),
|
||||
}
|
||||
|
|
|
|||
|
|
@ -47,8 +47,8 @@ in the simple case where the entire model can be offloaded to ONNX Runtime:
|
|||
API
|
||||
===
|
||||
|
||||
.. automodule:: onnxruntime.training.ortmodule
|
||||
.. automodule:: onnxruntime.training.ortmodule.ortmodule
|
||||
:members:
|
||||
:show-inheritance:
|
||||
:member-order: bysource
|
||||
.. :inherited-members:
|
||||
:inherited-members:
|
||||
|
|
|
|||
|
|
@ -16,28 +16,10 @@ from typing import Iterator, Optional, Tuple, TypeVar
|
|||
T = TypeVar('T', bound='Module')
|
||||
|
||||
class ORTModule(torch.nn.Module):
|
||||
"""Specializes a user torch.nn.Module to leverage ONNX Runtime graph execution.
|
||||
"""Extends user's :class:`torch.nn.Module` model to leverage ONNX Runtime super fast training engine.
|
||||
|
||||
ORTModule specializes the user's torch.nn.Module and provides forward, backward
|
||||
implementations be leveraging ONNX Runtime.
|
||||
|
||||
ORTModule interacts with:
|
||||
- GraphExecutionManagerFactory: Which returns a GraphExecutionManager based on
|
||||
whether or not the user's torch module is in training mode or eval mode.
|
||||
- GraphExecutionManager: Responsible for building and executing the forward and backward graphs.
|
||||
- InferenceManager(GraphExecutionManager): Responsible for building, optimizing
|
||||
and executing the inference onnx graph.
|
||||
- TrainingManager(GraphExecutionManager): Responsible for building, optimizing
|
||||
and executing the training onnx graph.
|
||||
|
||||
The GraphExecutionManager first exports the user model into an onnx model.
|
||||
Following that, GraphExecutionManager interacts with OrtModuleGraphBuilder to optimize the onnx graph.
|
||||
Once the onnx graph has been optimized, an ExecutionAgent is instantiated that
|
||||
facilitates in executing the forward and backward subgraphs of the onnx model.
|
||||
|
||||
- _ortmodule_io: Provides utilities to transform the user inputs and outputs of the model.
|
||||
- It facilitates in flattening the output from the user's PyTorch model (since exporting
|
||||
of nested structures is not supported at the moment)
|
||||
ORTModule specializes the user's :class:`torch.nn.Module` model, providing :meth:`~torch.nn.Module.forward`,
|
||||
:meth:`~torch.nn.Module.backward` along with all others :class:`torch.nn.Module`'s APIs.
|
||||
"""
|
||||
|
||||
def __init__(self, module):
|
||||
|
|
@ -76,6 +58,13 @@ class ORTModule(torch.nn.Module):
|
|||
|
||||
self._execution_manager = GraphExecutionManagerFactory(self._flattened_module)
|
||||
|
||||
# IMPORTANT: DO NOT add code here
|
||||
# This declaration is for automatic document generation purposes only
|
||||
# The actual forward implementation is bound during ORTModule initialization
|
||||
def forward(self, *inputs, **kwargs):
|
||||
'''Dummy documentation for forward method'''
|
||||
...
|
||||
|
||||
def _is_training(self):
|
||||
return self._flattened_module.training and torch.is_grad_enabled()
|
||||
|
||||
|
|
|
|||
0
tools/doc/builddoc.sh
Normal file → Executable file
0
tools/doc/builddoc.sh
Normal file → Executable file
Loading…
Reference in a new issue