From ba22d7879a4981f54079a582b1735d7d5fd38730 Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Wed, 6 Nov 2024 09:54:32 -0800 Subject: [PATCH] [CUDA/ROCm] Conditionally support ArgMax and ArgMin for opset 12 and above (#22713) ### Description Based on https://github.com/microsoft/onnxruntime/pull/9700, and extend it to ArgMin as well. This pull request introduces several enhancements and fixes related to the `ArgMax` and `ArgMin` operators in the CUDA execution provider. The changes ensure proper handling of these operators across different versions and improve kernel registration and fallback mechanisms. Key changes include: #### Enhancements to `ArgMax` and `ArgMin` Operators: * Added new kernel class registrations for `ArgMax` and `ArgMin` for different data types and versions in `onnxruntime/core/providers/cuda/cuda_execution_provider.cc`. [[1]](diffhunk://#diff-57ba769b54dce57acd89df47140ede5f29ea670d61176096076701912d573285R966-R972) [[2]](diffhunk://#diff-57ba769b54dce57acd89df47140ede5f29ea670d61176096076701912d573285R1209-R1215) [[3]](diffhunk://#diff-57ba769b54dce57acd89df47140ede5f29ea670d61176096076701912d573285R1657-R1659) [[4]](diffhunk://#diff-57ba769b54dce57acd89df47140ede5f29ea670d61176096076701912d573285L1825-L1827) [[5]](diffhunk://#diff-57ba769b54dce57acd89df47140ede5f29ea670d61176096076701912d573285R1933-R1939) [[6]](diffhunk://#diff-57ba769b54dce57acd89df47140ede5f29ea670d61176096076701912d573285R2174-R2180) * Introduced `ArgMaxOrArgMinNeedFallbackToCPU` function to handle fallback to CPU when the `select_last_index` attribute is set to 1, as CUDA does not support this attribute. [[1]](diffhunk://#diff-57ba769b54dce57acd89df47140ede5f29ea670d61176096076701912d573285R2597-R2622) [[2]](diffhunk://#diff-57ba769b54dce57acd89df47140ede5f29ea670d61176096076701912d573285R2672-R2674) #### Macro and Kernel Registration Improvements: * Replaced `REGISTER_KERNEL_UNTIL_VERSIONED_TYPED` with `REGISTER_KERNEL_VERSIONED_RANGE_TYPED` and `REGISTER_KERNEL_VERSIONED_SINCE_TYPED` macros for better version handling. [[1]](diffhunk://#diff-ee5316fc3898058f70e942d9a84de36be4c7da09f144633a2504236430d5d033L19-R29) [[2]](diffhunk://#diff-ee5316fc3898058f70e942d9a84de36be4c7da09f144633a2504236430d5d033L40-R46) * Updated kernel registration for `ArgMax` and `ArgMin` to use the new macros, ensuring proper version handling and support for different data types. #### Safety Checks: * Added safety checks in the `ArgMax` and `ArgMin` classes to ensure `select_last_index` is not set to 1, as it is not supported on CUDA. [[1]](diffhunk://#diff-8ab09fef1f4a12cbf3b3432e509f8f1ef561e83c72778a0e047780060aeef6efL91-R99) [[2]](diffhunk://#diff-8ab09fef1f4a12cbf3b3432e509f8f1ef561e83c72778a0e047780060aeef6efL101-R117) #### Testing Enhancements: * Added new tests for `ArgMax` and `ArgMin` operators to verify behavior when `select_last_index` is set to 0, ensuring compatibility with both CPU and CUDA execution providers. [[1]](diffhunk://#diff-77affe1b70d1a9d38c2485f7c6b16ef2b6b541ed94dd727bc9b286f068f1481aR3340-R3360) [[2]](diffhunk://#diff-77affe1b70d1a9d38c2485f7c6b16ef2b6b541ed94dd727bc9b286f068f1481aR3679-R3699) ### Motivation and Context Improve CUDA kernel coverage for stable diffusion model and hence improve its performance on CUDA --- docs/OperatorKernels.md | 8 ++- .../providers/cuda/cuda_execution_provider.cc | 63 +++++++++++++++- .../providers/cuda/reduction/reduction_ops.cc | 28 ++++---- .../providers/cuda/reduction/reduction_ops.h | 20 +++++- .../providers/rocm/reduction/reduction_ops.cc | 28 ++++---- .../providers/rocm/rocm_execution_provider.cc | 56 ++++++++++++++- .../cpu/reduction/reduction_ops_test.cc | 71 +++++++++++++++++++ 7 files changed, 240 insertions(+), 34 deletions(-) diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md index bd886abc98..5fb1e54b38 100644 --- a/docs/OperatorKernels.md +++ b/docs/OperatorKernels.md @@ -554,8 +554,12 @@ Do not modify directly.* |||[7, 12]|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64), tensor(uint32), tensor(uint64)| |Affine|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)| |And|*in* A:**T**
*in* B:**T**
*out* C:**T1**|7+|**T** = tensor(bool)
**T1** = tensor(bool)| -|ArgMax|*in* data:**T**
*out* reduced:**tensor(int64)**|[1, 11]|**T** = tensor(double), tensor(float), tensor(float16)| -|ArgMin|*in* data:**T**
*out* reduced:**tensor(int64)**|[1, 11]|**T** = tensor(double), tensor(float), tensor(float16)| +|ArgMax|*in* data:**T**
*out* reduced:**tensor(int64)**|13+|**T** = tensor(double), tensor(float), tensor(float16)| +|||12|**T** = tensor(double), tensor(float), tensor(float16)| +|||[1, 11]|**T** = tensor(double), tensor(float), tensor(float16)| +|ArgMin|*in* data:**T**
*out* reduced:**tensor(int64)**|13+|**T** = tensor(double), tensor(float), tensor(float16)| +|||12|**T** = tensor(double), tensor(float), tensor(float16)| +|||[1, 11]|**T** = tensor(double), tensor(float), tensor(float16)| |AveragePool|*in* X:**T**
*out* Y:**T**|11+|**T** = tensor(double), tensor(float), tensor(float16)| |||10|**T** = tensor(double), tensor(float), tensor(float16)| |||[7, 9]|**T** = tensor(double), tensor(float), tensor(float16)| diff --git a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc index 497d001479..8396e2629d 100644 --- a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc +++ b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc @@ -963,6 +963,13 @@ class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kO class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 12, 12, Dropout); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 12, Einsum); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 12, 12, float, ArgMax); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 12, 12, double, ArgMax); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 12, 12, MLFloat16, ArgMax); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 12, 12, float, ArgMin); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 12, 12, double, ArgMin); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 12, 12, MLFloat16, ArgMin); + // OpSet 13 class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, 14, Pow); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, 13, int32_t, Add); @@ -1199,6 +1206,13 @@ class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kO class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, 18, int8_t, DequantizeLinear); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, 18, uint8_t, DequantizeLinear); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, float, ArgMax); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, double, ArgMax); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, MLFloat16, ArgMax); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, float, ArgMin); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, double, ArgMin); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, MLFloat16, ArgMin); + // OpSet 14 class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, CumSum); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, float, Relu); @@ -1640,6 +1654,9 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, @@ -1822,9 +1839,6 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) { 19, IsInf)>, // opset 11 - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, @@ -1916,6 +1930,13 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + // OpSet 13 BuildKernelCreateInfo, BuildKernelCreateInfo, @@ -2150,6 +2171,13 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + // OpSet 14 BuildKernelCreateInfo, BuildKernelCreateInfo, @@ -2566,6 +2594,32 @@ static bool CastNeedFallbackToCPU(const onnxruntime::Node& node) { return false; } +static bool ArgMaxOrArgMinNeedFallbackToCPU(const onnxruntime::Node& node) { + // Opset 12 introduced the attribute "select_last_index" + if (node.SinceVersion() >= 12) { + const auto& node_attributes = node.GetAttributes(); + + for (auto& attr : node_attributes) { + auto& attr_name = attr.first; + auto& attr_value = attr.second; + + // CuDNN doesn't support picking the last index in case of encountering + // duplicate max values. + // CuDNN's API doc doesn't mention what happens in case duplicates are encountered, + // but based on testing, the results seem to indicate a "stable" implementation + // (i.e.) relative ordering is preserved which is the expected behavior when the + // attribute takes on the default value (most common use-case for this operator). + if ("select_last_index" == attr_name) { + if (attr_value.i() != 0) { + return true; + } + } + } + } + + return false; +} + std::unique_ptr CUDAExecutionProvider::GetDataTransfer() const { return std::make_unique(); } @@ -2615,6 +2669,9 @@ CUDAExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph, } else if ("ConvTranspose" == node.OpType()) { not_supported = ConvTransposeNeedFallbackToCPU(node, logger, graph, IsNHWCPreferred()); force_inside = !not_supported; + } else if ("ArgMax" == node.OpType() || "ArgMin" == node.OpType()) { + not_supported = ArgMaxOrArgMinNeedFallbackToCPU(node); + force_inside = !not_supported; } else if ("Cast" == node.OpType()) { not_supported = CastNeedFallbackToCPU(node); // cast is not compute heavy, and may be placed outside diff --git a/onnxruntime/core/providers/cuda/reduction/reduction_ops.cc b/onnxruntime/core/providers/cuda/reduction/reduction_ops.cc index 860bea67dc..4f8e6605ce 100644 --- a/onnxruntime/core/providers/cuda/reduction/reduction_ops.cc +++ b/onnxruntime/core/providers/cuda/reduction/reduction_ops.cc @@ -16,17 +16,17 @@ using namespace onnxruntime::common; namespace onnxruntime { namespace cuda { -#define REGISTER_KERNEL_UNTIL_VERSIONED_TYPED(name, T, end) \ +#define REGISTER_KERNEL_VERSIONED_RANGE_TYPED(name, T, begin, end) \ ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX( \ name, \ kOnnxDomain, \ - 1, end, \ + begin, end, \ T, \ kCudaExecutionProvider, \ (*KernelDefBuilder::Create()).TypeConstraint("T", DataTypeImpl::GetTensorType()), \ name); -#define REGISTER_KERNEL_TYPED_AXES_INPUT(name, T, version) \ +#define REGISTER_KERNEL_VERSIONED_SINCE_TYPED(name, T, version) \ ONNX_OPERATOR_TYPED_KERNEL_EX( \ name, \ kOnnxDomain, \ @@ -37,8 +37,13 @@ namespace cuda { name); #define REGISTER_KERNEL_TYPED_AXES_INPUT_WITH_VERSIONED(name, T, last, cur) \ - REGISTER_KERNEL_UNTIL_VERSIONED_TYPED(name, T, last) \ - REGISTER_KERNEL_TYPED_AXES_INPUT(name, T, cur) + REGISTER_KERNEL_VERSIONED_RANGE_TYPED(name, T, 1, last) \ + REGISTER_KERNEL_VERSIONED_SINCE_TYPED(name, T, cur) + +#define REGISTER_KERNEL_ARGMIN_OR_ARGMAX(name, T) \ + REGISTER_KERNEL_VERSIONED_RANGE_TYPED(name, T, 1, 11) \ + REGISTER_KERNEL_VERSIONED_RANGE_TYPED(name, T, 12, 12) \ + REGISTER_KERNEL_VERSIONED_SINCE_TYPED(name, T, 13) // TODO ReduceKernel::ReduceKernelShared() is still used by some other training classes though it's not used here - this should be refactored. template @@ -829,14 +834,13 @@ template std::unique_ptr ReduceCompute class ArgMax final : public ReduceKernel { public: - ArgMax(const OpKernelInfo& info) : ReduceKernel(info) {} + ArgMax(const OpKernelInfo& info) : ReduceKernel(info) { + // The following is just a safety check. + // The logic in ArgMaxOrArgMinNeedFallbackToCPU() makes sure to not assign ArgMax + // nodes with select_last_index == 1 to the CUDA EP. + int64_t select_last_index = 0; + if (info.GetAttr("select_last_index", &select_last_index).IsOK()) { + ORT_ENFORCE(select_last_index == 0, "select_last_index as 1 is not supported on CUDA"); + } + } Status ComputeInternal(OpKernelContext* ctx) const override { return ComputeImpl(ctx, CUDNN_REDUCE_TENSOR_MAX); @@ -98,7 +106,15 @@ class ArgMax final : public ReduceKernel { template class ArgMin final : public ReduceKernel { public: - ArgMin(const OpKernelInfo& info) : ReduceKernel(info) {} + ArgMin(const OpKernelInfo& info) : ReduceKernel(info) { + // The following is just a safety check. + // The logic in ArgMaxOrArgMinNeedFallbackToCPU() makes sure to not assign ArgMin + // nodes with select_last_index == 1 to the CUDA EP. + int64_t select_last_index = 0; + if (info.GetAttr("select_last_index", &select_last_index).IsOK()) { + ORT_ENFORCE(select_last_index == 0, "select_last_index as 1 is not supported on CUDA"); + } + } Status ComputeInternal(OpKernelContext* ctx) const override { return ComputeImpl(ctx, CUDNN_REDUCE_TENSOR_MIN); diff --git a/onnxruntime/core/providers/rocm/reduction/reduction_ops.cc b/onnxruntime/core/providers/rocm/reduction/reduction_ops.cc index 1340c49c38..d8b7e26d17 100644 --- a/onnxruntime/core/providers/rocm/reduction/reduction_ops.cc +++ b/onnxruntime/core/providers/rocm/reduction/reduction_ops.cc @@ -16,17 +16,17 @@ using namespace onnxruntime::common; namespace onnxruntime { namespace rocm { -#define REGISTER_KERNEL_UNTIL_VERSIONED_TYPED(name, T, end) \ +#define REGISTER_KERNEL_VERSIONED_RANGE_TYPED(name, T, begin, end) \ ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX( \ name, \ kOnnxDomain, \ - 1, end, \ + begin, end, \ T, \ kRocmExecutionProvider, \ (*KernelDefBuilder::Create()).TypeConstraint("T", DataTypeImpl::GetTensorType()), \ name); -#define REGISTER_KERNEL_TYPED_AXES_INPUT(name, T, version) \ +#define REGISTER_KERNEL_VERSIONED_SINCE_TYPED(name, T, version) \ ONNX_OPERATOR_TYPED_KERNEL_EX( \ name, \ kOnnxDomain, \ @@ -37,8 +37,13 @@ namespace rocm { name); #define REGISTER_KERNEL_TYPED_AXES_INPUT_WITH_VERSIONED(name, T, last, cur) \ - REGISTER_KERNEL_UNTIL_VERSIONED_TYPED(name, T, last) \ - REGISTER_KERNEL_TYPED_AXES_INPUT(name, T, cur) + REGISTER_KERNEL_VERSIONED_RANGE_TYPED(name, T, 1, last) \ + REGISTER_KERNEL_VERSIONED_SINCE_TYPED(name, T, cur) + +#define REGISTER_KERNEL_ARGMIN_OR_ARGMAX(name, T) \ + REGISTER_KERNEL_VERSIONED_RANGE_TYPED(name, T, 1, 11) \ + REGISTER_KERNEL_VERSIONED_RANGE_TYPED(name, T, 12, 12) \ + REGISTER_KERNEL_VERSIONED_SINCE_TYPED(name, T, 13) // TODO ReduceKernel::ReduceKernelShared() is still used by some other training classes though it's not used here - this should be refactored. template @@ -830,14 +835,13 @@ template std::unique_ptr ReduceCompute, // BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, + // BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, // BuildKernelCreateInfo, BuildKernelCreateInfo, @@ -1785,9 +1802,6 @@ static Status RegisterRocmKernels(KernelRegistry& kernel_registry) { 19, IsInf)>, // opset 11 - BuildKernelCreateInfo, - // BuildKernelCreateInfo, - BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, @@ -1879,6 +1893,13 @@ static Status RegisterRocmKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, + // BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + // BuildKernelCreateInfo, + BuildKernelCreateInfo, + // OpSet 13 BuildKernelCreateInfo, BuildKernelCreateInfo, @@ -2112,6 +2133,12 @@ static Status RegisterRocmKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, + // BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + // BuildKernelCreateInfo, + BuildKernelCreateInfo, // OpSet 14 BuildKernelCreateInfo, @@ -2387,6 +2414,26 @@ static bool CastNeedFallbackToCPU(const onnxruntime::Node& node) { return false; } +static bool ArgMaxOrArgMinNeedFallbackToCPU(const onnxruntime::Node& node) { + // Opset 12 introduced the attribute "select_last_index" + if (node.SinceVersion() >= 12) { + const auto& node_attributes = node.GetAttributes(); + + for (auto& attr : node_attributes) { + auto& attr_name = attr.first; + auto& attr_value = attr.second; + + // It is not supported to pick the last index in case of encountering duplicate max values. + if ("select_last_index" == attr_name) { + if (attr_value.i() != 0) { + return true; + } + } + } + } + + return false; +} std::unique_ptr ROCMExecutionProvider::GetDataTransfer() const { return std::make_unique(); } @@ -2425,6 +2472,9 @@ ROCMExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph, "GRU" == node.OpType()) { not_supported = true; force_inside = !not_supported; + } else if ("ArgMax" == node.OpType() || "ArgMin" == node.OpType()) { + not_supported = ArgMaxOrArgMinNeedFallbackToCPU(node); + force_inside = !not_supported; } else if ("Cast" == node.OpType()) { not_supported = CastNeedFallbackToCPU(node); // cast is not compute heavy, and may be placed outside diff --git a/onnxruntime/test/providers/cpu/reduction/reduction_ops_test.cc b/onnxruntime/test/providers/cpu/reduction/reduction_ops_test.cc index bb6d732fcc..c1c049ae5f 100644 --- a/onnxruntime/test/providers/cpu/reduction/reduction_ops_test.cc +++ b/onnxruntime/test/providers/cpu/reduction/reduction_ops_test.cc @@ -3,6 +3,7 @@ #include #include +#include #include #include "gtest/gtest.h" #include "test/common/dnnl_op_test_utils.h" @@ -3337,6 +3338,41 @@ TEST(ReductionOpTest, ArgMax_int32_last_index_dups) { test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); } +TEST(ReductionOpTest, ArgMax_float_first_index_random) { + OpTester test("ArgMax", 12); + test.AddAttribute("axis", static_cast(0)); + test.AddAttribute("keepdims", static_cast(1)); + + // Since select_last_index is 0 by default, this test should run on both CPU and CUDA + test.AddAttribute("select_last_index", static_cast(0)); + + constexpr size_t vector_size = 64 * 1024; + constexpr float max_value = std::numeric_limits::infinity(); + + std::random_device rd; + std::mt19937 generator(rd()); + std::uniform_int_distribution distribution(0, static_cast(vector_size) - 1); + + std::vector data_vec(vector_size, 0.0f); + + int min_index = -1; + + // Try replace 8 elements with max_value. It is fine that some elements hit same index. + for (int i = 0; i < 8; ++i) { + int index = distribution(generator); + data_vec[index] = max_value; + if (i == 0 || index < min_index) { + min_index = index; + } + } + + test.AddInput("data", {vector_size}, data_vec); + test.AddOutput("reduced", {1}, {min_index}); + + // Exclude OpenVINO since it failed to handle this case. + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider, kOpenVINOExecutionProvider}); +} + TEST(ReductionOpTest, ArgMax_int32_neg_axis) { OpTester test("ArgMax"); test.AddAttribute("axis", (int64_t)(-2)); @@ -3655,6 +3691,41 @@ TEST(ReductionOpTest, ArgMin_int32_neg_axis) { test.Run(); } +TEST(ReductionOpTest, ArgMin_float_first_index_random) { + OpTester test("ArgMin", 13); + test.AddAttribute("axis", static_cast(0)); + test.AddAttribute("keepdims", static_cast(1)); + + // Since select_last_index is 0 by default, this test should run on both CPU and CUDA + test.AddAttribute("select_last_index", static_cast(0)); + + constexpr size_t vector_size = 64 * 1024; + constexpr float min_value = -std::numeric_limits::infinity(); + + std::random_device rd; + std::mt19937 generator(rd()); + std::uniform_int_distribution distribution(0, static_cast(vector_size) - 1); + + std::vector data_vec(vector_size, 0.0f); + + int min_index = -1; + + // Try replace 8 elements with min_value. It is fine that some elements hit same index. + for (int i = 0; i < 8; ++i) { + int index = distribution(generator); + data_vec[index] = min_value; + if (i == 0 || index < min_index) { + min_index = index; + } + } + + test.AddInput("data", {vector_size}, data_vec); + test.AddOutput("reduced", {1}, {min_index}); + + // Exclude OpenVINO since it failed to handle this case. + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider, kOpenVINOExecutionProvider}); +} + TEST(ReductionOpTest, OptimizeShapeForFastReduce_ReduceDimWithZero1) { FastReduceKind fast_kind; TensorShapeVector fast_shape, fast_output_shape, fast_axes;