diff --git a/onnxruntime/core/providers/cuda/math/variadic_elementwise_ops.cc b/onnxruntime/core/providers/cuda/math/variadic_elementwise_ops.cc index 24e066d5d8..dc4db5a830 100644 --- a/onnxruntime/core/providers/cuda/math/variadic_elementwise_ops.cc +++ b/onnxruntime/core/providers/cuda/math/variadic_elementwise_ops.cc @@ -5,6 +5,7 @@ #include "core/providers/cuda/math/variadic_elementwise_ops.h" #include +#include #include "core/framework/data_types_internal.h" #include "core/providers/cuda/math/binary_elementwise_ops.h" @@ -17,21 +18,50 @@ namespace cuda { template template -Status VariadicElementwiseOp:: - NoBroadcastBatchImplDispatchTarget::operator()(cudaStream_t stream, const InputTensorVector& inputs, Tensor& output) const { - assert(inputs.size() > 1); - +Status VariadicElementwiseOp::NoBroadcastBatchImplDispatchTarget< + T>::operator()(cudaStream_t stream, const InputTensorVector& inputs, Tensor& output) const { using CudaT = typename ToCudaType::MappedType; - - InputBatchArray input_data_batch{static_cast(inputs.size())}; - for (size_t i = 0; i < inputs.size(); ++i) { + size_t input_count = inputs.size(); + assert(input_count > 1); + size_t index = std::min(input_count, static_cast(k_max_input_batch_size)); + InputBatchArray input_data_batch{static_cast(index)}; + for (size_t i = 0; i < index; ++i) { input_data_batch[static_cast(i)] = reinterpret_cast(inputs[i].get().template Data()); } CudaT* output_data = reinterpret_cast(output.template MutableData()); + Impl_NoBroadcastInputBatch(stream, input_data_batch, output_data, + output.Shape().Size()); - Impl_NoBroadcastInputBatch( - stream, input_data_batch, output_data, output.Shape().Size()); + while (index < input_count) { + size_t left_count = input_count - index + 1; + size_t batch = std::min(left_count, static_cast(k_max_input_batch_size)); + // Special case for 2 inputs left. + if (batch == 2) { + BinaryElementwisePreparation prepare; + ORT_RETURN_IF_ERROR(BinaryElementwiseBroadcastPrepare(&output, &inputs[input_count - 1].get(), &output, &prepare)); + Impl_General( + stream, prepare.output_rank_or_simple_broadcast, &prepare.lhs_padded_strides, + reinterpret_cast(prepare.lhs_tensor->template Data()), &prepare.rhs_padded_strides, + reinterpret_cast(prepare.rhs_tensor->template Data()), &prepare.fdm_output_strides, + prepare.fdm_H, prepare.fdm_C, reinterpret_cast(prepare.output_tensor->template MutableData()), + prepare.output_tensor->Shape().Size()); + + // Must be the last. + break; + } + + InputBatchArray left_input_data_batch{static_cast(batch)}; + left_input_data_batch[0] = reinterpret_cast(output.template Data()); + for (size_t i = 1; i < batch; ++i) { + left_input_data_batch[static_cast(i)] = + reinterpret_cast(inputs[index].get().template Data()); + index++; + } + + Impl_NoBroadcastInputBatch(stream, left_input_data_batch, output_data, + output.Shape().Size()); + } return Status::OK(); } @@ -65,44 +95,57 @@ Status VariadicElementwiseOp // for more than 2 inputs, we need to accumulate into output tensor, as the shape from input0 + input1 might be different from output shape template template -Status VariadicElementwiseOp:: - GeneralImplDispatchTarget::operator()(cudaStream_t stream, const InputTensorVector& inputs, Tensor& output) const { +Status +VariadicElementwiseOp::GeneralImplDispatchTarget::operator()( + cudaStream_t stream, const InputTensorVector& inputs, Tensor& output) const { assert(inputs.size() > 1); using CudaT = typename ToCudaType::MappedType; - CUDA_RETURN_IF_ERROR(cudaMemsetAsync(output.MutableDataRaw(), 0, output.SizeInBytes(), stream)); + // If there is any input having the same shape with output, we don't need the memset. + size_t index_of_same_shape = 0; + for (; index_of_same_shape < inputs.size(); index_of_same_shape++) { + if (inputs[index_of_same_shape].get().Shape() == output.Shape()) { + break; + } + } BinaryElementwisePreparation prepare; - ORT_RETURN_IF_ERROR(BinaryElementwiseBroadcastPrepare(&output, &inputs[0].get(), &output, &prepare)); - Impl_Add( - stream, - prepare.output_rank_or_simple_broadcast, - &prepare.lhs_padded_strides, - reinterpret_cast(prepare.lhs_tensor->template Data()), - &prepare.rhs_padded_strides, - reinterpret_cast(prepare.rhs_tensor->template Data()), - &prepare.fdm_output_strides, - prepare.fdm_H, - prepare.fdm_C, - reinterpret_cast(prepare.output_tensor->template MutableData()), - prepare.output_tensor->Shape().Size()); + // No input has same shape of output, memset the output, and add the 1st input as initialization. + if (index_of_same_shape == inputs.size()) { + CUDA_RETURN_IF_ERROR(cudaMemsetAsync(output.MutableDataRaw(), 0, output.SizeInBytes(), stream)); + ORT_RETURN_IF_ERROR(BinaryElementwiseBroadcastPrepare(&output, &inputs[0].get(), &output, &prepare)); + Impl_Add(stream, prepare.output_rank_or_simple_broadcast, &prepare.lhs_padded_strides, + reinterpret_cast(prepare.lhs_tensor->template Data()), &prepare.rhs_padded_strides, + reinterpret_cast(prepare.rhs_tensor->template Data()), &prepare.fdm_output_strides, + prepare.fdm_H, prepare.fdm_C, reinterpret_cast(prepare.output_tensor->template MutableData()), + prepare.output_tensor->Shape().Size()); + } else { + // First operation is between input[0] and input[index_of_same_shape] if index_of_same_shape is not 0. + size_t index = index_of_same_shape == 0 ? 1 : 0; + ORT_RETURN_IF_ERROR( + BinaryElementwiseBroadcastPrepare(&inputs[index_of_same_shape].get(), &inputs[index].get(), &output, &prepare)); + Impl_General( + stream, prepare.output_rank_or_simple_broadcast, &prepare.lhs_padded_strides, + reinterpret_cast(prepare.lhs_tensor->template Data()), &prepare.rhs_padded_strides, + reinterpret_cast(prepare.rhs_tensor->template Data()), &prepare.fdm_output_strides, + prepare.fdm_H, prepare.fdm_C, reinterpret_cast(prepare.output_tensor->template MutableData()), + prepare.output_tensor->Shape().Size()); + } for (size_t index = 1; index < inputs.size(); index++) { - ORT_RETURN_IF_ERROR(BinaryElementwiseBroadcastPrepare(&output, &inputs[index].get(), &output, &prepare)); + // If index_of_same_shape is 0, we already handle the 1st and 2nd inputs. + if (index == index_of_same_shape || (index_of_same_shape == 0 && index == 1)) { + continue; + } + ORT_RETURN_IF_ERROR(BinaryElementwiseBroadcastPrepare(&output, &inputs[index].get(), &output, &prepare)); Impl_General( - stream, - prepare.output_rank_or_simple_broadcast, - &prepare.lhs_padded_strides, - reinterpret_cast(prepare.lhs_tensor->template Data()), - &prepare.rhs_padded_strides, - reinterpret_cast(prepare.rhs_tensor->template Data()), - &prepare.fdm_output_strides, - prepare.fdm_H, - prepare.fdm_C, - reinterpret_cast(prepare.output_tensor->template MutableData()), + stream, prepare.output_rank_or_simple_broadcast, &prepare.lhs_padded_strides, + reinterpret_cast(prepare.lhs_tensor->template Data()), &prepare.rhs_padded_strides, + reinterpret_cast(prepare.rhs_tensor->template Data()), &prepare.fdm_output_strides, + prepare.fdm_H, prepare.fdm_C, reinterpret_cast(prepare.output_tensor->template MutableData()), prepare.output_tensor->Shape().Size()); } @@ -145,23 +188,21 @@ Status VariadicElementwiseOp const auto element_type = first_input_tensor.GetElementType(); utils::MLTypeCallDispatcher dispatcher(element_type); - // special case for no broadcasting and few enough inputs - if (input_count <= k_max_input_batch_size && - std::all_of( - input_tensors.begin() + 1, input_tensors.end(), - [&first_input_tensor](InputTensorVector::value_type t) { - return first_input_tensor.Shape() == t.get().Shape(); - })) { + // Special case for no broadcasting. + if (std::all_of(input_tensors.begin() + 1, input_tensors.end(), + [&first_input_tensor](InputTensorVector::value_type t) { + return first_input_tensor.Shape() == t.get().Shape(); + })) { auto& output_tensor = context->RequiredOutput(0, first_input_tensor.Shape()); // special case for no broadcasting and 2 inputs if (input_count == 2) { - return dispatcher.template InvokeRet( - Stream(), input_tensors[0], input_tensors[1], output_tensor); + return dispatcher.template InvokeRet(Stream(), input_tensors[0], + input_tensors[1], output_tensor); } - return dispatcher.template InvokeRet( - Stream(), input_tensors, output_tensor); + return dispatcher.template InvokeRet(Stream(), input_tensors, + output_tensor); } // compute output shape first, using broadcast rule diff --git a/onnxruntime/test/providers/cpu/math/element_wise_ops_test.cc b/onnxruntime/test/providers/cpu/math/element_wise_ops_test.cc index c95f3daf5b..f4c1b33501 100644 --- a/onnxruntime/test/providers/cpu/math/element_wise_ops_test.cc +++ b/onnxruntime/test/providers/cpu/math/element_wise_ops_test.cc @@ -1083,9 +1083,18 @@ static void TestSumMultipleInputsNoBroadcasting(size_t num_inputs, const TensorS TEST(MathOpTest, SumMultipleInputsNoBroadcasting) { const TensorShape shape{3, 3, 3}; - for (size_t num_inputs = 2; num_inputs < 10; ++num_inputs) { + // Special case: + // 2: BinaryImplDispatchTarget + // 3-8: NoBroadcastBatchImplDispatchTarget(i) + // 9: NoBroadcastBatchImplDispatchTarget(8) + BinaryImplDispatchTarget + // 10: NoBroadcastBatchImplDispatchTarget(8) + NoBroadcastBatchImplDispatchTarget(3) + // 15: NoBroadcastBatchImplDispatchTarget(8) + NoBroadcastBatchImplDispatchTarget(8) + // 16: NoBroadcastBatchImplDispatchTarget(8) + NoBroadcastBatchImplDispatchTarget(8) + BinaryImplDispatchTarget + for (size_t num_inputs = 2; num_inputs <= 10; ++num_inputs) { TestSumMultipleInputsNoBroadcasting(num_inputs, shape); } + TestSumMultipleInputsNoBroadcasting(15, shape); + TestSumMultipleInputsNoBroadcasting(16, shape); } TEST(MathOpTest, SumMultipleInputsNoBroadcasting_double) {