onnxruntime/tools/ci_build/github/azure-pipelines
Wei-Sheng Chin faef9c32fa
ONNX-Native Tensor Parallel: Using Distributed MatMul as Example (#17695)
This PR introduces
- New data structure to represent kernel-level (aka node-level or
op-level) tensor sharding informaiton. I consider it as the
fundamentaion of ONNX distribtued inference.
- Building blocks for distribtued kernels implementation especially
stateless implementation for communication ops.
- Implementation of DistributedMatMul and its tests.

Code structure:
- sharding.h/.cc: Function to shard and reshard tensors (calling into
NCCL).
- sharding_spec.h/.cc: Representation of how a tensor is sharded.
- distributed_matmul.h/.cc: Implementation of tensor parallel MatMul.
Inputs and outputs are sharded across devices.
- onnxruntime_test_distributed.py: distributed operator tests.

Example of specifying sharding information
```python
        @onnxscript.script()
        def matmul_rs_sr_rr(tensor_x: FLOAT, tensor_w: FLOAT) -> FLOAT:
            # Run MatMul by sharding x along column axis and w along row axis on
            # 2 GPUs.
            return MICROSOFT_OPSET.DistributedMatMul(
                tensor_x,
                tensor_w,
                device_mesh_shape=[2],
                device_mesh_elements=[0, 1],
                input_shard_specs=["RS[0]", "S[0]R"],
                output_shard_specs=["RR"],
            )
        onnx_model = matmul_rs_sr_rr.to_model_proto(
            input_types=[FLOAT[2, "s"], FLOAT["s", 2]],
            output_types=[FLOAT[2, 2]],
        )
```

In this example, the device mesh can be visualized as 1-D tensor, `[0,
1]`. The 2nd axis of `tensor_x` is sharded across `[0, 1]` (i.e., the
0-axis of the device mesh). Similarly, the 1st axis of `tensor_w` is
sharded across `[0, 1]` as well.

C++ classes to represent tensor sharding (copied from sharding_spec.h):
```cpp
class DeviceMesh {
 public:
  // [Device Mesh and Tensor Sharding for Tensor Parallel]
  // Device mesh is a tensor of device indices.
  // A tensor can then be partitioned along specific mesh axes.
  //
  // Assume we have 4 GPUs indexed by 0, 1, 2, and 3.
  // Let's consider some examples.
  //  1. 1D device mesh [0, 1, 2, 3]. In this case,
  //     device_mesh_shape is [4] and device_mesh_elements
  //     is [0, 1, 2, 3].
  //     If we want to shard a 2-D tensor along its axis 1, the
  //     corresponding sharding spec is a string "RS[0]".
  //  2. 2D device mesh [[0, 1], [2, 3]]. In this case,
  //     device_mesh_shape is [2, 2] and device_mesh_elements
  //     is [0, 1, 2, 3].
  //     If we want to shard a 2-D tensor's
  //     rows along mesh axis 1 and
  //     columns along mesh axis 0, the
  //     corresponding sharding spec is a string "S[1]S[0]".
  //     If that 2-D tensor's value is np.array([[5, 6], [7, 8]]),
  //     GPU 0/1/2/3 owns 5/7/6/8.  Below is a visualization the sharding
  //     proccess.
  //     - Start with a 2-D device mesh [[0, 1], [2, 3]] and
  //       a 2-D tensor [[5, 6], [7, 8]]
  //       - GPU: [[0, 1], [2, 3]], Tensor: [[5, 6], [7, 8]]
  //     - Split GPU mesh along axis 1 and tensor along
  //       axis 0 for "S[1]" in "S[1]S[0]"
  //       - GPU: [[0], [2]], Tensor: [[5, 6]]
  //         GPU: [[1], [3]], Tensor: [[7, 8]]
  //     - Split GPU mesh along axis 0 and tensor along
  //       axis 1 for "S[0]" in "S[1]S[0]"
  //       - GPU: [[0]], Tensor: [[5]]
  //       - GPU: [[2]], Tensor: [[6]]
  //       - GPU: [[1]], Tensor: [[7]]
  //       - GPU: [[3]], Tensor: [[8]]

  // Actual shape of device mesh represented by `device_mesh_elements`.
  std::vector<int64_t> device_mesh_shape;

  // Flattened device mesh.
  std::vector<int64_t> device_mesh_elements;
};

class AxisPartitionSpec {
  // [Device Mesh and Tensor Sharding for Tensor Parallel]
  // This class is the in-memory representation of
  //  1. if a tensor is sharded or not (aka replica), and
  //  2. which tensor axis is shard by which device mesh axis.
  // Let's consider sharding 2-D tensor along column axis on
  // device mesh [0, 1] as an example.
  // The required sharding spec RS[0] can be represented by
  // - AxisPartitionSpec(Condition::Replica, -1)
  // - AxisPartitionSpec(Condition::Shard, 0)
 public:
  // Status of a tensor axis.
  // A tensor axis can be either sharded or replicated
  // along a device mesh axis.
  enum class Condition { Replica,
                         Shard };

  // This field tells if a tensor axis is sharded or not.
  Condition cond;

  // If a tensor axis is sharded, this field tells which device
  // mesh axis to distribute the shards along.
  // If a tensor axis is not sharded, this field is ignored.
  int device_mesh_axis;

  // A helper to construct a replica spec for a tensor axis.
  static AxisPartitionSpec CreateReplica() {
    return AxisPartitionSpec(Condition::Replica, -1);
  }

  // A helper to construct a sharding spec for a tensor axis.
  // This tensor axis is sharded along `device_mesh_axis` in device mesh.
  static AxisPartitionSpec CreateShard(int device_mesh_axis) {
    return AxisPartitionSpec(Condition::Shard, device_mesh_axis);
  }
};

class TensorPartitionSpec {
  // [Device Mesh and Tensor Sharding for Tensor Parallel]
  // TensorPartitionSpec holds a collection of AxisPartitionSpec and an
  // associated DeviceMesh. It is responsible for determining how a tensor
  // should be partitioned across a device mesh.
  //
  // Example 1: RS[0]
  // In this scenario, `axis_specs` would contain two `AxisPartitionSpec` objects.
  // - The first object is a Replica, denoting that the first axis of the tensor is
  //   not sharded but is instead replicated.
  // - The second object is a Shard along the 0-th axis of the device mesh. It denotes
  //   that the second axis of the tensor is sharded along the first axis of the
  //   device mesh.
  //
  // Example 2: S[0]RR
  // In this scenario, `axis_specs` would contain three `AxisPartitionSpec` objects.
  // - The first object is a Shard along the 0-th axis of the device mesh, indicating
  //   that the first axis of the tensor is sharded along the first axis of the
  //   device mesh.
  // - The second and third objects are Replicas, indicating that the second and third
  //   axes of the tensor are not sharded but are instead replicated.
 public:
  // axis_specs[i]: AxisPartitionSpec for tensor axis i. For a 2-D tensor,
  //                axis_specs[0] is for row axis and axis_specs[1] is for
  //                column axis. axis_specs[i].device_mesh_axis = j means that
  //                tensor axis i is sharded along device mesh axis j.
  std::vector<AxisPartitionSpec> axis_specs;

  // device_mesh: DeviceMesh for sharding the associated tensor.
  // Read [Device Mesh and Tensor Sharding for Tensor Parallel] in DeviceMesh's comment.
  DeviceMesh device_mesh;
};
```
2023-10-05 14:22:25 -07:00
..
nodejs/templates
nuget/templates
templates [js/webgpu] support IO binding (#17480) 2023-09-29 11:24:42 -07:00
triggers
android-arm64-v8a-QNN-crosscompile-ci-pipeline.yml [QNN EP] Update QNN SDK to version 2.14.1 (#17467) 2023-09-11 21:07:50 -07:00
android-x86_64-crosscompile-ci-pipeline.yml
binary-size-checks-pipeline.yml
build-perf-test-binaries-pipeline.yml
c-api-noopenmp-packaging-pipelines.yml Run Final_Jar_Testing_Linux_GPU in docker (#17533) 2023-09-15 08:35:55 -07:00
clean-build-docker-image-cache-pipeline.yml
linux-ci-pipeline.yml Update nodejs to 18.x (#17657) 2023-09-25 14:12:11 -07:00
linux-cpu-aten-pipeline.yml Update nodejs to 18.x (#17657) 2023-09-25 14:12:11 -07:00
linux-cpu-eager-pipeline.yml
linux-cpu-minimal-build-ci-pipeline.yml
linux-dnnl-ci-pipeline.yml
linux-gpu-ci-pipeline.yml
linux-gpu-tensorrt-ci-pipeline.yml
linux-gpu-tensorrt-daily-perf-pipeline.yml
linux-migraphx-ci-pipeline.yml
linux-multi-gpu-tensorrt-ci-pipeline.yml Update nodejs to 18.x (#17657) 2023-09-25 14:12:11 -07:00
linux-openvino-ci-pipeline.yml Update nodejs to 18.x (#17657) 2023-09-25 14:12:11 -07:00
linux-qnn-ci-pipeline.yml [QNN EP] Update QNN SDK to version 2.14.1 (#17467) 2023-09-11 21:07:50 -07:00
mac-ci-pipeline.yml
mac-coreml-ci-pipeline.yml
mac-ios-ci-pipeline.yml
mac-ios-packaging-pipeline.yml
mac-objc-static-analysis-ci-pipeline.yml
mac-react-native-ci-pipeline.yml
npm-packaging-pipeline.yml Update npm-packaging-pipeline.yml to always use artifacts from main branch (#17604) 2023-09-19 14:42:08 -07:00
orttraining-linux-ci-pipeline.yml Update nodejs to 18.x (#17657) 2023-09-25 14:12:11 -07:00
orttraining-linux-gpu-ci-pipeline.yml Update nodejs to 18.x (#17657) 2023-09-25 14:12:11 -07:00
orttraining-linux-gpu-ortmodule-distributed-test-ci-pipeline.yml ONNX-Native Tensor Parallel: Using Distributed MatMul as Example (#17695) 2023-10-05 14:22:25 -07:00
orttraining-linux-nightly-ortmodule-test-pipeline.yml update acpt image for the training ci nightly (#17521) 2023-09-12 22:32:20 -07:00
orttraining-mac-ci-pipeline.yml
orttraining-pai-ci-pipeline.yml [ROCm] fix CI (#17648) 2023-09-21 07:37:50 -07:00
orttraining-py-packaging-pipeline-cpu.yml
orttraining-py-packaging-pipeline-cuda.yml
orttraining-py-packaging-pipeline-rocm.yml [ROCm] Remove ROCm5.4.2, ROCm 5.5 and add ROCm5.7 to python package pipeline (#17668) 2023-09-25 10:35:28 +08:00
post-merge-jobs.yml Add test for iOS dynamic framework (#17790) 2023-10-05 11:18:51 -07:00
py-package-build-pipeline.yml
py-package-test-pipeline.yml Remove dnf update from docker build scripts (#17551) 2023-09-21 07:33:29 -07:00
py-packaging-pipeline.yml
qnn-ep-nuget-packaging-pipeline.yml [QNN EP] Update QNN SDK to version 2.14.1 (#17467) 2023-09-11 21:07:50 -07:00
web-ci-pipeline.yml
win-ci-fuzz-testing.yml Update nodejs to 18.x (#17657) 2023-09-25 14:12:11 -07:00
win-ci-pipeline.yml Delete all Prefast tasks (#17522) 2023-09-12 17:40:49 -07:00
win-gpu-ci-pipeline.yml Delete all Prefast tasks (#17522) 2023-09-12 17:40:49 -07:00
win-gpu-reduce-op-ci-pipeline.yml
win-gpu-tensorrt-ci-pipeline.yml
win-qnn-arm64-ci-pipeline.yml [QNN EP] Update QNN SDK to version 2.14.1 (#17467) 2023-09-11 21:07:50 -07:00
win-qnn-ci-pipeline.yml Improve Win QNNEP pipeline (#17586) 2023-09-19 07:36:17 +08:00