ONNX Runtime: cross-platform, high performance ML inferencing and training accelerator
Find a file
pengwa 4932e04053
ORTModule GraphTransitionManager (#19007)
### Problem

Currently, the codebase contains some logics pertaining to model
re-export checks and graph_builder reinitialization checks. Ideally,
these operations should function akin to a state machine. However, upon
inspecting the implementation, it becomes apparent that certain states
are checked or set in various scattered locations. This fragmentation
makes it challenging to comprehend when a re-export or re-initialization
will be triggered. For optimal clarity and maintainability, it is
advisable to consolidate these states into a cohesive component, rather
than dispersing them within the current graph execution manager.

Furthermore, the process of model exports and post-export processing for
stage 3 support or memory-efficient gradient management introduces
considerable complexity. To enhance the codebase's structure, it would
be beneficial to extract these intricate functionalities into a
dedicated component, divorcing them from the current graph execution
manager.

As part of the effort to improve the codebase, it's essential to address
inconsistencies in handling input/output flatten/unflatten operations.
Currently, there are several functions performing these operations
recursively, each with slightly different implementations. This
inconsistency leads to varying support for input/output data types and
structures in different parts of the code. To rectify this, the proposed
pull request simplifies these operations into a set of primitive
functions, ensuring uniformity. This not only streamlines the code but
also facilitates the maintenance of consistency when introducing bug
fixes or supporting new data types. One thing to mention here: input
output handling is deeply bound to the graph transition mentioned above,
so it is difficult to make this change separately.

While acknowledging the complexity of these logics, it is reassuring
that the codebase benefits from an extensive suite of unit tests that
cover all possible branches. Despite the intricacies, ensuring the
passage of all tests has been a time-intensive but necessary aspect of
this development effort.

### Design


Introduce `GraphTransitionManager` and put all model export and
post-export processing logics in it.
1. Re-export check
2. Do export
3. Re-post-export process check
4. Do post-export process
5. Return `PostExportProcessedModelInfo`, which contains all the
information we need, to pass to ORT to build gradient graph (currently
we do the same for training or evaluating, but ideally we should not do
it for evaluating, let's keep this behavior as it is now, and make the
change later).
    ```
          # Input names for the pre-gradient-build graph.
# This may be different with the one in ExportedGraph since we may
modify the graph inputs as needed
# for example when memory efficient gradient management is enabled.
self.onnx_graph_input_names: list[str] = onnx_graph_input_names
  
          # A subset of onnx_graph_input_names.
# Input names that require gradients for the pre-gradient-build graph.
self.onnx_graph_input_names_require_grad: list[str] =
onnx_graph_input_names_require_grad
  
# Create symbolic names for each dimension of the graph input (e.g.
onnx_graph_input_names).
# The key is the input name, the value is a dict of {dim_index:
symbolic_dim_name}
# e.g. {"input1": {0: "input1_dim0", 1: "input1_dim1"}, "input2": {0:
"input2_dim0"}}
self.onnx_graph_input_dynamic_axes_map: dict[str, dict[int, str]] =
onnx_graph_input_dynamic_axes_map
  
self.buffer_for_ort_runs: dict[str, torch.Tensor] = OrderedDict()
          self.onnx_graph_input_names_user_defined = (
onnx_graph_input_names_user_defined # The ONNX graph input names
excluding the parameters, buffers.
          )
  
# The ONNX graph input names excluding the parameters, buffers.
self.onnx_graph_input_names_require_grad_user_defined =
onnx_graph_input_names_require_grad_user_defined
  
self._post_export_processed_model: onnx.ModelProto | None =
post_export_processed_model
  
# A function to access the input data from the args and kwargs.
# If it is not None, the length is same as onnx_graph_input_names.
# For i-th input name, we can use the i-th function to get the input
data from args and kwargs.
          self.data_accessor: list[callable] | None = data_accessor
  
          # Used for unflattening the outputs from the ORT forward run.
self.module_forward_output_schema: ORTModelInputOutputSchemaType | None
= module_forward_output_schema```




The `GraphTransitionManager` instance is a property of
`GraphExecutionManager` (e.g. `TrainingManager` or ``InferenceManager),
1. Use
'self._graph_transition_manager.use_cache_or_reconstruct_post_processed_model(inputs,
kwargs)' to check whether the PyTorch module need a re-export or
re-post-export-process.
2. Use
`self._graph_transition_manager._post_export_processed_model_info.construct_inputs`
to construct the list of inputs used for ORT runs.
3. Use
`self._graph_transition_manager._post_export_processed_model_info.restore_outputs(user_outputs)`
to restore the outputs in original PyTorch output structure.



### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->
2024-07-03 10:53:31 +08:00
.config
.devcontainer
.gdn
.github CoreML: Disable 1D ML Program matmul due to bug in coreml (#21186) 2024-06-29 12:19:51 -07:00
.pipelines Fix onebranch exception in code signing (#21088) 2024-06-19 12:07:17 +08:00
.vscode disable gemm f16 on CPU (#19744) 2024-03-01 13:44:29 -08:00
cgmanifests [ROCm] Update ck to use ck_tile (#21030) 2024-06-19 14:06:10 +08:00
cmake onnxruntime shared lib inside python package (#21223) 2024-07-02 15:37:50 -07:00
csharp Update to the net8 MAUI targets. Remove Xamarin. (#21062) 2024-06-19 16:20:58 +10:00
dockerfiles Update Dockerfile.cuda (#21042) 2024-06-13 23:50:03 -07:00
docs Rename a mispelled filename in the documentation (#21066) 2024-06-17 18:18:41 +02:00
include/onnxruntime/core Initial PR for VSINPU execution provider (#20903) 2024-06-28 21:48:34 -07:00
java Remove warning suppression from Java Packaging pipeline. (#21010) 2024-06-24 16:46:21 -07:00
js support for layernorm in webgpu pre opset-17 (#21121) 2024-06-27 10:20:48 -07:00
objectivec Fix Objective-C static analysis warnings. (#20417) 2024-04-24 11:48:29 -07:00
onnxruntime Add debugging helper to dump string, vector and thread id (#21224) 2024-07-02 11:24:04 -07:00
orttraining ORTModule GraphTransitionManager (#19007) 2024-07-03 10:53:31 +08:00
rust
samples
tools onnxruntime shared lib inside python package (#21223) 2024-07-02 15:37:50 -07:00
winml [DML EP] Add GroupQueryAttention (#20327) 2024-04-19 10:25:29 -07:00
.clang-format
.clang-tidy
.dockerignore
.gitattributes
.gitignore Build onnxruntime.dll as arm64x (#18633) 2023-12-06 16:49:00 -08:00
.gitmodules [js/web] optimize module export and deployment (#20165) 2024-05-20 09:51:16 -07:00
.lintrunner.toml Make Flash Attention work on Windows (#21015) 2024-06-24 09:43:49 -07:00
build.bat
build.sh
build_arm64x.bat remove unnecessary environment variable (#19166) 2024-01-16 16:24:37 -08:00
CITATION.cff Fix citation author name issue (#19597) 2024-02-22 17:03:56 -08:00
CODEOWNERS
CONTRIBUTING.md
lgtm.yml
LICENSE
NuGet.config
ort.wprp Fully dynamic ETW controlled logging for ORT and QNN logs (#20537) 2024-06-06 21:11:14 -07:00
ORT_icon_for_light_bg.png
packages.config Update DML to 1.14.1 (#20380) 2024-04-18 22:43:41 -07:00
pyproject.toml [CUDA] Add SparseAttention operator for Phi-3-small (#20216) 2024-04-30 09:06:29 -07:00
README.md Update README.md (#18963) 2024-01-03 17:26:25 -08:00
requirements-dev.txt
requirements-doc.txt
requirements-lintrunner.txt Bump ruff to 0.3.2 and black to 24 (#19878) 2024-03-13 10:00:32 -07:00
requirements-training.txt
requirements.txt Add compatibility for NumPy 2.0 (#21085) 2024-06-27 13:50:53 -07:00
SECURITY.md
setup.py onnxruntime shared lib inside python package (#21223) 2024-07-02 15:37:50 -07:00
ThirdPartyNotices.txt Fix HalideIR title in third party notices reference (#20190) 2024-04-05 11:12:43 -07:00
VERSION_NUMBER Bump up version in main from 1.18.0 to 1.19.0 (#20489) 2024-04-29 20:21:41 -07:00

ONNX Runtime is a cross-platform inference and training machine-learning accelerator.

ONNX Runtime inference can enable faster customer experiences and lower costs, supporting models from deep learning frameworks such as PyTorch and TensorFlow/Keras as well as classical machine learning libraries such as scikit-learn, LightGBM, XGBoost, etc. ONNX Runtime is compatible with different hardware, drivers, and operating systems, and provides optimal performance by leveraging hardware accelerators where applicable alongside graph optimizations and transforms. Learn more →

ONNX Runtime training can accelerate the model training time on multi-node NVIDIA GPUs for transformer models with a one-line addition for existing PyTorch training scripts. Learn more →

Get Started & Resources

Builtin Pipeline Status

System Inference Training
Windows Build Status
Build Status
Build Status
Linux Build Status
Build Status
Build Status
Build Status
Build Status
Build Status
Build Status
Build Status
Mac Build Status
Android Build Status
iOS Build Status
Web Build Status
Other Build Status

Third-party Pipeline Status

System Inference Training
Linux Build Status

Data/Telemetry

Windows distributions of this project may collect usage data and send it to Microsoft to help improve our products and services. See the privacy statement for more details.

Contributions and Feedback

We welcome contributions! Please see the contribution guidelines.

For feature requests or bug reports, please file a GitHub Issue.

For general discussion or questions, please use GitHub Discussions.

Code of Conduct

This project has adopted the Microsoft Open Source Code of Conduct. For more information see the Code of Conduct FAQ or contact opencode@microsoft.com with any additional questions or comments.

License

This project is licensed under the MIT License.