ONNX Runtime: cross-platform, high performance ML inferencing and training accelerator
Find a file
Wei-Sheng Chin faef9c32fa
ONNX-Native Tensor Parallel: Using Distributed MatMul as Example (#17695)
This PR introduces
- New data structure to represent kernel-level (aka node-level or
op-level) tensor sharding informaiton. I consider it as the
fundamentaion of ONNX distribtued inference.
- Building blocks for distribtued kernels implementation especially
stateless implementation for communication ops.
- Implementation of DistributedMatMul and its tests.

Code structure:
- sharding.h/.cc: Function to shard and reshard tensors (calling into
NCCL).
- sharding_spec.h/.cc: Representation of how a tensor is sharded.
- distributed_matmul.h/.cc: Implementation of tensor parallel MatMul.
Inputs and outputs are sharded across devices.
- onnxruntime_test_distributed.py: distributed operator tests.

Example of specifying sharding information
```python
        @onnxscript.script()
        def matmul_rs_sr_rr(tensor_x: FLOAT, tensor_w: FLOAT) -> FLOAT:
            # Run MatMul by sharding x along column axis and w along row axis on
            # 2 GPUs.
            return MICROSOFT_OPSET.DistributedMatMul(
                tensor_x,
                tensor_w,
                device_mesh_shape=[2],
                device_mesh_elements=[0, 1],
                input_shard_specs=["RS[0]", "S[0]R"],
                output_shard_specs=["RR"],
            )
        onnx_model = matmul_rs_sr_rr.to_model_proto(
            input_types=[FLOAT[2, "s"], FLOAT["s", 2]],
            output_types=[FLOAT[2, 2]],
        )
```

In this example, the device mesh can be visualized as 1-D tensor, `[0,
1]`. The 2nd axis of `tensor_x` is sharded across `[0, 1]` (i.e., the
0-axis of the device mesh). Similarly, the 1st axis of `tensor_w` is
sharded across `[0, 1]` as well.

C++ classes to represent tensor sharding (copied from sharding_spec.h):
```cpp
class DeviceMesh {
 public:
  // [Device Mesh and Tensor Sharding for Tensor Parallel]
  // Device mesh is a tensor of device indices.
  // A tensor can then be partitioned along specific mesh axes.
  //
  // Assume we have 4 GPUs indexed by 0, 1, 2, and 3.
  // Let's consider some examples.
  //  1. 1D device mesh [0, 1, 2, 3]. In this case,
  //     device_mesh_shape is [4] and device_mesh_elements
  //     is [0, 1, 2, 3].
  //     If we want to shard a 2-D tensor along its axis 1, the
  //     corresponding sharding spec is a string "RS[0]".
  //  2. 2D device mesh [[0, 1], [2, 3]]. In this case,
  //     device_mesh_shape is [2, 2] and device_mesh_elements
  //     is [0, 1, 2, 3].
  //     If we want to shard a 2-D tensor's
  //     rows along mesh axis 1 and
  //     columns along mesh axis 0, the
  //     corresponding sharding spec is a string "S[1]S[0]".
  //     If that 2-D tensor's value is np.array([[5, 6], [7, 8]]),
  //     GPU 0/1/2/3 owns 5/7/6/8.  Below is a visualization the sharding
  //     proccess.
  //     - Start with a 2-D device mesh [[0, 1], [2, 3]] and
  //       a 2-D tensor [[5, 6], [7, 8]]
  //       - GPU: [[0, 1], [2, 3]], Tensor: [[5, 6], [7, 8]]
  //     - Split GPU mesh along axis 1 and tensor along
  //       axis 0 for "S[1]" in "S[1]S[0]"
  //       - GPU: [[0], [2]], Tensor: [[5, 6]]
  //         GPU: [[1], [3]], Tensor: [[7, 8]]
  //     - Split GPU mesh along axis 0 and tensor along
  //       axis 1 for "S[0]" in "S[1]S[0]"
  //       - GPU: [[0]], Tensor: [[5]]
  //       - GPU: [[2]], Tensor: [[6]]
  //       - GPU: [[1]], Tensor: [[7]]
  //       - GPU: [[3]], Tensor: [[8]]

  // Actual shape of device mesh represented by `device_mesh_elements`.
  std::vector<int64_t> device_mesh_shape;

  // Flattened device mesh.
  std::vector<int64_t> device_mesh_elements;
};

class AxisPartitionSpec {
  // [Device Mesh and Tensor Sharding for Tensor Parallel]
  // This class is the in-memory representation of
  //  1. if a tensor is sharded or not (aka replica), and
  //  2. which tensor axis is shard by which device mesh axis.
  // Let's consider sharding 2-D tensor along column axis on
  // device mesh [0, 1] as an example.
  // The required sharding spec RS[0] can be represented by
  // - AxisPartitionSpec(Condition::Replica, -1)
  // - AxisPartitionSpec(Condition::Shard, 0)
 public:
  // Status of a tensor axis.
  // A tensor axis can be either sharded or replicated
  // along a device mesh axis.
  enum class Condition { Replica,
                         Shard };

  // This field tells if a tensor axis is sharded or not.
  Condition cond;

  // If a tensor axis is sharded, this field tells which device
  // mesh axis to distribute the shards along.
  // If a tensor axis is not sharded, this field is ignored.
  int device_mesh_axis;

  // A helper to construct a replica spec for a tensor axis.
  static AxisPartitionSpec CreateReplica() {
    return AxisPartitionSpec(Condition::Replica, -1);
  }

  // A helper to construct a sharding spec for a tensor axis.
  // This tensor axis is sharded along `device_mesh_axis` in device mesh.
  static AxisPartitionSpec CreateShard(int device_mesh_axis) {
    return AxisPartitionSpec(Condition::Shard, device_mesh_axis);
  }
};

class TensorPartitionSpec {
  // [Device Mesh and Tensor Sharding for Tensor Parallel]
  // TensorPartitionSpec holds a collection of AxisPartitionSpec and an
  // associated DeviceMesh. It is responsible for determining how a tensor
  // should be partitioned across a device mesh.
  //
  // Example 1: RS[0]
  // In this scenario, `axis_specs` would contain two `AxisPartitionSpec` objects.
  // - The first object is a Replica, denoting that the first axis of the tensor is
  //   not sharded but is instead replicated.
  // - The second object is a Shard along the 0-th axis of the device mesh. It denotes
  //   that the second axis of the tensor is sharded along the first axis of the
  //   device mesh.
  //
  // Example 2: S[0]RR
  // In this scenario, `axis_specs` would contain three `AxisPartitionSpec` objects.
  // - The first object is a Shard along the 0-th axis of the device mesh, indicating
  //   that the first axis of the tensor is sharded along the first axis of the
  //   device mesh.
  // - The second and third objects are Replicas, indicating that the second and third
  //   axes of the tensor are not sharded but are instead replicated.
 public:
  // axis_specs[i]: AxisPartitionSpec for tensor axis i. For a 2-D tensor,
  //                axis_specs[0] is for row axis and axis_specs[1] is for
  //                column axis. axis_specs[i].device_mesh_axis = j means that
  //                tensor axis i is sharded along device mesh axis j.
  std::vector<AxisPartitionSpec> axis_specs;

  // device_mesh: DeviceMesh for sharding the associated tensor.
  // Read [Device Mesh and Tensor Sharding for Tensor Parallel] in DeviceMesh's comment.
  DeviceMesh device_mesh;
};
```
2023-10-05 14:22:25 -07:00
.config
.devcontainer
.gdn Update win-ci-pipeline.yml: enable xnnpack tests (#16244) 2023-06-14 19:12:42 -07:00
.github Bump actions/checkout from 3 to 4 (#17487) 2023-09-13 09:22:21 -07:00
.pipelines Bump DirectML version from 1.12.0 to 1.12.1 (#17225) 2023-08-20 09:55:38 -07:00
.vscode Close the JSON object in settings.json (#17583) 2023-09-26 09:51:13 -07:00
cgmanifests ONNX 1.15 integration (#17125) 2023-09-26 14:44:48 -07:00
cmake ONNX-Native Tensor Parallel: Using Distributed MatMul as Example (#17695) 2023-10-05 14:22:25 -07:00
csharp [On-Device Training] Expose Parameters through the Training API (#17364) 2023-09-25 20:03:24 -07:00
dockerfiles Update cmake to 3.27 and upgrade Linux CUDA docker files from CentOS7 to UBI8 (#16856) 2023-09-05 18:12:10 -07:00
docs ONNX 1.15 integration (#17125) 2023-09-26 14:44:48 -07:00
include/onnxruntime/core [QNN EP] Enable QNN Saver for debugging issues (#17747) 2023-10-03 16:24:33 -07:00
java [java] Filling out the javadoc for the float8 types (#17694) 2023-09-27 10:52:11 -07:00
js [js/webgpu] Enable the NCHW ConvMatMul path (#17717) 2023-10-05 00:26:01 -07:00
objectivec Objective-C Add Support to Create and Query String ORTValues (#16764) 2023-07-20 17:39:29 -07:00
onnxruntime ONNX-Native Tensor Parallel: Using Distributed MatMul as Example (#17695) 2023-10-05 14:22:25 -07:00
orttraining Python API to check whether collective ops are available or not (#17730) 2023-09-29 14:11:05 -07:00
rust rust bindings: Do not unnecessarily re-run build.rs (#17018) 2023-09-05 19:42:06 -07:00
samples
swift/OnnxRuntimeBindingsTests Add iOS Swift Package Manager support (#15297) 2023-04-20 16:18:35 +10:00
tools ONNX-Native Tensor Parallel: Using Distributed MatMul as Example (#17695) 2023-10-05 14:22:25 -07:00
winml Add support for specifying a custom logging function per session. (#17727) 2023-09-29 19:46:55 -07:00
.clang-format Prevent GSL_SUPPRESS arguments from being modified by clang-format (#17242) 2023-08-22 18:26:53 -07:00
.clang-tidy
.dockerignore
.gitattributes
.gitignore remove 'lib/' from .gitignore (#15613) 2023-04-24 18:43:32 -07:00
.gitmodules Remove onnxruntime extensions from list of gitmodules (#17615) 2023-09-19 17:12:14 -07:00
.lintrunner.toml Format c++ code under winml/ (#16660) 2023-07-25 21:56:50 -07:00
build.bat try to find patch.exe in git default installation folder (#17106) 2023-08-10 21:48:13 -07:00
build.sh Upgrade old Python version in packaging pipeline (#16667) 2023-07-17 08:24:47 -07:00
CITATION.cff
CODEOWNERS
CONTRIBUTING.md
lgtm.yml
LICENSE
NuGet.config
ort.wprp
ORT_icon_for_light_bg.png
Package.swift Objective-C Add Support to Create and Query String ORTValues (#16764) 2023-07-20 17:39:29 -07:00
packages.config Bump DirectML version from 1.12.0 to 1.12.1 (#17225) 2023-08-20 09:55:38 -07:00
pyproject.toml Updating QDQ to support Float8E4M3FN (#16550) 2023-08-08 12:18:48 +02:00
README.md add third-party pipeline status to README.md (#16155) 2023-05-31 22:14:39 -07:00
requirements-dev.txt ONNX 1.15 integration (#17125) 2023-09-26 14:44:48 -07:00
requirements-doc.txt
requirements-lintrunner.txt Bump clang-format to 16.0.6 in CI (#17099) 2023-08-10 13:53:04 -07:00
requirements-training.txt ONNX 1.15 integration (#17125) 2023-09-26 14:44:48 -07:00
requirements.txt.in
SECURITY.md
setup.py Update tensorrt_dependencies in setup.py (#17562) 2023-09-15 08:20:47 -07:00
ThirdPartyNotices.txt Flash Attention v2 MHA (#17227) 2023-08-31 13:52:21 -07:00
VERSION_NUMBER Bump Up Version to 1.17.0 (#17587) 2023-09-20 11:02:58 +08:00

ONNX Runtime is a cross-platform inference and training machine-learning accelerator.

ONNX Runtime inference can enable faster customer experiences and lower costs, supporting models from deep learning frameworks such as PyTorch and TensorFlow/Keras as well as classical machine learning libraries such as scikit-learn, LightGBM, XGBoost, etc. ONNX Runtime is compatible with different hardware, drivers, and operating systems, and provides optimal performance by leveraging hardware accelerators where applicable alongside graph optimizations and transforms. Learn more →

ONNX Runtime training can accelerate the model training time on multi-node NVIDIA GPUs for transformer models with a one-line addition for existing PyTorch training scripts. Learn more →

Get Started & Resources

Builtin Pipeline Status

System Inference Training
Windows Build Status
Build Status
Build Status
Linux Build Status
Build Status
Build Status
Build Status
Build Status
Build Status
Build Status
Build Status
Mac Build Status
Android Build Status
iOS Build Status
Web Build Status
Other Build Status
Build Status

Third-party Pipeline Status

System Inference Training
Linux Build Status

Data/Telemetry

Windows distributions of this project may collect usage data and send it to Microsoft to help improve our products and services. See the privacy statement for more details.

Contributions and Feedback

We welcome contributions! Please see the contribution guidelines.

For feature requests or bug reports, please file a GitHub Issue.

For general discussion or questions, please use GitHub Discussions.

Code of Conduct

This project has adopted the Microsoft Open Source Code of Conduct. For more information see the Code of Conduct FAQ or contact opencode@microsoft.com with any additional questions or comments.

License

This project is licensed under the MIT License.