mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-16 21:00:14 +00:00
This PR introduces
- New data structure to represent kernel-level (aka node-level or
op-level) tensor sharding informaiton. I consider it as the
fundamentaion of ONNX distribtued inference.
- Building blocks for distribtued kernels implementation especially
stateless implementation for communication ops.
- Implementation of DistributedMatMul and its tests.
Code structure:
- sharding.h/.cc: Function to shard and reshard tensors (calling into
NCCL).
- sharding_spec.h/.cc: Representation of how a tensor is sharded.
- distributed_matmul.h/.cc: Implementation of tensor parallel MatMul.
Inputs and outputs are sharded across devices.
- onnxruntime_test_distributed.py: distributed operator tests.
Example of specifying sharding information
```python
@onnxscript.script()
def matmul_rs_sr_rr(tensor_x: FLOAT, tensor_w: FLOAT) -> FLOAT:
# Run MatMul by sharding x along column axis and w along row axis on
# 2 GPUs.
return MICROSOFT_OPSET.DistributedMatMul(
tensor_x,
tensor_w,
device_mesh_shape=[2],
device_mesh_elements=[0, 1],
input_shard_specs=["RS[0]", "S[0]R"],
output_shard_specs=["RR"],
)
onnx_model = matmul_rs_sr_rr.to_model_proto(
input_types=[FLOAT[2, "s"], FLOAT["s", 2]],
output_types=[FLOAT[2, 2]],
)
```
In this example, the device mesh can be visualized as 1-D tensor, `[0,
1]`. The 2nd axis of `tensor_x` is sharded across `[0, 1]` (i.e., the
0-axis of the device mesh). Similarly, the 1st axis of `tensor_w` is
sharded across `[0, 1]` as well.
C++ classes to represent tensor sharding (copied from sharding_spec.h):
```cpp
class DeviceMesh {
public:
// [Device Mesh and Tensor Sharding for Tensor Parallel]
// Device mesh is a tensor of device indices.
// A tensor can then be partitioned along specific mesh axes.
//
// Assume we have 4 GPUs indexed by 0, 1, 2, and 3.
// Let's consider some examples.
// 1. 1D device mesh [0, 1, 2, 3]. In this case,
// device_mesh_shape is [4] and device_mesh_elements
// is [0, 1, 2, 3].
// If we want to shard a 2-D tensor along its axis 1, the
// corresponding sharding spec is a string "RS[0]".
// 2. 2D device mesh [[0, 1], [2, 3]]. In this case,
// device_mesh_shape is [2, 2] and device_mesh_elements
// is [0, 1, 2, 3].
// If we want to shard a 2-D tensor's
// rows along mesh axis 1 and
// columns along mesh axis 0, the
// corresponding sharding spec is a string "S[1]S[0]".
// If that 2-D tensor's value is np.array([[5, 6], [7, 8]]),
// GPU 0/1/2/3 owns 5/7/6/8. Below is a visualization the sharding
// proccess.
// - Start with a 2-D device mesh [[0, 1], [2, 3]] and
// a 2-D tensor [[5, 6], [7, 8]]
// - GPU: [[0, 1], [2, 3]], Tensor: [[5, 6], [7, 8]]
// - Split GPU mesh along axis 1 and tensor along
// axis 0 for "S[1]" in "S[1]S[0]"
// - GPU: [[0], [2]], Tensor: [[5, 6]]
// GPU: [[1], [3]], Tensor: [[7, 8]]
// - Split GPU mesh along axis 0 and tensor along
// axis 1 for "S[0]" in "S[1]S[0]"
// - GPU: [[0]], Tensor: [[5]]
// - GPU: [[2]], Tensor: [[6]]
// - GPU: [[1]], Tensor: [[7]]
// - GPU: [[3]], Tensor: [[8]]
// Actual shape of device mesh represented by `device_mesh_elements`.
std::vector<int64_t> device_mesh_shape;
// Flattened device mesh.
std::vector<int64_t> device_mesh_elements;
};
class AxisPartitionSpec {
// [Device Mesh and Tensor Sharding for Tensor Parallel]
// This class is the in-memory representation of
// 1. if a tensor is sharded or not (aka replica), and
// 2. which tensor axis is shard by which device mesh axis.
// Let's consider sharding 2-D tensor along column axis on
// device mesh [0, 1] as an example.
// The required sharding spec RS[0] can be represented by
// - AxisPartitionSpec(Condition::Replica, -1)
// - AxisPartitionSpec(Condition::Shard, 0)
public:
// Status of a tensor axis.
// A tensor axis can be either sharded or replicated
// along a device mesh axis.
enum class Condition { Replica,
Shard };
// This field tells if a tensor axis is sharded or not.
Condition cond;
// If a tensor axis is sharded, this field tells which device
// mesh axis to distribute the shards along.
// If a tensor axis is not sharded, this field is ignored.
int device_mesh_axis;
// A helper to construct a replica spec for a tensor axis.
static AxisPartitionSpec CreateReplica() {
return AxisPartitionSpec(Condition::Replica, -1);
}
// A helper to construct a sharding spec for a tensor axis.
// This tensor axis is sharded along `device_mesh_axis` in device mesh.
static AxisPartitionSpec CreateShard(int device_mesh_axis) {
return AxisPartitionSpec(Condition::Shard, device_mesh_axis);
}
};
class TensorPartitionSpec {
// [Device Mesh and Tensor Sharding for Tensor Parallel]
// TensorPartitionSpec holds a collection of AxisPartitionSpec and an
// associated DeviceMesh. It is responsible for determining how a tensor
// should be partitioned across a device mesh.
//
// Example 1: RS[0]
// In this scenario, `axis_specs` would contain two `AxisPartitionSpec` objects.
// - The first object is a Replica, denoting that the first axis of the tensor is
// not sharded but is instead replicated.
// - The second object is a Shard along the 0-th axis of the device mesh. It denotes
// that the second axis of the tensor is sharded along the first axis of the
// device mesh.
//
// Example 2: S[0]RR
// In this scenario, `axis_specs` would contain three `AxisPartitionSpec` objects.
// - The first object is a Shard along the 0-th axis of the device mesh, indicating
// that the first axis of the tensor is sharded along the first axis of the
// device mesh.
// - The second and third objects are Replicas, indicating that the second and third
// axes of the tensor are not sharded but are instead replicated.
public:
// axis_specs[i]: AxisPartitionSpec for tensor axis i. For a 2-D tensor,
// axis_specs[0] is for row axis and axis_specs[1] is for
// column axis. axis_specs[i].device_mesh_axis = j means that
// tensor axis i is sharded along device mesh axis j.
std::vector<AxisPartitionSpec> axis_specs;
// device_mesh: DeviceMesh for sharding the associated tensor.
// Read [Device Mesh and Tensor Sharding for Tensor Parallel] in DeviceMesh's comment.
DeviceMesh device_mesh;
};
```
|
||
|---|---|---|
| .. | ||
| android | ||
| apple | ||
| azure-pipelines | ||
| js | ||
| linux | ||
| pai | ||
| windows | ||
| Doxyfile_csharp.cfg | ||