mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-17 21:10:43 +00:00
Issue#16990: Cast -> AllToAll -> Cast fails with random output (#17075)
This commit is contained in:
parent
bd8a488f4b
commit
d76dbc4fc3
2 changed files with 161 additions and 53 deletions
|
|
@ -4,6 +4,7 @@
|
|||
#include "nccl_kernels.h"
|
||||
#include "mpi_include.h"
|
||||
#include "core/providers/cuda/tensor/transpose.h"
|
||||
#include "core/providers/cuda/cuda_check_memory.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace contrib {
|
||||
|
|
@ -199,6 +200,9 @@ Status AllToAll::ComputeInternal(OpKernelContext* context) const {
|
|||
|
||||
char* output_data = static_cast<char*>(context->Output(0, out_shape)->MutableDataRaw());
|
||||
|
||||
CheckIfMemoryOnCurrentGpuDevice(input_data);
|
||||
CheckIfMemoryOnCurrentGpuDevice(output_data);
|
||||
|
||||
NCCL_RETURN_IF_ERROR(ncclGroupStart());
|
||||
for (int32_t r = 0; r < group_size_; r++) {
|
||||
NCCL_RETURN_IF_ERROR(ncclSend(input_data, rank_stride, dtype, r, comm, Stream(context)));
|
||||
|
|
@ -238,7 +242,6 @@ ONNX_OPERATOR_KERNEL_EX(
|
|||
1,
|
||||
kCudaExecutionProvider,
|
||||
(*KernelDefBuilder::Create())
|
||||
.VariadicAlias(0, 0) // outputs and inputs are mapped one to one
|
||||
.AllocateInputsContiguously()
|
||||
.TypeConstraint("T", DataTypeImpl::AllTensorTypes()),
|
||||
AllToAll);
|
||||
|
|
|
|||
|
|
@ -5,6 +5,7 @@ import unittest
|
|||
import numpy as np
|
||||
from mpi4py import MPI
|
||||
from onnx import TensorProto, helper
|
||||
from parameterized import parameterized
|
||||
|
||||
import onnxruntime as ort
|
||||
|
||||
|
|
@ -154,42 +155,86 @@ class ORTBertPretrainTest(unittest.TestCase):
|
|||
)
|
||||
return ORTBertPretrainTest._create_model_with_opsets(graph_def)
|
||||
|
||||
def test_all_reduce(self):
|
||||
for np_elem_type, elem_type in ((np.float32, TensorProto.FLOAT),):
|
||||
model = self._create_allreduce_ut_model((128, 128), elem_type)
|
||||
rank, size = self._get_rank_size()
|
||||
ort_sess = ort.InferenceSession(
|
||||
model.SerializeToString(),
|
||||
providers=["CUDAExecutionProvider", "CPUExecutionProvider"],
|
||||
provider_options=[{"device_id": str(rank)}, {}],
|
||||
)
|
||||
@parameterized.expand(
|
||||
[
|
||||
(np.float32, TensorProto.FLOAT),
|
||||
]
|
||||
)
|
||||
def test_all_reduce(self, np_elem_type, elem_type):
|
||||
model = self._create_allreduce_ut_model((128, 128), elem_type)
|
||||
rank, size = self._get_rank_size()
|
||||
ort_sess = ort.InferenceSession(
|
||||
model.SerializeToString(),
|
||||
providers=["CUDAExecutionProvider", "CPUExecutionProvider"],
|
||||
provider_options=[{"device_id": str(rank)}, {}],
|
||||
)
|
||||
|
||||
input = np.ones((128, 128), dtype=np_elem_type)
|
||||
outputs = ort_sess.run(None, {"X": input})
|
||||
assert np.allclose(outputs[0], size * input)
|
||||
# Input for size == 4
|
||||
# For rank == 0:
|
||||
# Input: [[1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1]]
|
||||
# Output: [[4, 4, 4, 4], [4, 4, 4, 4], [4, 4, 4, 4], [4, 4, 4, 4]]
|
||||
#
|
||||
# For rank == 1:
|
||||
# Input: [[1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1]]
|
||||
# Output: [[4, 4, 4, 4], [4, 4, 4, 4], [4, 4, 4, 4], [4, 4, 4, 4]]
|
||||
#
|
||||
# For rank == 2:
|
||||
# Input: [[1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1]]
|
||||
# Output: [[4, 4, 4, 4], [4, 4, 4, 4], [4, 4, 4, 4], [4, 4, 4, 4]]
|
||||
#
|
||||
# For rank == 3:
|
||||
# Input: [[1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1]]
|
||||
# Output: [[4, 4, 4, 4], [4, 4, 4, 4], [4, 4, 4, 4], [4, 4, 4, 4]]
|
||||
|
||||
def test_all_gather(self):
|
||||
for np_elem_type, elem_type, communication_elem_type in ((np.float32, TensorProto.FLOAT, TensorProto.FLOAT),):
|
||||
model = self._create_allgather_ut_model((128, 128), 0, elem_type, communication_elem_type)
|
||||
rank, size = self._get_rank_size()
|
||||
ort_sess = ort.InferenceSession(
|
||||
model.SerializeToString(),
|
||||
providers=["CUDAExecutionProvider", "CPUExecutionProvider"],
|
||||
provider_options=[{"device_id": str(rank)}, {}],
|
||||
)
|
||||
input = np.ones((128, 128), dtype=np_elem_type)
|
||||
outputs = ort_sess.run(None, {"X": input})
|
||||
|
||||
input = np.ones((128, 128), dtype=np.float32) * rank
|
||||
outputs = ort_sess.run(None, {"X": input})
|
||||
np.testing.assert_allclose(
|
||||
outputs[0], size * input, err_msg=f"{rank}: AllGather ({np_elem_type}, {elem_type}): results mismatch"
|
||||
)
|
||||
|
||||
expected_output = np.zeros((128, 128), dtype=np_elem_type)
|
||||
for _ in range(size - 1):
|
||||
expected_output = np.concatenate((expected_output, np.ones((128, 128), dtype=np_elem_type) * (_ + 1)))
|
||||
np.testing.assert_allclose(outputs[0], expected_output, err_msg="all gather on axis0: result mismatch")
|
||||
@parameterized.expand(
|
||||
[
|
||||
(np.float32, TensorProto.FLOAT, TensorProto.FLOAT),
|
||||
]
|
||||
)
|
||||
def test_all_gather(self, np_elem_type, elem_type, communication_elem_type):
|
||||
model = self._create_allgather_ut_model((128, 128), 0, elem_type, communication_elem_type)
|
||||
rank, size = self._get_rank_size()
|
||||
ort_sess = ort.InferenceSession(
|
||||
model.SerializeToString(),
|
||||
providers=["CUDAExecutionProvider", "CPUExecutionProvider"],
|
||||
provider_options=[{"device_id": str(rank)}, {}],
|
||||
)
|
||||
|
||||
# Input for size == 2
|
||||
# For rank == 0:
|
||||
# Input: [[0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0]]
|
||||
# Output: [[0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0],
|
||||
# [1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1]]
|
||||
#
|
||||
# For rank == 1:
|
||||
# Input: [[1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1]]
|
||||
# Output: [[0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0],
|
||||
# [1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1]]
|
||||
|
||||
input = np.ones((128, 128), dtype=np.float32) * rank
|
||||
outputs = ort_sess.run(None, {"X": input})
|
||||
|
||||
expected_output = np.zeros((128, 128), dtype=np_elem_type)
|
||||
for _ in range(size - 1):
|
||||
expected_output = np.concatenate((expected_output, np.ones((128, 128), dtype=np_elem_type) * (_ + 1)))
|
||||
|
||||
np.testing.assert_allclose(
|
||||
outputs[0],
|
||||
expected_output,
|
||||
err_msg=f"{rank}: AllGather (axis0) ({np_elem_type}, {elem_type}, {communication_elem_type}): results mismatch",
|
||||
)
|
||||
|
||||
def test_all_gather_bool(self):
|
||||
model = self._create_allgather_ut_model((4,), 0, TensorProto.INT64, TensorProto.INT64)
|
||||
rank, size = self._get_rank_size()
|
||||
print(f"rank: {rank}, size: {size}")
|
||||
rank, _ = self._get_rank_size()
|
||||
|
||||
ort_sess = ort.InferenceSession(
|
||||
model.SerializeToString(),
|
||||
providers=["CUDAExecutionProvider", "CPUExecutionProvider"],
|
||||
|
|
@ -203,7 +248,7 @@ class ORTBertPretrainTest(unittest.TestCase):
|
|||
[True, True, False, False] * 4,
|
||||
).astype(np.int64)
|
||||
|
||||
np.testing.assert_allclose(y, y_expected)
|
||||
np.testing.assert_allclose(y, y_expected, err_msg=f"{rank}: AllGather (bool): results mismatch")
|
||||
|
||||
def test_all_gather_axis1(self):
|
||||
model = self._create_allgather_ut_model((128, 128), 1)
|
||||
|
|
@ -221,41 +266,101 @@ class ORTBertPretrainTest(unittest.TestCase):
|
|||
for _ in range(size - 1):
|
||||
expected_output = np.concatenate((expected_output, np.ones((128, 128), dtype=np.float32) * (_ + 1)), axis=1)
|
||||
|
||||
np.testing.assert_allclose(outputs[0], expected_output, err_msg="all gather on axis1: result mismatch")
|
||||
np.testing.assert_allclose(outputs[0], expected_output, err_msg=f"{rank}: AllGather (axis1): results mismatch")
|
||||
|
||||
def test_all_to_all(self):
|
||||
for np_elem_type, elem_type, communication_elem_type in (
|
||||
@parameterized.expand(
|
||||
[
|
||||
(np.float32, TensorProto.FLOAT, TensorProto.FLOAT),
|
||||
(np.int64, TensorProto.INT64, TensorProto.INT64),
|
||||
# TODO: Fix the following case, which triggers random number-mismatch error.
|
||||
# (np.float32, TensorProto.INT64, TensorProto.BOOL),
|
||||
):
|
||||
model = self._create_alltoall_ut_model(
|
||||
(128, 128), elem_type=elem_type, communication_elem_type=communication_elem_type
|
||||
)
|
||||
rank, size = self._get_rank_size()
|
||||
ort_sess = ort.InferenceSession(
|
||||
model.SerializeToString(),
|
||||
providers=["CUDAExecutionProvider", "CPUExecutionProvider"],
|
||||
provider_options=[{"device_id": str(rank)}, {}],
|
||||
)
|
||||
(np.int64, TensorProto.INT64, TensorProto.BOOL),
|
||||
]
|
||||
)
|
||||
def test_all_to_all(self, np_elem_type, elem_type, communication_elem_type):
|
||||
model = self._create_alltoall_ut_model(
|
||||
(8, 8), elem_type=elem_type, communication_elem_type=communication_elem_type
|
||||
)
|
||||
rank, size = self._get_rank_size()
|
||||
ort_sess = ort.InferenceSession(
|
||||
model.SerializeToString(),
|
||||
providers=["CUDAExecutionProvider", "CPUExecutionProvider"],
|
||||
provider_options=[{"device_id": str(rank)}, {}],
|
||||
)
|
||||
|
||||
input = np.ones((128, 128), dtype=np_elem_type) * rank
|
||||
outputs = ort_sess.run(None, {"X": input})
|
||||
# Casting to Boolean needs to be handled as a special case because only absoluate zero equates
|
||||
# to False and everything else to True. Post casting, however, the original input can never be
|
||||
# receovered when going in the opposite direction. The end results are always going to be in the
|
||||
# set [0, 1].
|
||||
|
||||
expected_output = np.zeros((int(128 / size), 128), dtype=np_elem_type)
|
||||
input = np.ones((8, 8), dtype=np_elem_type) * rank
|
||||
|
||||
if communication_elem_type == TensorProto.BOOL:
|
||||
# Input for size == 4
|
||||
# For rank == 0:
|
||||
# Input: [[0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0]]
|
||||
# Output: [[0, 0, 0, 0], [1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1]]
|
||||
#
|
||||
# For rank == 1:
|
||||
# Input: [[1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1]]
|
||||
# Output: [[0, 0, 0, 0], [1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1]]
|
||||
#
|
||||
# For rank == 2:
|
||||
# Input: [[2, 2, 2, 2], [2, 2, 2, 2], [2, 2, 2, 2], [2, 2, 2, 2]]
|
||||
# Output: [[0, 0, 0, 0], [1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1]]
|
||||
#
|
||||
# For rank == 3:
|
||||
# Input: [[3, 3, 3, 3], [3, 3, 3, 3], [3, 3, 3, 3], [3, 3, 3, 3]]
|
||||
# Output: [[0, 0, 0, 0], [1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1]]
|
||||
|
||||
expected_output = np.concatenate(
|
||||
(
|
||||
np.zeros((8 // size, 8), dtype=np_elem_type),
|
||||
np.ones((8 - (8 // size), 8), dtype=np_elem_type),
|
||||
)
|
||||
)
|
||||
else:
|
||||
# Input for size == 4
|
||||
# For rank == 0:
|
||||
# Input: [[0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0]]
|
||||
# Output: [[0, 0, 0, 0], [1, 1, 1, 1], [2, 2, 2, 2], [3, 3, 3, 3]]
|
||||
#
|
||||
# For rank == 1:
|
||||
# Input: [[1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1]]
|
||||
# Output: [[0, 0, 0, 0], [1, 1, 1, 1], [2, 2, 2, 2], [3, 3, 3, 3]]
|
||||
#
|
||||
# For rank == 2:
|
||||
# Input: [[2, 2, 2, 2], [2, 2, 2, 2], [2, 2, 2, 2], [2, 2, 2, 2]]
|
||||
# Output: [[0, 0, 0, 0], [1, 1, 1, 1], [2, 2, 2, 2], [3, 3, 3, 3]]
|
||||
#
|
||||
# For rank == 3:
|
||||
# Input: [[3, 3, 3, 3], [3, 3, 3, 3], [3, 3, 3, 3], [3, 3, 3, 3]]
|
||||
# Output: [[0, 0, 0, 0], [1, 1, 1, 1], [2, 2, 2, 2], [3, 3, 3, 3]]
|
||||
|
||||
expected_output = np.zeros((8 // size, 8), dtype=np_elem_type)
|
||||
for _ in range(size - 1):
|
||||
expected_output = np.concatenate(
|
||||
(expected_output, np.ones((int(128 / size), 128), dtype=np_elem_type) * (_ + 1))
|
||||
(expected_output, np.ones((8 // size, 8), dtype=np_elem_type) * (_ + 1))
|
||||
)
|
||||
|
||||
print("outputs[0]: ", outputs[0] - expected_output)
|
||||
outputs = ort_sess.run(None, {"X": input})
|
||||
|
||||
assert np.allclose(outputs[0], expected_output)
|
||||
np.testing.assert_allclose(
|
||||
outputs[0],
|
||||
expected_output,
|
||||
err_msg=f"{rank}: AllToAll ({np_elem_type}, {elem_type}, {communication_elem_type}): results mismatch",
|
||||
)
|
||||
|
||||
def test_all_to_all_bool(self):
|
||||
rank, _ = self._get_rank_size()
|
||||
|
||||
# Input for size == 2
|
||||
# For rank == 0:
|
||||
# Input: [True, True, True, True]
|
||||
# Output: [True, True, False, False]
|
||||
#
|
||||
# For rank == 1:
|
||||
# Input: [False, False, False, False]
|
||||
# Output: [True, True, False, False]
|
||||
|
||||
if rank == 0:
|
||||
x = [True, True, True, True]
|
||||
else:
|
||||
|
|
@ -275,7 +380,7 @@ class ORTBertPretrainTest(unittest.TestCase):
|
|||
[True, False, False, False],
|
||||
).astype(np.int64)
|
||||
|
||||
np.testing.assert_allclose(y[0], y_expected)
|
||||
np.testing.assert_allclose(y[0], y_expected, err_msg=f"{rank}: AllToAll (bool): results mismatch")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
|||
Loading…
Reference in a new issue