onnxruntime/tools/ci_build/github/azure-pipelines/orttraining-linux-gpu-ortmodule-distributed-test-ci-pipeline.yml
Wei-Sheng Chin faef9c32fa
ONNX-Native Tensor Parallel: Using Distributed MatMul as Example (#17695)
This PR introduces
- New data structure to represent kernel-level (aka node-level or
op-level) tensor sharding informaiton. I consider it as the
fundamentaion of ONNX distribtued inference.
- Building blocks for distribtued kernels implementation especially
stateless implementation for communication ops.
- Implementation of DistributedMatMul and its tests.

Code structure:
- sharding.h/.cc: Function to shard and reshard tensors (calling into
NCCL).
- sharding_spec.h/.cc: Representation of how a tensor is sharded.
- distributed_matmul.h/.cc: Implementation of tensor parallel MatMul.
Inputs and outputs are sharded across devices.
- onnxruntime_test_distributed.py: distributed operator tests.

Example of specifying sharding information
```python
        @onnxscript.script()
        def matmul_rs_sr_rr(tensor_x: FLOAT, tensor_w: FLOAT) -> FLOAT:
            # Run MatMul by sharding x along column axis and w along row axis on
            # 2 GPUs.
            return MICROSOFT_OPSET.DistributedMatMul(
                tensor_x,
                tensor_w,
                device_mesh_shape=[2],
                device_mesh_elements=[0, 1],
                input_shard_specs=["RS[0]", "S[0]R"],
                output_shard_specs=["RR"],
            )
        onnx_model = matmul_rs_sr_rr.to_model_proto(
            input_types=[FLOAT[2, "s"], FLOAT["s", 2]],
            output_types=[FLOAT[2, 2]],
        )
```

In this example, the device mesh can be visualized as 1-D tensor, `[0,
1]`. The 2nd axis of `tensor_x` is sharded across `[0, 1]` (i.e., the
0-axis of the device mesh). Similarly, the 1st axis of `tensor_w` is
sharded across `[0, 1]` as well.

C++ classes to represent tensor sharding (copied from sharding_spec.h):
```cpp
class DeviceMesh {
 public:
  // [Device Mesh and Tensor Sharding for Tensor Parallel]
  // Device mesh is a tensor of device indices.
  // A tensor can then be partitioned along specific mesh axes.
  //
  // Assume we have 4 GPUs indexed by 0, 1, 2, and 3.
  // Let's consider some examples.
  //  1. 1D device mesh [0, 1, 2, 3]. In this case,
  //     device_mesh_shape is [4] and device_mesh_elements
  //     is [0, 1, 2, 3].
  //     If we want to shard a 2-D tensor along its axis 1, the
  //     corresponding sharding spec is a string "RS[0]".
  //  2. 2D device mesh [[0, 1], [2, 3]]. In this case,
  //     device_mesh_shape is [2, 2] and device_mesh_elements
  //     is [0, 1, 2, 3].
  //     If we want to shard a 2-D tensor's
  //     rows along mesh axis 1 and
  //     columns along mesh axis 0, the
  //     corresponding sharding spec is a string "S[1]S[0]".
  //     If that 2-D tensor's value is np.array([[5, 6], [7, 8]]),
  //     GPU 0/1/2/3 owns 5/7/6/8.  Below is a visualization the sharding
  //     proccess.
  //     - Start with a 2-D device mesh [[0, 1], [2, 3]] and
  //       a 2-D tensor [[5, 6], [7, 8]]
  //       - GPU: [[0, 1], [2, 3]], Tensor: [[5, 6], [7, 8]]
  //     - Split GPU mesh along axis 1 and tensor along
  //       axis 0 for "S[1]" in "S[1]S[0]"
  //       - GPU: [[0], [2]], Tensor: [[5, 6]]
  //         GPU: [[1], [3]], Tensor: [[7, 8]]
  //     - Split GPU mesh along axis 0 and tensor along
  //       axis 1 for "S[0]" in "S[1]S[0]"
  //       - GPU: [[0]], Tensor: [[5]]
  //       - GPU: [[2]], Tensor: [[6]]
  //       - GPU: [[1]], Tensor: [[7]]
  //       - GPU: [[3]], Tensor: [[8]]

  // Actual shape of device mesh represented by `device_mesh_elements`.
  std::vector<int64_t> device_mesh_shape;

  // Flattened device mesh.
  std::vector<int64_t> device_mesh_elements;
};

class AxisPartitionSpec {
  // [Device Mesh and Tensor Sharding for Tensor Parallel]
  // This class is the in-memory representation of
  //  1. if a tensor is sharded or not (aka replica), and
  //  2. which tensor axis is shard by which device mesh axis.
  // Let's consider sharding 2-D tensor along column axis on
  // device mesh [0, 1] as an example.
  // The required sharding spec RS[0] can be represented by
  // - AxisPartitionSpec(Condition::Replica, -1)
  // - AxisPartitionSpec(Condition::Shard, 0)
 public:
  // Status of a tensor axis.
  // A tensor axis can be either sharded or replicated
  // along a device mesh axis.
  enum class Condition { Replica,
                         Shard };

  // This field tells if a tensor axis is sharded or not.
  Condition cond;

  // If a tensor axis is sharded, this field tells which device
  // mesh axis to distribute the shards along.
  // If a tensor axis is not sharded, this field is ignored.
  int device_mesh_axis;

  // A helper to construct a replica spec for a tensor axis.
  static AxisPartitionSpec CreateReplica() {
    return AxisPartitionSpec(Condition::Replica, -1);
  }

  // A helper to construct a sharding spec for a tensor axis.
  // This tensor axis is sharded along `device_mesh_axis` in device mesh.
  static AxisPartitionSpec CreateShard(int device_mesh_axis) {
    return AxisPartitionSpec(Condition::Shard, device_mesh_axis);
  }
};

class TensorPartitionSpec {
  // [Device Mesh and Tensor Sharding for Tensor Parallel]
  // TensorPartitionSpec holds a collection of AxisPartitionSpec and an
  // associated DeviceMesh. It is responsible for determining how a tensor
  // should be partitioned across a device mesh.
  //
  // Example 1: RS[0]
  // In this scenario, `axis_specs` would contain two `AxisPartitionSpec` objects.
  // - The first object is a Replica, denoting that the first axis of the tensor is
  //   not sharded but is instead replicated.
  // - The second object is a Shard along the 0-th axis of the device mesh. It denotes
  //   that the second axis of the tensor is sharded along the first axis of the
  //   device mesh.
  //
  // Example 2: S[0]RR
  // In this scenario, `axis_specs` would contain three `AxisPartitionSpec` objects.
  // - The first object is a Shard along the 0-th axis of the device mesh, indicating
  //   that the first axis of the tensor is sharded along the first axis of the
  //   device mesh.
  // - The second and third objects are Replicas, indicating that the second and third
  //   axes of the tensor are not sharded but are instead replicated.
 public:
  // axis_specs[i]: AxisPartitionSpec for tensor axis i. For a 2-D tensor,
  //                axis_specs[0] is for row axis and axis_specs[1] is for
  //                column axis. axis_specs[i].device_mesh_axis = j means that
  //                tensor axis i is sharded along device mesh axis j.
  std::vector<AxisPartitionSpec> axis_specs;

  // device_mesh: DeviceMesh for sharding the associated tensor.
  // Read [Device Mesh and Tensor Sharding for Tensor Parallel] in DeviceMesh's comment.
  DeviceMesh device_mesh;
};
```
2023-10-05 14:22:25 -07:00

136 lines
4.6 KiB
YAML

##### start trigger Don't edit it manually, Please do edit set-trigger-rules.py ####
trigger:
branches:
include:
- main
- rel-*
paths:
exclude:
- docs/**
- README.md
- CONTRIBUTING.md
- BUILD.md
- 'js/web'
- 'onnxruntime/core/providers/js'
pr:
branches:
include:
- main
- rel-*
paths:
exclude:
- docs/**
- README.md
- CONTRIBUTING.md
- BUILD.md
- 'js/web'
- 'onnxruntime/core/providers/js'
#### end trigger ####
stages:
- stage: ORTModuleDistributedTest
dependsOn: []
jobs:
- job: Onnxruntime_Linux_GPU_ORTModule_Distributed_Test
timeoutInMinutes: 120
pool: 'Onnxruntime-Linux-GPU-NC24sv3'
steps:
- checkout: self
clean: true
submodules: recursive
- template: templates/run-docker-build-steps.yml
parameters:
RunDockerBuildArgs: |
-o ubuntu20.04 -d gpu \
-t onnxruntime_ortmodule_distributed_tests_image \
-x " \
--config RelWithDebInfo \
--use_cuda --cuda_version=11.8 --cuda_home=/usr/local/cuda-11.8 --cudnn_home=/usr/local/cuda-11.8 \
--enable_training \
--update --build \
--build_wheel \
" \
-m \
-u \
-e
DisplayName: 'Build'
- bash: tools/ci_build/github/linux/docker/scripts/training/azure_scale_set_vm_mount_test_data.sh -p $(orttrainingtestdatascus-storage-key) -s "//orttrainingtestdatascus.file.core.windows.net/mnist" -d "/mnist"
displayName: 'Mount MNIST'
condition: succeededOrFailed()
# Entry point for all ORTModule distributed tests
# Refer to orttraining/orttraining/test/python/how_to_add_ortmodule_distributed_ci_pipeline_tests.md for guidelines on how to add new tests to this pipeline.
- script: |
docker run \
--gpus all \
--shm-size=1024m \
--rm \
--volume $(Build.SourcesDirectory):/onnxruntime_src \
--volume $(Build.BinariesDirectory):/build \
--volume /mnist:/mnist \
onnxruntime_ortmodule_distributed_tests_image \
bash -c "rm -rf /build/RelWithDebInfo/onnxruntime/ && python3 -m pip install /build/RelWithDebInfo/dist/onnxruntime*.whl && python3 -m onnxruntime.training.ortmodule.torch_cpp_extensions.install && /build/RelWithDebInfo/launch_test.py --cmd_line_with_args 'python orttraining_ortmodule_distributed_tests.py --mnist /mnist' --cwd /build/RelWithDebInfo" \
displayName: 'Run orttraining_ortmodule_distributed_tests.py'
condition: succeededOrFailed()
timeoutInMinutes: 30
- template: templates/component-governance-component-detection-steps.yml
parameters:
condition: 'succeeded'
- template: templates/clean-agent-build-directory-step.yml
- stage: DistributedInferenceTest
dependsOn: []
jobs:
- job: Onnxruntime_Linux_GPU_Inference_Distributed_Test
timeoutInMinutes: 120
pool: 'Onnxruntime-Linux-GPU-NC24sv3'
steps:
- checkout: self
clean: true
submodules: recursive
- template: templates/run-docker-build-steps.yml
parameters:
RunDockerBuildArgs: |
-o ubuntu20.04 -d gpu \
-t onnxruntime_ortmodule_distributed_tests_image \
-x " \
--config RelWithDebInfo \
--use_cuda --cuda_version=11.8 --cuda_home=/usr/local/cuda-11.8 --cudnn_home=/usr/local/cuda-11.8 \
--update --build \
--build_wheel \
--use_mpi \
--enable_nccl \
" \
-m \
-u \
-e
DisplayName: 'Build'
- script: |
docker run \
--gpus all \
--shm-size=1024m \
--rm \
--volume $(Build.SourcesDirectory):/onnxruntime_src \
--volume $(Build.BinariesDirectory):/build \
--volume /mnist:/mnist \
onnxruntime_ortmodule_distributed_tests_image \
bash -c "rm -rf /build/RelWithDebInfo/onnxruntime/ && python3 -m pip install mpi4py onnxscript && python3 -m pip install /build/RelWithDebInfo/dist/onnxruntime*.whl && mpirun -n 4 -x NCCL_DEBUG=INFO python /onnxruntime_src/onnxruntime/test/python/onnxruntime_test_collective.py && mpirun -n 2 -x NCCL_DEBUG=INFO python /onnxruntime_src/onnxruntime/test/python/onnxruntime_test_distributed.py" \
displayName: 'Run onnxruntime_test_collective.py'
condition: succeededOrFailed()
timeoutInMinutes: 30
- template: templates/component-governance-component-detection-steps.yml
parameters:
condition: 'succeeded'
- template: templates/clean-agent-build-directory-step.yml