mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-26 03:00:54 +00:00
Fix CUDA Reduction kernel for ArgMax/ArgMix for when reduction dim=1 (#6490)
* Fix for when reduction dim=1 * Disable test for AMD GPUs * Specify Async
This commit is contained in:
parent
14f7d56c81
commit
85434273ff
3 changed files with 36 additions and 17 deletions
|
|
@ -154,7 +154,7 @@ Status ReduceKernel<allow_multi_axes>::ReduceKernelShared(
|
|||
m, n, false);
|
||||
}
|
||||
case ApplicableMatrixReduction::Columns:
|
||||
// don't call reduce_matrix_columns() since it will reset initial output data
|
||||
// don't call reduce_matrix_columns() since it will reset initial output data
|
||||
default:
|
||||
break;
|
||||
}
|
||||
|
|
@ -600,24 +600,30 @@ Status ReduceComputeCore(CUDAExecutionProvider& cuda_ep, const Tensor& input, Pr
|
|||
}
|
||||
}
|
||||
} else { // For ArgMax & ArgMin ops, use the indicies as the output with int64 type
|
||||
if (temp_X) {
|
||||
auto temp_output = cuda_ep.GetScratchBuffer<float>(output_count);
|
||||
CUDNN_RETURN_IF_ERROR(cudnnReduceTensor(
|
||||
cuda_ep.PerThreadCudnnHandle(), reduce_desc, indices_cuda.get(), indices_bytes,
|
||||
workspace_cuda.get(), workspace_bytes,
|
||||
&one, input_tensor, temp_X.get(),
|
||||
&zero, output_tensor, temp_output.get()));
|
||||
// cudnnReduceTensor has issue if input and output has same size, which will happen if the axis to be reduced has dim value of 1.
|
||||
// the output is zeros of the output size
|
||||
if (input_count == output_count) {
|
||||
CUDA_RETURN_IF_ERROR(cudaMemsetAsync(output.template MutableData<int64_t>(), static_cast<int64_t>(0), output_count * sizeof(int64_t)));
|
||||
} else {
|
||||
auto temp_output = cuda_ep.GetScratchBuffer<CudaT>(output_count);
|
||||
CUDNN_RETURN_IF_ERROR(cudnnReduceTensor(
|
||||
cuda_ep.PerThreadCudnnHandle(), reduce_desc, indices_cuda.get(), indices_bytes,
|
||||
workspace_cuda.get(), workspace_bytes,
|
||||
&one, input_tensor, reinterpret_cast<const CudaT*>(input.template Data<T>()),
|
||||
&zero, output_tensor, temp_output.get()));
|
||||
}
|
||||
if (temp_X) {
|
||||
auto temp_output = cuda_ep.GetScratchBuffer<float>(output_count);
|
||||
CUDNN_RETURN_IF_ERROR(cudnnReduceTensor(
|
||||
cuda_ep.PerThreadCudnnHandle(), reduce_desc, indices_cuda.get(), indices_bytes,
|
||||
workspace_cuda.get(), workspace_bytes,
|
||||
&one, input_tensor, temp_X.get(),
|
||||
&zero, output_tensor, temp_output.get()));
|
||||
} else {
|
||||
auto temp_output = cuda_ep.GetScratchBuffer<CudaT>(output_count);
|
||||
CUDNN_RETURN_IF_ERROR(cudnnReduceTensor(
|
||||
cuda_ep.PerThreadCudnnHandle(), reduce_desc, indices_cuda.get(), indices_bytes,
|
||||
workspace_cuda.get(), workspace_bytes,
|
||||
&one, input_tensor, reinterpret_cast<const CudaT*>(input.template Data<T>()),
|
||||
&zero, output_tensor, temp_output.get()));
|
||||
}
|
||||
|
||||
// CUDA reduction index is uint32_t for now, cast it to int64_t according to ONNX spec
|
||||
Impl_Cast<uint32_t, int64_t>(reinterpret_cast<uint32_t*>(indices_cuda.get()), output.template MutableData<int64_t>(), output_count);
|
||||
// CUDA reduction index is uint32_t for now, cast it to int64_t according to ONNX spec
|
||||
Impl_Cast<uint32_t, int64_t>(reinterpret_cast<uint32_t*>(indices_cuda.get()), output.template MutableData<int64_t>(), output_count);
|
||||
}
|
||||
}
|
||||
|
||||
if (calculate_log) {
|
||||
|
|
|
|||
|
|
@ -2060,6 +2060,18 @@ TEST(ReductionOpTest, ArgMax2D_select_last) {
|
|||
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider});
|
||||
}
|
||||
|
||||
TEST(ReductionOpTest, ArgMax2D_dim1) {
|
||||
OpTester test("ArgMax", 11);
|
||||
test.AddAttribute("axis", (int64_t)1);
|
||||
test.AddInput<float>("data", {3, 1},
|
||||
{1.0f,
|
||||
6.0f,
|
||||
9.0f});
|
||||
test.AddOutput<int64_t>("reduced", {3, 1},
|
||||
{0, 0, 0});
|
||||
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider});
|
||||
}
|
||||
|
||||
TEST(ReductionOpTest, ArgMin) {
|
||||
OpTester test("ArgMin");
|
||||
test.AddAttribute("axis", (int64_t)0);
|
||||
|
|
|
|||
|
|
@ -118,6 +118,7 @@ ReductionOpTest.ArgMax_Double_Type
|
|||
ReductionOpTest.ArgMax_do_not_keepdims
|
||||
ReductionOpTest.ArgMax_do_not_keepdims_2
|
||||
ReductionOpTest.ArgMax2D
|
||||
ReductionOpTest.ArgMax2D_dim1
|
||||
ReductionOpTest.ArgMin
|
||||
ReductionOpTest.ArgMin_Double_Type
|
||||
ReductionOpTest.ArgMin_Double_Precision
|
||||
|
|
|
|||
Loading…
Reference in a new issue