Optimize Variadic Elementwise Ops (#9186)

* optimize variadic elementwise ops

* remove nvvp file

* correct comment

* resolve comments
This commit is contained in:
Vincent Wang 2021-10-08 13:45:54 +08:00 committed by GitHub
parent 5f5f28bf14
commit cd65a8089e
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 98 additions and 48 deletions

View file

@ -5,6 +5,7 @@
#include "core/providers/cuda/math/variadic_elementwise_ops.h"
#include <cassert>
#include <algorithm>
#include "core/framework/data_types_internal.h"
#include "core/providers/cuda/math/binary_elementwise_ops.h"
@ -17,21 +18,50 @@ namespace cuda {
template <typename VariadicElementwiseOpTag, typename... SupportedElementTypes>
template <typename T>
Status VariadicElementwiseOp<VariadicElementwiseOpTag, SupportedElementTypes...>::
NoBroadcastBatchImplDispatchTarget<T>::operator()(cudaStream_t stream, const InputTensorVector& inputs, Tensor& output) const {
assert(inputs.size() > 1);
Status VariadicElementwiseOp<VariadicElementwiseOpTag, SupportedElementTypes...>::NoBroadcastBatchImplDispatchTarget<
T>::operator()(cudaStream_t stream, const InputTensorVector& inputs, Tensor& output) const {
using CudaT = typename ToCudaType<T>::MappedType;
InputBatchArray<CudaT> input_data_batch{static_cast<int32_t>(inputs.size())};
for (size_t i = 0; i < inputs.size(); ++i) {
size_t input_count = inputs.size();
assert(input_count > 1);
size_t index = std::min(input_count, static_cast<size_t>(k_max_input_batch_size));
InputBatchArray<CudaT> input_data_batch{static_cast<int32_t>(index)};
for (size_t i = 0; i < index; ++i) {
input_data_batch[static_cast<int32_t>(i)] = reinterpret_cast<const CudaT*>(inputs[i].get().template Data<T>());
}
CudaT* output_data = reinterpret_cast<CudaT*>(output.template MutableData<T>());
Impl_NoBroadcastInputBatch<CudaT, VariadicElementwiseOpTag>(stream, input_data_batch, output_data,
output.Shape().Size());
Impl_NoBroadcastInputBatch<CudaT, VariadicElementwiseOpTag>(
stream, input_data_batch, output_data, output.Shape().Size());
while (index < input_count) {
size_t left_count = input_count - index + 1;
size_t batch = std::min(left_count, static_cast<size_t>(k_max_input_batch_size));
// Special case for 2 inputs left.
if (batch == 2) {
BinaryElementwisePreparation prepare;
ORT_RETURN_IF_ERROR(BinaryElementwiseBroadcastPrepare(&output, &inputs[input_count - 1].get(), &output, &prepare));
Impl_General<CudaT, VariadicElementwiseOpTag>(
stream, prepare.output_rank_or_simple_broadcast, &prepare.lhs_padded_strides,
reinterpret_cast<const CudaT*>(prepare.lhs_tensor->template Data<T>()), &prepare.rhs_padded_strides,
reinterpret_cast<const CudaT*>(prepare.rhs_tensor->template Data<T>()), &prepare.fdm_output_strides,
prepare.fdm_H, prepare.fdm_C, reinterpret_cast<CudaT*>(prepare.output_tensor->template MutableData<T>()),
prepare.output_tensor->Shape().Size());
// Must be the last.
break;
}
InputBatchArray<CudaT> left_input_data_batch{static_cast<int32_t>(batch)};
left_input_data_batch[0] = reinterpret_cast<const CudaT*>(output.template Data<T>());
for (size_t i = 1; i < batch; ++i) {
left_input_data_batch[static_cast<int32_t>(i)] =
reinterpret_cast<const CudaT*>(inputs[index].get().template Data<T>());
index++;
}
Impl_NoBroadcastInputBatch<CudaT, VariadicElementwiseOpTag>(stream, left_input_data_batch, output_data,
output.Shape().Size());
}
return Status::OK();
}
@ -65,44 +95,57 @@ Status VariadicElementwiseOp<VariadicElementwiseOpTag, SupportedElementTypes...>
// for more than 2 inputs, we need to accumulate into output tensor, as the shape from input0 + input1 might be different from output shape
template <typename VariadicElementwiseOpTag, typename... SupportedElementTypes>
template <typename T>
Status VariadicElementwiseOp<VariadicElementwiseOpTag, SupportedElementTypes...>::
GeneralImplDispatchTarget<T>::operator()(cudaStream_t stream, const InputTensorVector& inputs, Tensor& output) const {
Status
VariadicElementwiseOp<VariadicElementwiseOpTag, SupportedElementTypes...>::GeneralImplDispatchTarget<T>::operator()(
cudaStream_t stream, const InputTensorVector& inputs, Tensor& output) const {
assert(inputs.size() > 1);
using CudaT = typename ToCudaType<T>::MappedType;
CUDA_RETURN_IF_ERROR(cudaMemsetAsync(output.MutableDataRaw(), 0, output.SizeInBytes(), stream));
// If there is any input having the same shape with output, we don't need the memset.
size_t index_of_same_shape = 0;
for (; index_of_same_shape < inputs.size(); index_of_same_shape++) {
if (inputs[index_of_same_shape].get().Shape() == output.Shape()) {
break;
}
}
BinaryElementwisePreparation prepare;
ORT_RETURN_IF_ERROR(BinaryElementwiseBroadcastPrepare(&output, &inputs[0].get(), &output, &prepare));
Impl_Add(
stream,
prepare.output_rank_or_simple_broadcast,
&prepare.lhs_padded_strides,
reinterpret_cast<const CudaT*>(prepare.lhs_tensor->template Data<T>()),
&prepare.rhs_padded_strides,
reinterpret_cast<const CudaT*>(prepare.rhs_tensor->template Data<T>()),
&prepare.fdm_output_strides,
prepare.fdm_H,
prepare.fdm_C,
reinterpret_cast<CudaT*>(prepare.output_tensor->template MutableData<T>()),
prepare.output_tensor->Shape().Size());
// No input has same shape of output, memset the output, and add the 1st input as initialization.
if (index_of_same_shape == inputs.size()) {
CUDA_RETURN_IF_ERROR(cudaMemsetAsync(output.MutableDataRaw(), 0, output.SizeInBytes(), stream));
ORT_RETURN_IF_ERROR(BinaryElementwiseBroadcastPrepare(&output, &inputs[0].get(), &output, &prepare));
Impl_Add(stream, prepare.output_rank_or_simple_broadcast, &prepare.lhs_padded_strides,
reinterpret_cast<const CudaT*>(prepare.lhs_tensor->template Data<T>()), &prepare.rhs_padded_strides,
reinterpret_cast<const CudaT*>(prepare.rhs_tensor->template Data<T>()), &prepare.fdm_output_strides,
prepare.fdm_H, prepare.fdm_C, reinterpret_cast<CudaT*>(prepare.output_tensor->template MutableData<T>()),
prepare.output_tensor->Shape().Size());
} else {
// First operation is between input[0] and input[index_of_same_shape] if index_of_same_shape is not 0.
size_t index = index_of_same_shape == 0 ? 1 : 0;
ORT_RETURN_IF_ERROR(
BinaryElementwiseBroadcastPrepare(&inputs[index_of_same_shape].get(), &inputs[index].get(), &output, &prepare));
Impl_General<CudaT, VariadicElementwiseOpTag>(
stream, prepare.output_rank_or_simple_broadcast, &prepare.lhs_padded_strides,
reinterpret_cast<const CudaT*>(prepare.lhs_tensor->template Data<T>()), &prepare.rhs_padded_strides,
reinterpret_cast<const CudaT*>(prepare.rhs_tensor->template Data<T>()), &prepare.fdm_output_strides,
prepare.fdm_H, prepare.fdm_C, reinterpret_cast<CudaT*>(prepare.output_tensor->template MutableData<T>()),
prepare.output_tensor->Shape().Size());
}
for (size_t index = 1; index < inputs.size(); index++) {
ORT_RETURN_IF_ERROR(BinaryElementwiseBroadcastPrepare(&output, &inputs[index].get(), &output, &prepare));
// If index_of_same_shape is 0, we already handle the 1st and 2nd inputs.
if (index == index_of_same_shape || (index_of_same_shape == 0 && index == 1)) {
continue;
}
ORT_RETURN_IF_ERROR(BinaryElementwiseBroadcastPrepare(&output, &inputs[index].get(), &output, &prepare));
Impl_General<CudaT, VariadicElementwiseOpTag>(
stream,
prepare.output_rank_or_simple_broadcast,
&prepare.lhs_padded_strides,
reinterpret_cast<const CudaT*>(prepare.lhs_tensor->template Data<T>()),
&prepare.rhs_padded_strides,
reinterpret_cast<const CudaT*>(prepare.rhs_tensor->template Data<T>()),
&prepare.fdm_output_strides,
prepare.fdm_H,
prepare.fdm_C,
reinterpret_cast<CudaT*>(prepare.output_tensor->template MutableData<T>()),
stream, prepare.output_rank_or_simple_broadcast, &prepare.lhs_padded_strides,
reinterpret_cast<const CudaT*>(prepare.lhs_tensor->template Data<T>()), &prepare.rhs_padded_strides,
reinterpret_cast<const CudaT*>(prepare.rhs_tensor->template Data<T>()), &prepare.fdm_output_strides,
prepare.fdm_H, prepare.fdm_C, reinterpret_cast<CudaT*>(prepare.output_tensor->template MutableData<T>()),
prepare.output_tensor->Shape().Size());
}
@ -145,23 +188,21 @@ Status VariadicElementwiseOp<VariadicElementwiseOpTag, SupportedElementTypes...>
const auto element_type = first_input_tensor.GetElementType();
utils::MLTypeCallDispatcher<SupportedElementTypes...> dispatcher(element_type);
// special case for no broadcasting and few enough inputs
if (input_count <= k_max_input_batch_size &&
std::all_of(
input_tensors.begin() + 1, input_tensors.end(),
[&first_input_tensor](InputTensorVector::value_type t) {
return first_input_tensor.Shape() == t.get().Shape();
})) {
// Special case for no broadcasting.
if (std::all_of(input_tensors.begin() + 1, input_tensors.end(),
[&first_input_tensor](InputTensorVector::value_type t) {
return first_input_tensor.Shape() == t.get().Shape();
})) {
auto& output_tensor = context->RequiredOutput(0, first_input_tensor.Shape());
// special case for no broadcasting and 2 inputs
if (input_count == 2) {
return dispatcher.template InvokeRet<Status, BinaryImplDispatchTarget>(
Stream(), input_tensors[0], input_tensors[1], output_tensor);
return dispatcher.template InvokeRet<Status, BinaryImplDispatchTarget>(Stream(), input_tensors[0],
input_tensors[1], output_tensor);
}
return dispatcher.template InvokeRet<Status, NoBroadcastBatchImplDispatchTarget>(
Stream(), input_tensors, output_tensor);
return dispatcher.template InvokeRet<Status, NoBroadcastBatchImplDispatchTarget>(Stream(), input_tensors,
output_tensor);
}
// compute output shape first, using broadcast rule

View file

@ -1083,9 +1083,18 @@ static void TestSumMultipleInputsNoBroadcasting(size_t num_inputs, const TensorS
TEST(MathOpTest, SumMultipleInputsNoBroadcasting) {
const TensorShape shape{3, 3, 3};
for (size_t num_inputs = 2; num_inputs < 10; ++num_inputs) {
// Special case:
// 2: BinaryImplDispatchTarget
// 3-8: NoBroadcastBatchImplDispatchTarget(i)
// 9: NoBroadcastBatchImplDispatchTarget(8) + BinaryImplDispatchTarget
// 10: NoBroadcastBatchImplDispatchTarget(8) + NoBroadcastBatchImplDispatchTarget(3)
// 15: NoBroadcastBatchImplDispatchTarget(8) + NoBroadcastBatchImplDispatchTarget(8)
// 16: NoBroadcastBatchImplDispatchTarget(8) + NoBroadcastBatchImplDispatchTarget(8) + BinaryImplDispatchTarget
for (size_t num_inputs = 2; num_inputs <= 10; ++num_inputs) {
TestSumMultipleInputsNoBroadcasting<float>(num_inputs, shape);
}
TestSumMultipleInputsNoBroadcasting<float>(15, shape);
TestSumMultipleInputsNoBroadcasting<float>(16, shape);
}
TEST(MathOpTest, SumMultipleInputsNoBroadcasting_double) {