mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-01 23:30:35 +00:00
fast reduction for reducemean (#8976)
This commit is contained in:
parent
1c872f9d74
commit
b7b42e0c5d
7 changed files with 130 additions and 41 deletions
|
|
@ -92,7 +92,7 @@ ApplicableMatrixReduction get_applicable_matrix_reduction(
|
|||
const cudnnReduceTensorOp_t cudnn_reduce_op,
|
||||
const std::vector<int64_t>& dims, const std::vector<int64_t>& original_axes,
|
||||
int& m_out, int& n_out) {
|
||||
if (cudnn_reduce_op != CUDNN_REDUCE_TENSOR_ADD) {
|
||||
if (cudnn_reduce_op != CUDNN_REDUCE_TENSOR_ADD && cudnn_reduce_op != CUDNN_REDUCE_TENSOR_AVG) {
|
||||
return ApplicableMatrixReduction::None;
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -12,6 +12,7 @@
|
|||
#include "core/providers/cuda/cu_inc/common.cuh"
|
||||
#include "core/providers/cuda/shared_inc/cuda_utils.h"
|
||||
#include "core/providers/cuda/reduction/reduction_utils.cuh"
|
||||
#include "core/providers/cuda/cu_inc/unary_elementwise_impl.cuh"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace cuda {
|
||||
|
|
@ -458,6 +459,32 @@ Status call_reduce_matrix_rows(cudaStream_t stream, const TIn* input, TOut* outp
|
|||
}
|
||||
} // namespace detail
|
||||
|
||||
template <typename T>
|
||||
struct OP_Div {
|
||||
__device__ __inline__ T operator()(const T& a) const {
|
||||
return a / v_;
|
||||
}
|
||||
|
||||
OP_Div(T v) : v_(v) {}
|
||||
|
||||
T v_;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
void UnaryDiv(cudaStream_t stream, const T* input, T* output, T denominator, size_t count) {
|
||||
UnaryElementWiseImpl(stream, input, output, OP_Div<T>(denominator), count);
|
||||
}
|
||||
|
||||
#define INSTANTIATE_UNARY_DIV(T) \
|
||||
template void UnaryDiv<T>(cudaStream_t stream, const T* input, T* output, T denominator, size_t count)
|
||||
INSTANTIATE_UNARY_DIV(half);
|
||||
INSTANTIATE_UNARY_DIV(float);
|
||||
INSTANTIATE_UNARY_DIV(double);
|
||||
#if CUDA_VERSION >= 11000 && (__CUDA_ARCH__ >= 800 || !defined(__CUDA_ARCH__))
|
||||
INSTANTIATE_UNARY_DIV(nv_bfloat16);
|
||||
#endif
|
||||
#undef INSTANTIATE_UNARY_DIV
|
||||
|
||||
template <typename TIn, typename TOut>
|
||||
Status reduce_matrix_rows(cudaStream_t stream, const TIn* input, TOut* output, int m, int n, bool reset_initial_output) {
|
||||
using TBuf = AccumulationType_t<TIn>;
|
||||
|
|
|
|||
|
|
@ -103,5 +103,9 @@ Status reduce_matrix_rows(cudaStream_t stream, const TIn* input, TOut* output, i
|
|||
template <typename TIn, typename TOut>
|
||||
Status reduce_matrix_columns(cudaStream_t stream, const TIn* input, TOut* output, int m, int n, void* buffer, size_t buffer_size);
|
||||
|
||||
/** Apply unary elementwise division. */
|
||||
template <typename T>
|
||||
void UnaryDiv(cudaStream_t stream, const T* input, T* output, T denominator, size_t count);
|
||||
|
||||
} // namespace cuda
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -455,27 +455,52 @@ Status ReduceComputeCore(CUDAExecutionProvider& cuda_ep, const Tensor& input, Pr
|
|||
// Block of fast matrix reduction.
|
||||
if (fast_reduction) {
|
||||
int m{}, n{};
|
||||
const auto applicable_matrix_reduction = get_applicable_matrix_reduction(
|
||||
cudnn_reduce_op, input_shape.GetDims(), axes, m, n);
|
||||
switch (applicable_matrix_reduction) {
|
||||
case ApplicableMatrixReduction::Rows: {
|
||||
return reduce_matrix_rows(
|
||||
stream,
|
||||
reinterpret_cast<const CudaT*>(input.template Data<T>()),
|
||||
reinterpret_cast<CudaT*>(output.template MutableData<T>()),
|
||||
m, n);
|
||||
const auto applicable_matrix_reduction =
|
||||
get_applicable_matrix_reduction(cudnn_reduce_op, input_shape.GetDims(), axes, m, n);
|
||||
if (applicable_matrix_reduction != ApplicableMatrixReduction::None) {
|
||||
IAllocatorUniquePtr<T> input_data_buffer(nullptr, [](T*) {});
|
||||
const CudaT* input_data = reinterpret_cast<const CudaT*>(input.template Data<T>());
|
||||
if (calculate_sqt) {
|
||||
input_data_buffer = cuda_ep.GetScratchBuffer<T>(input_count);
|
||||
input_data = reinterpret_cast<CudaT*>(input_data_buffer.get());
|
||||
fast_divmod tmp_div;
|
||||
Impl_Mul<CudaT>(stream, static_cast<int32_t>(SimpleBroadcast::NoBroadcast), nullptr,
|
||||
reinterpret_cast<const CudaT*>(input.template Data<T>()), nullptr,
|
||||
reinterpret_cast<const CudaT*>(input.template Data<T>()), nullptr, tmp_div, tmp_div,
|
||||
reinterpret_cast<CudaT*>(input_data_buffer.get()), input_count);
|
||||
input_data = reinterpret_cast<const CudaT*>(input_data_buffer.get());
|
||||
}
|
||||
case ApplicableMatrixReduction::Columns: {
|
||||
const auto buffer_size_bytes = compute_reduce_matrix_columns_buffer_size<CudaT>(m, n);
|
||||
auto buffer = cuda_ep.GetScratchBuffer<void>(buffer_size_bytes);
|
||||
return reduce_matrix_columns(
|
||||
stream,
|
||||
reinterpret_cast<const CudaT*>(input.template Data<T>()),
|
||||
reinterpret_cast<CudaT*>(output.template MutableData<T>()),
|
||||
m, n, buffer.get(), buffer_size_bytes);
|
||||
|
||||
switch (applicable_matrix_reduction) {
|
||||
case ApplicableMatrixReduction::Rows: {
|
||||
ORT_RETURN_IF_ERROR(reduce_matrix_rows(
|
||||
stream, input_data, reinterpret_cast<CudaT*>(output.template MutableData<T>()), m, n));
|
||||
} break;
|
||||
case ApplicableMatrixReduction::Columns: {
|
||||
const auto buffer_size_bytes = compute_reduce_matrix_columns_buffer_size<CudaT>(m, n);
|
||||
auto buffer = cuda_ep.GetScratchBuffer<void>(buffer_size_bytes);
|
||||
ORT_RETURN_IF_ERROR(reduce_matrix_columns(stream, input_data,
|
||||
reinterpret_cast<CudaT*>(output.template MutableData<T>()), m, n,
|
||||
buffer.get(), buffer_size_bytes));
|
||||
} break;
|
||||
default: {
|
||||
ORT_ENFORCE(false, "Invild matrix reduction type.");
|
||||
}
|
||||
}
|
||||
default:
|
||||
break;
|
||||
|
||||
if (calculate_log) {
|
||||
Impl_Log<CudaT>(stream, reinterpret_cast<const CudaT*>(output.template Data<T>()),
|
||||
reinterpret_cast<CudaT*>(output.template MutableData<T>()), output_count);
|
||||
} else if (cudnn_reduce_op == CUDNN_REDUCE_TENSOR_AVG) {
|
||||
float denominator_float = applicable_matrix_reduction == ApplicableMatrixReduction::Rows
|
||||
? static_cast<float>(m)
|
||||
: static_cast<float>(n);
|
||||
CudaT denominator = ToCudaType<T>::FromFloat(denominator_float);
|
||||
UnaryDiv(stream, reinterpret_cast<const CudaT*>(output.template Data<T>()),
|
||||
reinterpret_cast<CudaT*>(output.template MutableData<T>()), denominator, output_count);
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -138,7 +138,9 @@ class ReduceMax final : public ReduceKernel<true> {
|
|||
template <typename T>
|
||||
class ReduceMean final : public ReduceKernel<true> {
|
||||
public:
|
||||
ReduceMean(const OpKernelInfo& info) : ReduceKernel<true>(info) {}
|
||||
ReduceMean(const OpKernelInfo& info) : ReduceKernel<true>(info) {
|
||||
fast_reduction_ = true;
|
||||
}
|
||||
|
||||
Status ComputeInternal(OpKernelContext* ctx) const override {
|
||||
return ComputeImpl<T>(ctx, CUDNN_REDUCE_TENSOR_AVG);
|
||||
|
|
@ -182,6 +184,7 @@ class ReduceLogSum final : public ReduceKernel<true> {
|
|||
public:
|
||||
ReduceLogSum(const OpKernelInfo& info) : ReduceKernel<true>(info) {
|
||||
ReduceKernel<true>::calculate_log_ = true;
|
||||
fast_reduction_ = true;
|
||||
}
|
||||
|
||||
Status ComputeInternal(OpKernelContext* ctx) const override {
|
||||
|
|
@ -194,6 +197,7 @@ class ReduceSumSquare final : public ReduceKernel<true> {
|
|||
public:
|
||||
ReduceSumSquare(const OpKernelInfo& info) : ReduceKernel<true>(info) {
|
||||
ReduceKernel<true>::calculate_sqt_ = true;
|
||||
fast_reduction_ = true;
|
||||
}
|
||||
|
||||
Status ComputeInternal(OpKernelContext* ctx) const override {
|
||||
|
|
|
|||
|
|
@ -445,27 +445,52 @@ Status ReduceComputeCore(ROCMExecutionProvider& rocm_ep, const Tensor& input, Pr
|
|||
// Block of fast matrix reduction.
|
||||
if (fast_reduction) {
|
||||
int m{}, n{};
|
||||
const auto applicable_matrix_reduction = get_applicable_matrix_reduction(
|
||||
miopen_reduce_op, input_shape.GetDims(), axes, m, n);
|
||||
switch (applicable_matrix_reduction) {
|
||||
case ApplicableMatrixReduction::Rows: {
|
||||
return reduce_matrix_rows(
|
||||
stream,
|
||||
reinterpret_cast<const HipT*>(input.template Data<T>()),
|
||||
reinterpret_cast<HipT*>(output.template MutableData<T>()),
|
||||
m, n);
|
||||
const auto applicable_matrix_reduction =
|
||||
get_applicable_matrix_reduction(miopen_reduce_op, input_shape.GetDims(), axes, m, n);
|
||||
if (applicable_matrix_reduction != ApplicableMatrixReduction::None) {
|
||||
IAllocatorUniquePtr<T> input_data_buffer(nullptr, [](T*) {});
|
||||
const HipT* input_data = reinterpret_cast<const HipT*>(input.template Data<T>());
|
||||
if (calculate_sqt) {
|
||||
input_data_buffer = rocm_ep.GetScratchBuffer<T>(input_count);
|
||||
input_data = reinterpret_cast<HipT*>(input_data_buffer.get());
|
||||
fast_divmod tmp_div;
|
||||
Impl_Mul<HipT>(stream, static_cast<int32_t>(SimpleBroadcast::NoBroadcast), nullptr,
|
||||
reinterpret_cast<const HipT*>(input.template Data<T>()), nullptr,
|
||||
reinterpret_cast<const HipT*>(input.template Data<T>()), nullptr, tmp_div, tmp_div,
|
||||
reinterpret_cast<HipT*>(input_data_buffer.get()), input_count);
|
||||
input_data = reinterpret_cast<const HipT*>(input_data_buffer.get());
|
||||
}
|
||||
case ApplicableMatrixReduction::Columns: {
|
||||
const auto buffer_size_bytes = compute_reduce_matrix_columns_buffer_size<HipT>(m, n);
|
||||
auto buffer = rocm_ep.GetScratchBuffer<void>(buffer_size_bytes);
|
||||
return reduce_matrix_columns(
|
||||
stream,
|
||||
reinterpret_cast<const HipT*>(input.template Data<T>()),
|
||||
reinterpret_cast<HipT*>(output.template MutableData<T>()),
|
||||
m, n, buffer.get(), buffer_size_bytes);
|
||||
|
||||
switch (applicable_matrix_reduction) {
|
||||
case ApplicableMatrixReduction::Rows: {
|
||||
ORT_RETURN_IF_ERROR(reduce_matrix_rows(
|
||||
stream, input_data, reinterpret_cast<HipT*>(output.template MutableData<T>()), m, n));
|
||||
} break;
|
||||
case ApplicableMatrixReduction::Columns: {
|
||||
const auto buffer_size_bytes = compute_reduce_matrix_columns_buffer_size<HipT>(m, n);
|
||||
auto buffer = rocm_ep.GetScratchBuffer<void>(buffer_size_bytes);
|
||||
ORT_RETURN_IF_ERROR(reduce_matrix_columns(stream, input_data,
|
||||
reinterpret_cast<HipT*>(output.template MutableData<T>()), m, n,
|
||||
buffer.get(), buffer_size_bytes));
|
||||
} break;
|
||||
default: {
|
||||
ORT_ENFORCE(false, "Invild matrix reduction type.");
|
||||
}
|
||||
}
|
||||
default:
|
||||
break;
|
||||
|
||||
if (calculate_log) {
|
||||
Impl_Log<HipT>(stream, reinterpret_cast<const HipT*>(output.template Data<T>()),
|
||||
reinterpret_cast<HipT*>(output.template MutableData<T>()), output_count);
|
||||
} else if (miopen_reduce_op == MIOPEN_REDUCE_TENSOR_AVG) {
|
||||
float denominator_float = applicable_matrix_reduction == ApplicableMatrixReduction::Rows
|
||||
? static_cast<float>(m)
|
||||
: static_cast<float>(n);
|
||||
HipT denominator = ToHipType<T>::FromFloat(denominator_float);
|
||||
UnaryDiv(stream, reinterpret_cast<const HipT*>(output.template Data<T>()),
|
||||
reinterpret_cast<HipT*>(output.template MutableData<T>()), denominator, output_count);
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -140,7 +140,9 @@ class ReduceMax final : public ReduceKernel<true> {
|
|||
template <typename T>
|
||||
class ReduceMean final : public ReduceKernel<true> {
|
||||
public:
|
||||
ReduceMean(const OpKernelInfo& info) : ReduceKernel<true>(info) {}
|
||||
ReduceMean(const OpKernelInfo& info) : ReduceKernel<true>(info) {
|
||||
fast_reduction_ = true;
|
||||
}
|
||||
|
||||
Status ComputeInternal(OpKernelContext* ctx) const override {
|
||||
return ComputeImpl<T>(ctx, MIOPEN_REDUCE_TENSOR_AVG);
|
||||
|
|
@ -184,6 +186,7 @@ class ReduceLogSum final : public ReduceKernel<true> {
|
|||
public:
|
||||
ReduceLogSum(const OpKernelInfo& info) : ReduceKernel<true>(info) {
|
||||
ReduceKernel<true>::calculate_log_ = true;
|
||||
fast_reduction_ = true;
|
||||
}
|
||||
|
||||
Status ComputeInternal(OpKernelContext* ctx) const override {
|
||||
|
|
@ -196,6 +199,7 @@ class ReduceSumSquare final : public ReduceKernel<true> {
|
|||
public:
|
||||
ReduceSumSquare(const OpKernelInfo& info) : ReduceKernel<true>(info) {
|
||||
ReduceKernel<true>::calculate_sqt_ = true;
|
||||
fast_reduction_ = true;
|
||||
}
|
||||
|
||||
Status ComputeInternal(OpKernelContext* ctx) const override {
|
||||
|
|
|
|||
Loading…
Reference in a new issue