mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-28 22:56:32 +00:00
Optimize Variadic Elementwise Ops (#9186)
* optimize variadic elementwise ops * remove nvvp file * correct comment * resolve comments
This commit is contained in:
parent
5f5f28bf14
commit
cd65a8089e
2 changed files with 98 additions and 48 deletions
|
|
@ -5,6 +5,7 @@
|
|||
#include "core/providers/cuda/math/variadic_elementwise_ops.h"
|
||||
|
||||
#include <cassert>
|
||||
#include <algorithm>
|
||||
|
||||
#include "core/framework/data_types_internal.h"
|
||||
#include "core/providers/cuda/math/binary_elementwise_ops.h"
|
||||
|
|
@ -17,21 +18,50 @@ namespace cuda {
|
|||
|
||||
template <typename VariadicElementwiseOpTag, typename... SupportedElementTypes>
|
||||
template <typename T>
|
||||
Status VariadicElementwiseOp<VariadicElementwiseOpTag, SupportedElementTypes...>::
|
||||
NoBroadcastBatchImplDispatchTarget<T>::operator()(cudaStream_t stream, const InputTensorVector& inputs, Tensor& output) const {
|
||||
assert(inputs.size() > 1);
|
||||
|
||||
Status VariadicElementwiseOp<VariadicElementwiseOpTag, SupportedElementTypes...>::NoBroadcastBatchImplDispatchTarget<
|
||||
T>::operator()(cudaStream_t stream, const InputTensorVector& inputs, Tensor& output) const {
|
||||
using CudaT = typename ToCudaType<T>::MappedType;
|
||||
|
||||
InputBatchArray<CudaT> input_data_batch{static_cast<int32_t>(inputs.size())};
|
||||
for (size_t i = 0; i < inputs.size(); ++i) {
|
||||
size_t input_count = inputs.size();
|
||||
assert(input_count > 1);
|
||||
size_t index = std::min(input_count, static_cast<size_t>(k_max_input_batch_size));
|
||||
InputBatchArray<CudaT> input_data_batch{static_cast<int32_t>(index)};
|
||||
for (size_t i = 0; i < index; ++i) {
|
||||
input_data_batch[static_cast<int32_t>(i)] = reinterpret_cast<const CudaT*>(inputs[i].get().template Data<T>());
|
||||
}
|
||||
|
||||
CudaT* output_data = reinterpret_cast<CudaT*>(output.template MutableData<T>());
|
||||
Impl_NoBroadcastInputBatch<CudaT, VariadicElementwiseOpTag>(stream, input_data_batch, output_data,
|
||||
output.Shape().Size());
|
||||
|
||||
Impl_NoBroadcastInputBatch<CudaT, VariadicElementwiseOpTag>(
|
||||
stream, input_data_batch, output_data, output.Shape().Size());
|
||||
while (index < input_count) {
|
||||
size_t left_count = input_count - index + 1;
|
||||
size_t batch = std::min(left_count, static_cast<size_t>(k_max_input_batch_size));
|
||||
// Special case for 2 inputs left.
|
||||
if (batch == 2) {
|
||||
BinaryElementwisePreparation prepare;
|
||||
ORT_RETURN_IF_ERROR(BinaryElementwiseBroadcastPrepare(&output, &inputs[input_count - 1].get(), &output, &prepare));
|
||||
Impl_General<CudaT, VariadicElementwiseOpTag>(
|
||||
stream, prepare.output_rank_or_simple_broadcast, &prepare.lhs_padded_strides,
|
||||
reinterpret_cast<const CudaT*>(prepare.lhs_tensor->template Data<T>()), &prepare.rhs_padded_strides,
|
||||
reinterpret_cast<const CudaT*>(prepare.rhs_tensor->template Data<T>()), &prepare.fdm_output_strides,
|
||||
prepare.fdm_H, prepare.fdm_C, reinterpret_cast<CudaT*>(prepare.output_tensor->template MutableData<T>()),
|
||||
prepare.output_tensor->Shape().Size());
|
||||
|
||||
// Must be the last.
|
||||
break;
|
||||
}
|
||||
|
||||
InputBatchArray<CudaT> left_input_data_batch{static_cast<int32_t>(batch)};
|
||||
left_input_data_batch[0] = reinterpret_cast<const CudaT*>(output.template Data<T>());
|
||||
for (size_t i = 1; i < batch; ++i) {
|
||||
left_input_data_batch[static_cast<int32_t>(i)] =
|
||||
reinterpret_cast<const CudaT*>(inputs[index].get().template Data<T>());
|
||||
index++;
|
||||
}
|
||||
|
||||
Impl_NoBroadcastInputBatch<CudaT, VariadicElementwiseOpTag>(stream, left_input_data_batch, output_data,
|
||||
output.Shape().Size());
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
|
@ -65,44 +95,57 @@ Status VariadicElementwiseOp<VariadicElementwiseOpTag, SupportedElementTypes...>
|
|||
// for more than 2 inputs, we need to accumulate into output tensor, as the shape from input0 + input1 might be different from output shape
|
||||
template <typename VariadicElementwiseOpTag, typename... SupportedElementTypes>
|
||||
template <typename T>
|
||||
Status VariadicElementwiseOp<VariadicElementwiseOpTag, SupportedElementTypes...>::
|
||||
GeneralImplDispatchTarget<T>::operator()(cudaStream_t stream, const InputTensorVector& inputs, Tensor& output) const {
|
||||
Status
|
||||
VariadicElementwiseOp<VariadicElementwiseOpTag, SupportedElementTypes...>::GeneralImplDispatchTarget<T>::operator()(
|
||||
cudaStream_t stream, const InputTensorVector& inputs, Tensor& output) const {
|
||||
assert(inputs.size() > 1);
|
||||
|
||||
using CudaT = typename ToCudaType<T>::MappedType;
|
||||
|
||||
CUDA_RETURN_IF_ERROR(cudaMemsetAsync(output.MutableDataRaw(), 0, output.SizeInBytes(), stream));
|
||||
// If there is any input having the same shape with output, we don't need the memset.
|
||||
size_t index_of_same_shape = 0;
|
||||
for (; index_of_same_shape < inputs.size(); index_of_same_shape++) {
|
||||
if (inputs[index_of_same_shape].get().Shape() == output.Shape()) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
BinaryElementwisePreparation prepare;
|
||||
ORT_RETURN_IF_ERROR(BinaryElementwiseBroadcastPrepare(&output, &inputs[0].get(), &output, &prepare));
|
||||
|
||||
Impl_Add(
|
||||
stream,
|
||||
prepare.output_rank_or_simple_broadcast,
|
||||
&prepare.lhs_padded_strides,
|
||||
reinterpret_cast<const CudaT*>(prepare.lhs_tensor->template Data<T>()),
|
||||
&prepare.rhs_padded_strides,
|
||||
reinterpret_cast<const CudaT*>(prepare.rhs_tensor->template Data<T>()),
|
||||
&prepare.fdm_output_strides,
|
||||
prepare.fdm_H,
|
||||
prepare.fdm_C,
|
||||
reinterpret_cast<CudaT*>(prepare.output_tensor->template MutableData<T>()),
|
||||
prepare.output_tensor->Shape().Size());
|
||||
// No input has same shape of output, memset the output, and add the 1st input as initialization.
|
||||
if (index_of_same_shape == inputs.size()) {
|
||||
CUDA_RETURN_IF_ERROR(cudaMemsetAsync(output.MutableDataRaw(), 0, output.SizeInBytes(), stream));
|
||||
ORT_RETURN_IF_ERROR(BinaryElementwiseBroadcastPrepare(&output, &inputs[0].get(), &output, &prepare));
|
||||
Impl_Add(stream, prepare.output_rank_or_simple_broadcast, &prepare.lhs_padded_strides,
|
||||
reinterpret_cast<const CudaT*>(prepare.lhs_tensor->template Data<T>()), &prepare.rhs_padded_strides,
|
||||
reinterpret_cast<const CudaT*>(prepare.rhs_tensor->template Data<T>()), &prepare.fdm_output_strides,
|
||||
prepare.fdm_H, prepare.fdm_C, reinterpret_cast<CudaT*>(prepare.output_tensor->template MutableData<T>()),
|
||||
prepare.output_tensor->Shape().Size());
|
||||
} else {
|
||||
// First operation is between input[0] and input[index_of_same_shape] if index_of_same_shape is not 0.
|
||||
size_t index = index_of_same_shape == 0 ? 1 : 0;
|
||||
ORT_RETURN_IF_ERROR(
|
||||
BinaryElementwiseBroadcastPrepare(&inputs[index_of_same_shape].get(), &inputs[index].get(), &output, &prepare));
|
||||
Impl_General<CudaT, VariadicElementwiseOpTag>(
|
||||
stream, prepare.output_rank_or_simple_broadcast, &prepare.lhs_padded_strides,
|
||||
reinterpret_cast<const CudaT*>(prepare.lhs_tensor->template Data<T>()), &prepare.rhs_padded_strides,
|
||||
reinterpret_cast<const CudaT*>(prepare.rhs_tensor->template Data<T>()), &prepare.fdm_output_strides,
|
||||
prepare.fdm_H, prepare.fdm_C, reinterpret_cast<CudaT*>(prepare.output_tensor->template MutableData<T>()),
|
||||
prepare.output_tensor->Shape().Size());
|
||||
}
|
||||
|
||||
for (size_t index = 1; index < inputs.size(); index++) {
|
||||
ORT_RETURN_IF_ERROR(BinaryElementwiseBroadcastPrepare(&output, &inputs[index].get(), &output, &prepare));
|
||||
// If index_of_same_shape is 0, we already handle the 1st and 2nd inputs.
|
||||
if (index == index_of_same_shape || (index_of_same_shape == 0 && index == 1)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
ORT_RETURN_IF_ERROR(BinaryElementwiseBroadcastPrepare(&output, &inputs[index].get(), &output, &prepare));
|
||||
Impl_General<CudaT, VariadicElementwiseOpTag>(
|
||||
stream,
|
||||
prepare.output_rank_or_simple_broadcast,
|
||||
&prepare.lhs_padded_strides,
|
||||
reinterpret_cast<const CudaT*>(prepare.lhs_tensor->template Data<T>()),
|
||||
&prepare.rhs_padded_strides,
|
||||
reinterpret_cast<const CudaT*>(prepare.rhs_tensor->template Data<T>()),
|
||||
&prepare.fdm_output_strides,
|
||||
prepare.fdm_H,
|
||||
prepare.fdm_C,
|
||||
reinterpret_cast<CudaT*>(prepare.output_tensor->template MutableData<T>()),
|
||||
stream, prepare.output_rank_or_simple_broadcast, &prepare.lhs_padded_strides,
|
||||
reinterpret_cast<const CudaT*>(prepare.lhs_tensor->template Data<T>()), &prepare.rhs_padded_strides,
|
||||
reinterpret_cast<const CudaT*>(prepare.rhs_tensor->template Data<T>()), &prepare.fdm_output_strides,
|
||||
prepare.fdm_H, prepare.fdm_C, reinterpret_cast<CudaT*>(prepare.output_tensor->template MutableData<T>()),
|
||||
prepare.output_tensor->Shape().Size());
|
||||
}
|
||||
|
||||
|
|
@ -145,23 +188,21 @@ Status VariadicElementwiseOp<VariadicElementwiseOpTag, SupportedElementTypes...>
|
|||
const auto element_type = first_input_tensor.GetElementType();
|
||||
utils::MLTypeCallDispatcher<SupportedElementTypes...> dispatcher(element_type);
|
||||
|
||||
// special case for no broadcasting and few enough inputs
|
||||
if (input_count <= k_max_input_batch_size &&
|
||||
std::all_of(
|
||||
input_tensors.begin() + 1, input_tensors.end(),
|
||||
[&first_input_tensor](InputTensorVector::value_type t) {
|
||||
return first_input_tensor.Shape() == t.get().Shape();
|
||||
})) {
|
||||
// Special case for no broadcasting.
|
||||
if (std::all_of(input_tensors.begin() + 1, input_tensors.end(),
|
||||
[&first_input_tensor](InputTensorVector::value_type t) {
|
||||
return first_input_tensor.Shape() == t.get().Shape();
|
||||
})) {
|
||||
auto& output_tensor = context->RequiredOutput(0, first_input_tensor.Shape());
|
||||
|
||||
// special case for no broadcasting and 2 inputs
|
||||
if (input_count == 2) {
|
||||
return dispatcher.template InvokeRet<Status, BinaryImplDispatchTarget>(
|
||||
Stream(), input_tensors[0], input_tensors[1], output_tensor);
|
||||
return dispatcher.template InvokeRet<Status, BinaryImplDispatchTarget>(Stream(), input_tensors[0],
|
||||
input_tensors[1], output_tensor);
|
||||
}
|
||||
|
||||
return dispatcher.template InvokeRet<Status, NoBroadcastBatchImplDispatchTarget>(
|
||||
Stream(), input_tensors, output_tensor);
|
||||
return dispatcher.template InvokeRet<Status, NoBroadcastBatchImplDispatchTarget>(Stream(), input_tensors,
|
||||
output_tensor);
|
||||
}
|
||||
|
||||
// compute output shape first, using broadcast rule
|
||||
|
|
|
|||
|
|
@ -1083,9 +1083,18 @@ static void TestSumMultipleInputsNoBroadcasting(size_t num_inputs, const TensorS
|
|||
|
||||
TEST(MathOpTest, SumMultipleInputsNoBroadcasting) {
|
||||
const TensorShape shape{3, 3, 3};
|
||||
for (size_t num_inputs = 2; num_inputs < 10; ++num_inputs) {
|
||||
// Special case:
|
||||
// 2: BinaryImplDispatchTarget
|
||||
// 3-8: NoBroadcastBatchImplDispatchTarget(i)
|
||||
// 9: NoBroadcastBatchImplDispatchTarget(8) + BinaryImplDispatchTarget
|
||||
// 10: NoBroadcastBatchImplDispatchTarget(8) + NoBroadcastBatchImplDispatchTarget(3)
|
||||
// 15: NoBroadcastBatchImplDispatchTarget(8) + NoBroadcastBatchImplDispatchTarget(8)
|
||||
// 16: NoBroadcastBatchImplDispatchTarget(8) + NoBroadcastBatchImplDispatchTarget(8) + BinaryImplDispatchTarget
|
||||
for (size_t num_inputs = 2; num_inputs <= 10; ++num_inputs) {
|
||||
TestSumMultipleInputsNoBroadcasting<float>(num_inputs, shape);
|
||||
}
|
||||
TestSumMultipleInputsNoBroadcasting<float>(15, shape);
|
||||
TestSumMultipleInputsNoBroadcasting<float>(16, shape);
|
||||
}
|
||||
|
||||
TEST(MathOpTest, SumMultipleInputsNoBroadcasting_double) {
|
||||
|
|
|
|||
Loading…
Reference in a new issue