diff --git a/docs/python/training/conf.py b/docs/python/training/conf.py index 729e61f054..9e4b34297c 100644 --- a/docs/python/training/conf.py +++ b/docs/python/training/conf.py @@ -28,4 +28,8 @@ html_static_path = ['_static'] # -- Options for intersphinx extension --------------------------------------- -intersphinx_mapping = {'https://docs.python.org/': None} +intersphinx_mapping = { + 'python': ('https://docs.python.org/3', None), + 'numpy': ('https://numpy.org/doc/stable', None), + 'torch': ('https://pytorch.org/docs/stable/', None), +} diff --git a/docs/python/training/content.rst b/docs/python/training/content.rst index eff0714012..595a4259a2 100644 --- a/docs/python/training/content.rst +++ b/docs/python/training/content.rst @@ -47,8 +47,8 @@ in the simple case where the entire model can be offloaded to ONNX Runtime: API === -.. automodule:: onnxruntime.training.ortmodule +.. automodule:: onnxruntime.training.ortmodule.ortmodule :members: :show-inheritance: :member-order: bysource -.. :inherited-members: + :inherited-members: diff --git a/orttraining/orttraining/python/training/ortmodule/ortmodule.py b/orttraining/orttraining/python/training/ortmodule/ortmodule.py index d5afb07e56..62d1c7ee46 100644 --- a/orttraining/orttraining/python/training/ortmodule/ortmodule.py +++ b/orttraining/orttraining/python/training/ortmodule/ortmodule.py @@ -16,28 +16,10 @@ from typing import Iterator, Optional, Tuple, TypeVar T = TypeVar('T', bound='Module') class ORTModule(torch.nn.Module): - """Specializes a user torch.nn.Module to leverage ONNX Runtime graph execution. + """Extends user's :class:`torch.nn.Module` model to leverage ONNX Runtime super fast training engine. - ORTModule specializes the user's torch.nn.Module and provides forward, backward - implementations be leveraging ONNX Runtime. - - ORTModule interacts with: - - GraphExecutionManagerFactory: Which returns a GraphExecutionManager based on - whether or not the user's torch module is in training mode or eval mode. - - GraphExecutionManager: Responsible for building and executing the forward and backward graphs. - - InferenceManager(GraphExecutionManager): Responsible for building, optimizing - and executing the inference onnx graph. - - TrainingManager(GraphExecutionManager): Responsible for building, optimizing - and executing the training onnx graph. - - The GraphExecutionManager first exports the user model into an onnx model. - Following that, GraphExecutionManager interacts with OrtModuleGraphBuilder to optimize the onnx graph. - Once the onnx graph has been optimized, an ExecutionAgent is instantiated that - facilitates in executing the forward and backward subgraphs of the onnx model. - - - _ortmodule_io: Provides utilities to transform the user inputs and outputs of the model. - - It facilitates in flattening the output from the user's PyTorch model (since exporting - of nested structures is not supported at the moment) + ORTModule specializes the user's :class:`torch.nn.Module` model, providing :meth:`~torch.nn.Module.forward`, + :meth:`~torch.nn.Module.backward` along with all others :class:`torch.nn.Module`'s APIs. """ def __init__(self, module): @@ -76,6 +58,13 @@ class ORTModule(torch.nn.Module): self._execution_manager = GraphExecutionManagerFactory(self._flattened_module) + # IMPORTANT: DO NOT add code here + # This declaration is for automatic document generation purposes only + # The actual forward implementation is bound during ORTModule initialization + def forward(self, *inputs, **kwargs): + '''Dummy documentation for forward method''' + ... + def _is_training(self): return self._flattened_module.training and torch.is_grad_enabled() diff --git a/tools/doc/builddoc.sh b/tools/doc/builddoc.sh old mode 100644 new mode 100755