From b7b42e0c5d4ae6f4052a89c3f3bc8ab03d9abe9a Mon Sep 17 00:00:00 2001 From: Vincent Wang Date: Thu, 9 Sep 2021 01:28:57 +0800 Subject: [PATCH] fast reduction for reducemean (#8976) --- .../cuda/reduction/reduction_functions.cc | 2 +- .../cuda/reduction/reduction_functions.cu | 27 ++++++++ .../cuda/reduction/reduction_functions.h | 4 ++ .../providers/cuda/reduction/reduction_ops.cc | 63 +++++++++++++------ .../providers/cuda/reduction/reduction_ops.h | 6 +- .../providers/rocm/reduction/reduction_ops.cc | 63 +++++++++++++------ .../providers/rocm/reduction/reduction_ops.h | 6 +- 7 files changed, 130 insertions(+), 41 deletions(-) diff --git a/onnxruntime/core/providers/cuda/reduction/reduction_functions.cc b/onnxruntime/core/providers/cuda/reduction/reduction_functions.cc index fbf0ac0bd4..d756bd4501 100644 --- a/onnxruntime/core/providers/cuda/reduction/reduction_functions.cc +++ b/onnxruntime/core/providers/cuda/reduction/reduction_functions.cc @@ -92,7 +92,7 @@ ApplicableMatrixReduction get_applicable_matrix_reduction( const cudnnReduceTensorOp_t cudnn_reduce_op, const std::vector& dims, const std::vector& original_axes, int& m_out, int& n_out) { - if (cudnn_reduce_op != CUDNN_REDUCE_TENSOR_ADD) { + if (cudnn_reduce_op != CUDNN_REDUCE_TENSOR_ADD && cudnn_reduce_op != CUDNN_REDUCE_TENSOR_AVG) { return ApplicableMatrixReduction::None; } diff --git a/onnxruntime/core/providers/cuda/reduction/reduction_functions.cu b/onnxruntime/core/providers/cuda/reduction/reduction_functions.cu index c83060ef48..e48ad4f4f5 100644 --- a/onnxruntime/core/providers/cuda/reduction/reduction_functions.cu +++ b/onnxruntime/core/providers/cuda/reduction/reduction_functions.cu @@ -12,6 +12,7 @@ #include "core/providers/cuda/cu_inc/common.cuh" #include "core/providers/cuda/shared_inc/cuda_utils.h" #include "core/providers/cuda/reduction/reduction_utils.cuh" +#include "core/providers/cuda/cu_inc/unary_elementwise_impl.cuh" namespace onnxruntime { namespace cuda { @@ -458,6 +459,32 @@ Status call_reduce_matrix_rows(cudaStream_t stream, const TIn* input, TOut* outp } } // namespace detail +template +struct OP_Div { + __device__ __inline__ T operator()(const T& a) const { + return a / v_; + } + + OP_Div(T v) : v_(v) {} + + T v_; +}; + +template +void UnaryDiv(cudaStream_t stream, const T* input, T* output, T denominator, size_t count) { + UnaryElementWiseImpl(stream, input, output, OP_Div(denominator), count); +} + +#define INSTANTIATE_UNARY_DIV(T) \ + template void UnaryDiv(cudaStream_t stream, const T* input, T* output, T denominator, size_t count) +INSTANTIATE_UNARY_DIV(half); +INSTANTIATE_UNARY_DIV(float); +INSTANTIATE_UNARY_DIV(double); +#if CUDA_VERSION >= 11000 && (__CUDA_ARCH__ >= 800 || !defined(__CUDA_ARCH__)) +INSTANTIATE_UNARY_DIV(nv_bfloat16); +#endif +#undef INSTANTIATE_UNARY_DIV + template Status reduce_matrix_rows(cudaStream_t stream, const TIn* input, TOut* output, int m, int n, bool reset_initial_output) { using TBuf = AccumulationType_t; diff --git a/onnxruntime/core/providers/cuda/reduction/reduction_functions.h b/onnxruntime/core/providers/cuda/reduction/reduction_functions.h index 965de5a2bd..1ffcffa1d0 100644 --- a/onnxruntime/core/providers/cuda/reduction/reduction_functions.h +++ b/onnxruntime/core/providers/cuda/reduction/reduction_functions.h @@ -103,5 +103,9 @@ Status reduce_matrix_rows(cudaStream_t stream, const TIn* input, TOut* output, i template Status reduce_matrix_columns(cudaStream_t stream, const TIn* input, TOut* output, int m, int n, void* buffer, size_t buffer_size); +/** Apply unary elementwise division. */ +template +void UnaryDiv(cudaStream_t stream, const T* input, T* output, T denominator, size_t count); + } // namespace cuda } // namespace onnxruntime diff --git a/onnxruntime/core/providers/cuda/reduction/reduction_ops.cc b/onnxruntime/core/providers/cuda/reduction/reduction_ops.cc index 3809f4c559..af6a079abb 100644 --- a/onnxruntime/core/providers/cuda/reduction/reduction_ops.cc +++ b/onnxruntime/core/providers/cuda/reduction/reduction_ops.cc @@ -455,27 +455,52 @@ Status ReduceComputeCore(CUDAExecutionProvider& cuda_ep, const Tensor& input, Pr // Block of fast matrix reduction. if (fast_reduction) { int m{}, n{}; - const auto applicable_matrix_reduction = get_applicable_matrix_reduction( - cudnn_reduce_op, input_shape.GetDims(), axes, m, n); - switch (applicable_matrix_reduction) { - case ApplicableMatrixReduction::Rows: { - return reduce_matrix_rows( - stream, - reinterpret_cast(input.template Data()), - reinterpret_cast(output.template MutableData()), - m, n); + const auto applicable_matrix_reduction = + get_applicable_matrix_reduction(cudnn_reduce_op, input_shape.GetDims(), axes, m, n); + if (applicable_matrix_reduction != ApplicableMatrixReduction::None) { + IAllocatorUniquePtr input_data_buffer(nullptr, [](T*) {}); + const CudaT* input_data = reinterpret_cast(input.template Data()); + if (calculate_sqt) { + input_data_buffer = cuda_ep.GetScratchBuffer(input_count); + input_data = reinterpret_cast(input_data_buffer.get()); + fast_divmod tmp_div; + Impl_Mul(stream, static_cast(SimpleBroadcast::NoBroadcast), nullptr, + reinterpret_cast(input.template Data()), nullptr, + reinterpret_cast(input.template Data()), nullptr, tmp_div, tmp_div, + reinterpret_cast(input_data_buffer.get()), input_count); + input_data = reinterpret_cast(input_data_buffer.get()); } - case ApplicableMatrixReduction::Columns: { - const auto buffer_size_bytes = compute_reduce_matrix_columns_buffer_size(m, n); - auto buffer = cuda_ep.GetScratchBuffer(buffer_size_bytes); - return reduce_matrix_columns( - stream, - reinterpret_cast(input.template Data()), - reinterpret_cast(output.template MutableData()), - m, n, buffer.get(), buffer_size_bytes); + + switch (applicable_matrix_reduction) { + case ApplicableMatrixReduction::Rows: { + ORT_RETURN_IF_ERROR(reduce_matrix_rows( + stream, input_data, reinterpret_cast(output.template MutableData()), m, n)); + } break; + case ApplicableMatrixReduction::Columns: { + const auto buffer_size_bytes = compute_reduce_matrix_columns_buffer_size(m, n); + auto buffer = cuda_ep.GetScratchBuffer(buffer_size_bytes); + ORT_RETURN_IF_ERROR(reduce_matrix_columns(stream, input_data, + reinterpret_cast(output.template MutableData()), m, n, + buffer.get(), buffer_size_bytes)); + } break; + default: { + ORT_ENFORCE(false, "Invild matrix reduction type."); + } } - default: - break; + + if (calculate_log) { + Impl_Log(stream, reinterpret_cast(output.template Data()), + reinterpret_cast(output.template MutableData()), output_count); + } else if (cudnn_reduce_op == CUDNN_REDUCE_TENSOR_AVG) { + float denominator_float = applicable_matrix_reduction == ApplicableMatrixReduction::Rows + ? static_cast(m) + : static_cast(n); + CudaT denominator = ToCudaType::FromFloat(denominator_float); + UnaryDiv(stream, reinterpret_cast(output.template Data()), + reinterpret_cast(output.template MutableData()), denominator, output_count); + } + + return Status::OK(); } } diff --git a/onnxruntime/core/providers/cuda/reduction/reduction_ops.h b/onnxruntime/core/providers/cuda/reduction/reduction_ops.h index 523d041ea0..3bb1ea42ec 100644 --- a/onnxruntime/core/providers/cuda/reduction/reduction_ops.h +++ b/onnxruntime/core/providers/cuda/reduction/reduction_ops.h @@ -138,7 +138,9 @@ class ReduceMax final : public ReduceKernel { template class ReduceMean final : public ReduceKernel { public: - ReduceMean(const OpKernelInfo& info) : ReduceKernel(info) {} + ReduceMean(const OpKernelInfo& info) : ReduceKernel(info) { + fast_reduction_ = true; + } Status ComputeInternal(OpKernelContext* ctx) const override { return ComputeImpl(ctx, CUDNN_REDUCE_TENSOR_AVG); @@ -182,6 +184,7 @@ class ReduceLogSum final : public ReduceKernel { public: ReduceLogSum(const OpKernelInfo& info) : ReduceKernel(info) { ReduceKernel::calculate_log_ = true; + fast_reduction_ = true; } Status ComputeInternal(OpKernelContext* ctx) const override { @@ -194,6 +197,7 @@ class ReduceSumSquare final : public ReduceKernel { public: ReduceSumSquare(const OpKernelInfo& info) : ReduceKernel(info) { ReduceKernel::calculate_sqt_ = true; + fast_reduction_ = true; } Status ComputeInternal(OpKernelContext* ctx) const override { diff --git a/onnxruntime/core/providers/rocm/reduction/reduction_ops.cc b/onnxruntime/core/providers/rocm/reduction/reduction_ops.cc index 8c43b2e5b3..f1fcbcc635 100644 --- a/onnxruntime/core/providers/rocm/reduction/reduction_ops.cc +++ b/onnxruntime/core/providers/rocm/reduction/reduction_ops.cc @@ -445,27 +445,52 @@ Status ReduceComputeCore(ROCMExecutionProvider& rocm_ep, const Tensor& input, Pr // Block of fast matrix reduction. if (fast_reduction) { int m{}, n{}; - const auto applicable_matrix_reduction = get_applicable_matrix_reduction( - miopen_reduce_op, input_shape.GetDims(), axes, m, n); - switch (applicable_matrix_reduction) { - case ApplicableMatrixReduction::Rows: { - return reduce_matrix_rows( - stream, - reinterpret_cast(input.template Data()), - reinterpret_cast(output.template MutableData()), - m, n); + const auto applicable_matrix_reduction = + get_applicable_matrix_reduction(miopen_reduce_op, input_shape.GetDims(), axes, m, n); + if (applicable_matrix_reduction != ApplicableMatrixReduction::None) { + IAllocatorUniquePtr input_data_buffer(nullptr, [](T*) {}); + const HipT* input_data = reinterpret_cast(input.template Data()); + if (calculate_sqt) { + input_data_buffer = rocm_ep.GetScratchBuffer(input_count); + input_data = reinterpret_cast(input_data_buffer.get()); + fast_divmod tmp_div; + Impl_Mul(stream, static_cast(SimpleBroadcast::NoBroadcast), nullptr, + reinterpret_cast(input.template Data()), nullptr, + reinterpret_cast(input.template Data()), nullptr, tmp_div, tmp_div, + reinterpret_cast(input_data_buffer.get()), input_count); + input_data = reinterpret_cast(input_data_buffer.get()); } - case ApplicableMatrixReduction::Columns: { - const auto buffer_size_bytes = compute_reduce_matrix_columns_buffer_size(m, n); - auto buffer = rocm_ep.GetScratchBuffer(buffer_size_bytes); - return reduce_matrix_columns( - stream, - reinterpret_cast(input.template Data()), - reinterpret_cast(output.template MutableData()), - m, n, buffer.get(), buffer_size_bytes); + + switch (applicable_matrix_reduction) { + case ApplicableMatrixReduction::Rows: { + ORT_RETURN_IF_ERROR(reduce_matrix_rows( + stream, input_data, reinterpret_cast(output.template MutableData()), m, n)); + } break; + case ApplicableMatrixReduction::Columns: { + const auto buffer_size_bytes = compute_reduce_matrix_columns_buffer_size(m, n); + auto buffer = rocm_ep.GetScratchBuffer(buffer_size_bytes); + ORT_RETURN_IF_ERROR(reduce_matrix_columns(stream, input_data, + reinterpret_cast(output.template MutableData()), m, n, + buffer.get(), buffer_size_bytes)); + } break; + default: { + ORT_ENFORCE(false, "Invild matrix reduction type."); + } } - default: - break; + + if (calculate_log) { + Impl_Log(stream, reinterpret_cast(output.template Data()), + reinterpret_cast(output.template MutableData()), output_count); + } else if (miopen_reduce_op == MIOPEN_REDUCE_TENSOR_AVG) { + float denominator_float = applicable_matrix_reduction == ApplicableMatrixReduction::Rows + ? static_cast(m) + : static_cast(n); + HipT denominator = ToHipType::FromFloat(denominator_float); + UnaryDiv(stream, reinterpret_cast(output.template Data()), + reinterpret_cast(output.template MutableData()), denominator, output_count); + } + + return Status::OK(); } } diff --git a/onnxruntime/core/providers/rocm/reduction/reduction_ops.h b/onnxruntime/core/providers/rocm/reduction/reduction_ops.h index 6078000cdd..e3842ac6a7 100644 --- a/onnxruntime/core/providers/rocm/reduction/reduction_ops.h +++ b/onnxruntime/core/providers/rocm/reduction/reduction_ops.h @@ -140,7 +140,9 @@ class ReduceMax final : public ReduceKernel { template class ReduceMean final : public ReduceKernel { public: - ReduceMean(const OpKernelInfo& info) : ReduceKernel(info) {} + ReduceMean(const OpKernelInfo& info) : ReduceKernel(info) { + fast_reduction_ = true; + } Status ComputeInternal(OpKernelContext* ctx) const override { return ComputeImpl(ctx, MIOPEN_REDUCE_TENSOR_AVG); @@ -184,6 +186,7 @@ class ReduceLogSum final : public ReduceKernel { public: ReduceLogSum(const OpKernelInfo& info) : ReduceKernel(info) { ReduceKernel::calculate_log_ = true; + fast_reduction_ = true; } Status ComputeInternal(OpKernelContext* ctx) const override { @@ -196,6 +199,7 @@ class ReduceSumSquare final : public ReduceKernel { public: ReduceSumSquare(const OpKernelInfo& info) : ReduceKernel(info) { ReduceKernel::calculate_sqt_ = true; + fast_reduction_ = true; } Status ComputeInternal(OpKernelContext* ctx) const override {