mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-30 03:37:44 +00:00
[CUDA/ROCm] Conditionally support ArgMax and ArgMin for opset 12 and above (#22713)
### Description Based on https://github.com/microsoft/onnxruntime/pull/9700, and extend it to ArgMin as well. This pull request introduces several enhancements and fixes related to the `ArgMax` and `ArgMin` operators in the CUDA execution provider. The changes ensure proper handling of these operators across different versions and improve kernel registration and fallback mechanisms. Key changes include: #### Enhancements to `ArgMax` and `ArgMin` Operators: * Added new kernel class registrations for `ArgMax` and `ArgMin` for different data types and versions in `onnxruntime/core/providers/cuda/cuda_execution_provider.cc`. [[1]](diffhunk://#diff-57ba769b54dce57acd89df47140ede5f29ea670d61176096076701912d573285R966-R972) [[2]](diffhunk://#diff-57ba769b54dce57acd89df47140ede5f29ea670d61176096076701912d573285R1209-R1215) [[3]](diffhunk://#diff-57ba769b54dce57acd89df47140ede5f29ea670d61176096076701912d573285R1657-R1659) [[4]](diffhunk://#diff-57ba769b54dce57acd89df47140ede5f29ea670d61176096076701912d573285L1825-L1827) [[5]](diffhunk://#diff-57ba769b54dce57acd89df47140ede5f29ea670d61176096076701912d573285R1933-R1939) [[6]](diffhunk://#diff-57ba769b54dce57acd89df47140ede5f29ea670d61176096076701912d573285R2174-R2180) * Introduced `ArgMaxOrArgMinNeedFallbackToCPU` function to handle fallback to CPU when the `select_last_index` attribute is set to 1, as CUDA does not support this attribute. [[1]](diffhunk://#diff-57ba769b54dce57acd89df47140ede5f29ea670d61176096076701912d573285R2597-R2622) [[2]](diffhunk://#diff-57ba769b54dce57acd89df47140ede5f29ea670d61176096076701912d573285R2672-R2674) #### Macro and Kernel Registration Improvements: * Replaced `REGISTER_KERNEL_UNTIL_VERSIONED_TYPED` with `REGISTER_KERNEL_VERSIONED_RANGE_TYPED` and `REGISTER_KERNEL_VERSIONED_SINCE_TYPED` macros for better version handling. [[1]](diffhunk://#diff-ee5316fc3898058f70e942d9a84de36be4c7da09f144633a2504236430d5d033L19-R29) [[2]](diffhunk://#diff-ee5316fc3898058f70e942d9a84de36be4c7da09f144633a2504236430d5d033L40-R46) * Updated kernel registration for `ArgMax` and `ArgMin` to use the new macros, ensuring proper version handling and support for different data types. #### Safety Checks: * Added safety checks in the `ArgMax` and `ArgMin` classes to ensure `select_last_index` is not set to 1, as it is not supported on CUDA. [[1]](diffhunk://#diff-8ab09fef1f4a12cbf3b3432e509f8f1ef561e83c72778a0e047780060aeef6efL91-R99) [[2]](diffhunk://#diff-8ab09fef1f4a12cbf3b3432e509f8f1ef561e83c72778a0e047780060aeef6efL101-R117) #### Testing Enhancements: * Added new tests for `ArgMax` and `ArgMin` operators to verify behavior when `select_last_index` is set to 0, ensuring compatibility with both CPU and CUDA execution providers. [[1]](diffhunk://#diff-77affe1b70d1a9d38c2485f7c6b16ef2b6b541ed94dd727bc9b286f068f1481aR3340-R3360) [[2]](diffhunk://#diff-77affe1b70d1a9d38c2485f7c6b16ef2b6b541ed94dd727bc9b286f068f1481aR3679-R3699) ### Motivation and Context Improve CUDA kernel coverage for stable diffusion model and hence improve its performance on CUDA
This commit is contained in:
parent
d993ec313f
commit
ba22d7879a
7 changed files with 240 additions and 34 deletions
|
|
@ -554,8 +554,12 @@ Do not modify directly.*
|
|||
|||[7, 12]|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64), tensor(uint32), tensor(uint64)|
|
||||
|Affine|*in* X:**T**<br> *out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)|
|
||||
|And|*in* A:**T**<br> *in* B:**T**<br> *out* C:**T1**|7+|**T** = tensor(bool)<br/> **T1** = tensor(bool)|
|
||||
|ArgMax|*in* data:**T**<br> *out* reduced:**tensor(int64)**|[1, 11]|**T** = tensor(double), tensor(float), tensor(float16)|
|
||||
|ArgMin|*in* data:**T**<br> *out* reduced:**tensor(int64)**|[1, 11]|**T** = tensor(double), tensor(float), tensor(float16)|
|
||||
|ArgMax|*in* data:**T**<br> *out* reduced:**tensor(int64)**|13+|**T** = tensor(double), tensor(float), tensor(float16)|
|
||||
|||12|**T** = tensor(double), tensor(float), tensor(float16)|
|
||||
|||[1, 11]|**T** = tensor(double), tensor(float), tensor(float16)|
|
||||
|ArgMin|*in* data:**T**<br> *out* reduced:**tensor(int64)**|13+|**T** = tensor(double), tensor(float), tensor(float16)|
|
||||
|||12|**T** = tensor(double), tensor(float), tensor(float16)|
|
||||
|||[1, 11]|**T** = tensor(double), tensor(float), tensor(float16)|
|
||||
|AveragePool|*in* X:**T**<br> *out* Y:**T**|11+|**T** = tensor(double), tensor(float), tensor(float16)|
|
||||
|||10|**T** = tensor(double), tensor(float), tensor(float16)|
|
||||
|||[7, 9]|**T** = tensor(double), tensor(float), tensor(float16)|
|
||||
|
|
|
|||
|
|
@ -963,6 +963,13 @@ class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kO
|
|||
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 12, 12, Dropout);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 12, Einsum);
|
||||
|
||||
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 12, 12, float, ArgMax);
|
||||
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 12, 12, double, ArgMax);
|
||||
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 12, 12, MLFloat16, ArgMax);
|
||||
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 12, 12, float, ArgMin);
|
||||
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 12, 12, double, ArgMin);
|
||||
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 12, 12, MLFloat16, ArgMin);
|
||||
|
||||
// OpSet 13
|
||||
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, 14, Pow);
|
||||
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, 13, int32_t, Add);
|
||||
|
|
@ -1199,6 +1206,13 @@ class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kO
|
|||
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, 18, int8_t, DequantizeLinear);
|
||||
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, 18, uint8_t, DequantizeLinear);
|
||||
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, float, ArgMax);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, double, ArgMax);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, MLFloat16, ArgMax);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, float, ArgMin);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, double, ArgMin);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, MLFloat16, ArgMin);
|
||||
|
||||
// OpSet 14
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, CumSum);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, float, Relu);
|
||||
|
|
@ -1640,6 +1654,9 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) {
|
|||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 11, float, ArgMin)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 11, double, ArgMin)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 11, MLFloat16, ArgMin)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 11, float, ArgMax)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 11, double, ArgMax)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 11, MLFloat16, ArgMax)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 17, float, ReduceL1)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 17, double, ReduceL1)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 17, MLFloat16, ReduceL1)>,
|
||||
|
|
@ -1822,9 +1839,6 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) {
|
|||
19, IsInf)>,
|
||||
|
||||
// opset 11
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 11, float, ArgMax)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 11, double, ArgMax)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 11, MLFloat16, ArgMax)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, Compress)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, 12, Concat)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, 12, Flatten)>,
|
||||
|
|
@ -1916,6 +1930,13 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) {
|
|||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 12, 12, Dropout)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 12, Einsum)>,
|
||||
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 12, 12, float, ArgMax)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 12, 12, double, ArgMax)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 12, 12, MLFloat16, ArgMax)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 12, 12, float, ArgMin)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 12, 12, double, ArgMin)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 12, 12, MLFloat16, ArgMin)>,
|
||||
|
||||
// OpSet 13
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, 14, Pow)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, 13, int32_t, Add)>,
|
||||
|
|
@ -2150,6 +2171,13 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) {
|
|||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, 18, int8_t, DequantizeLinear)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, 18, uint8_t, DequantizeLinear)>,
|
||||
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, float, ArgMax)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, double, ArgMax)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, MLFloat16, ArgMax)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, float, ArgMin)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, double, ArgMin)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, MLFloat16, ArgMin)>,
|
||||
|
||||
// OpSet 14
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, CumSum)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, float, Relu)>,
|
||||
|
|
@ -2566,6 +2594,32 @@ static bool CastNeedFallbackToCPU(const onnxruntime::Node& node) {
|
|||
return false;
|
||||
}
|
||||
|
||||
static bool ArgMaxOrArgMinNeedFallbackToCPU(const onnxruntime::Node& node) {
|
||||
// Opset 12 introduced the attribute "select_last_index"
|
||||
if (node.SinceVersion() >= 12) {
|
||||
const auto& node_attributes = node.GetAttributes();
|
||||
|
||||
for (auto& attr : node_attributes) {
|
||||
auto& attr_name = attr.first;
|
||||
auto& attr_value = attr.second;
|
||||
|
||||
// CuDNN doesn't support picking the last index in case of encountering
|
||||
// duplicate max values.
|
||||
// CuDNN's API doc doesn't mention what happens in case duplicates are encountered,
|
||||
// but based on testing, the results seem to indicate a "stable" implementation
|
||||
// (i.e.) relative ordering is preserved which is the expected behavior when the
|
||||
// attribute takes on the default value (most common use-case for this operator).
|
||||
if ("select_last_index" == attr_name) {
|
||||
if (attr_value.i() != 0) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
std::unique_ptr<onnxruntime::IDataTransfer> CUDAExecutionProvider::GetDataTransfer() const {
|
||||
return std::make_unique<onnxruntime::GPUDataTransfer>();
|
||||
}
|
||||
|
|
@ -2615,6 +2669,9 @@ CUDAExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph,
|
|||
} else if ("ConvTranspose" == node.OpType()) {
|
||||
not_supported = ConvTransposeNeedFallbackToCPU(node, logger, graph, IsNHWCPreferred());
|
||||
force_inside = !not_supported;
|
||||
} else if ("ArgMax" == node.OpType() || "ArgMin" == node.OpType()) {
|
||||
not_supported = ArgMaxOrArgMinNeedFallbackToCPU(node);
|
||||
force_inside = !not_supported;
|
||||
} else if ("Cast" == node.OpType()) {
|
||||
not_supported = CastNeedFallbackToCPU(node);
|
||||
// cast is not compute heavy, and may be placed outside
|
||||
|
|
|
|||
|
|
@ -16,17 +16,17 @@ using namespace onnxruntime::common;
|
|||
namespace onnxruntime {
|
||||
namespace cuda {
|
||||
|
||||
#define REGISTER_KERNEL_UNTIL_VERSIONED_TYPED(name, T, end) \
|
||||
#define REGISTER_KERNEL_VERSIONED_RANGE_TYPED(name, T, begin, end) \
|
||||
ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX( \
|
||||
name, \
|
||||
kOnnxDomain, \
|
||||
1, end, \
|
||||
begin, end, \
|
||||
T, \
|
||||
kCudaExecutionProvider, \
|
||||
(*KernelDefBuilder::Create()).TypeConstraint("T", DataTypeImpl::GetTensorType<T>()), \
|
||||
name<T>);
|
||||
|
||||
#define REGISTER_KERNEL_TYPED_AXES_INPUT(name, T, version) \
|
||||
#define REGISTER_KERNEL_VERSIONED_SINCE_TYPED(name, T, version) \
|
||||
ONNX_OPERATOR_TYPED_KERNEL_EX( \
|
||||
name, \
|
||||
kOnnxDomain, \
|
||||
|
|
@ -37,8 +37,13 @@ namespace cuda {
|
|||
name<T>);
|
||||
|
||||
#define REGISTER_KERNEL_TYPED_AXES_INPUT_WITH_VERSIONED(name, T, last, cur) \
|
||||
REGISTER_KERNEL_UNTIL_VERSIONED_TYPED(name, T, last) \
|
||||
REGISTER_KERNEL_TYPED_AXES_INPUT(name, T, cur)
|
||||
REGISTER_KERNEL_VERSIONED_RANGE_TYPED(name, T, 1, last) \
|
||||
REGISTER_KERNEL_VERSIONED_SINCE_TYPED(name, T, cur)
|
||||
|
||||
#define REGISTER_KERNEL_ARGMIN_OR_ARGMAX(name, T) \
|
||||
REGISTER_KERNEL_VERSIONED_RANGE_TYPED(name, T, 1, 11) \
|
||||
REGISTER_KERNEL_VERSIONED_RANGE_TYPED(name, T, 12, 12) \
|
||||
REGISTER_KERNEL_VERSIONED_SINCE_TYPED(name, T, 13)
|
||||
|
||||
// TODO ReduceKernel::ReduceKernelShared() is still used by some other training classes though it's not used here - this should be refactored.
|
||||
template <bool allow_multi_axes>
|
||||
|
|
@ -829,14 +834,13 @@ template std::unique_ptr<Tensor> ReduceCompute<MLFloat16, CUDNN_REDUCE_TENSOR_NO
|
|||
|
||||
} // namespace ReductionOps
|
||||
|
||||
// CUDA ArgMax/ArgMin doesn't have OpSet12+ implementation (with select_last_index attr) yet
|
||||
REGISTER_KERNEL_UNTIL_VERSIONED_TYPED(ArgMax, MLFloat16, 11)
|
||||
REGISTER_KERNEL_UNTIL_VERSIONED_TYPED(ArgMax, float, 11)
|
||||
REGISTER_KERNEL_UNTIL_VERSIONED_TYPED(ArgMax, double, 11)
|
||||
REGISTER_KERNEL_ARGMIN_OR_ARGMAX(ArgMax, MLFloat16)
|
||||
REGISTER_KERNEL_ARGMIN_OR_ARGMAX(ArgMax, float)
|
||||
REGISTER_KERNEL_ARGMIN_OR_ARGMAX(ArgMax, double)
|
||||
|
||||
REGISTER_KERNEL_UNTIL_VERSIONED_TYPED(ArgMin, MLFloat16, 11)
|
||||
REGISTER_KERNEL_UNTIL_VERSIONED_TYPED(ArgMin, float, 11)
|
||||
REGISTER_KERNEL_UNTIL_VERSIONED_TYPED(ArgMin, double, 11)
|
||||
REGISTER_KERNEL_ARGMIN_OR_ARGMAX(ArgMin, MLFloat16)
|
||||
REGISTER_KERNEL_ARGMIN_OR_ARGMAX(ArgMin, float)
|
||||
REGISTER_KERNEL_ARGMIN_OR_ARGMAX(ArgMin, double)
|
||||
|
||||
REGISTER_KERNEL_TYPED_AXES_INPUT_WITH_VERSIONED(ReduceMax, MLFloat16, 17, 18)
|
||||
REGISTER_KERNEL_TYPED_AXES_INPUT_WITH_VERSIONED(ReduceMax, float, 17, 18)
|
||||
|
|
|
|||
|
|
@ -88,7 +88,15 @@ class ReduceKernel : public CudaKernel, public ReduceKernelBase<allow_multi_axes
|
|||
template <typename T>
|
||||
class ArgMax final : public ReduceKernel<false> {
|
||||
public:
|
||||
ArgMax(const OpKernelInfo& info) : ReduceKernel<false>(info) {}
|
||||
ArgMax(const OpKernelInfo& info) : ReduceKernel<false>(info) {
|
||||
// The following is just a safety check.
|
||||
// The logic in ArgMaxOrArgMinNeedFallbackToCPU() makes sure to not assign ArgMax
|
||||
// nodes with select_last_index == 1 to the CUDA EP.
|
||||
int64_t select_last_index = 0;
|
||||
if (info.GetAttr<int64_t>("select_last_index", &select_last_index).IsOK()) {
|
||||
ORT_ENFORCE(select_last_index == 0, "select_last_index as 1 is not supported on CUDA");
|
||||
}
|
||||
}
|
||||
|
||||
Status ComputeInternal(OpKernelContext* ctx) const override {
|
||||
return ComputeImpl<T, CUDNN_REDUCE_TENSOR_FLATTENED_INDICES>(ctx, CUDNN_REDUCE_TENSOR_MAX);
|
||||
|
|
@ -98,7 +106,15 @@ class ArgMax final : public ReduceKernel<false> {
|
|||
template <typename T>
|
||||
class ArgMin final : public ReduceKernel<false> {
|
||||
public:
|
||||
ArgMin(const OpKernelInfo& info) : ReduceKernel<false>(info) {}
|
||||
ArgMin(const OpKernelInfo& info) : ReduceKernel<false>(info) {
|
||||
// The following is just a safety check.
|
||||
// The logic in ArgMaxOrArgMinNeedFallbackToCPU() makes sure to not assign ArgMin
|
||||
// nodes with select_last_index == 1 to the CUDA EP.
|
||||
int64_t select_last_index = 0;
|
||||
if (info.GetAttr<int64_t>("select_last_index", &select_last_index).IsOK()) {
|
||||
ORT_ENFORCE(select_last_index == 0, "select_last_index as 1 is not supported on CUDA");
|
||||
}
|
||||
}
|
||||
|
||||
Status ComputeInternal(OpKernelContext* ctx) const override {
|
||||
return ComputeImpl<T, CUDNN_REDUCE_TENSOR_FLATTENED_INDICES>(ctx, CUDNN_REDUCE_TENSOR_MIN);
|
||||
|
|
|
|||
|
|
@ -16,17 +16,17 @@ using namespace onnxruntime::common;
|
|||
namespace onnxruntime {
|
||||
namespace rocm {
|
||||
|
||||
#define REGISTER_KERNEL_UNTIL_VERSIONED_TYPED(name, T, end) \
|
||||
#define REGISTER_KERNEL_VERSIONED_RANGE_TYPED(name, T, begin, end) \
|
||||
ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX( \
|
||||
name, \
|
||||
kOnnxDomain, \
|
||||
1, end, \
|
||||
begin, end, \
|
||||
T, \
|
||||
kRocmExecutionProvider, \
|
||||
(*KernelDefBuilder::Create()).TypeConstraint("T", DataTypeImpl::GetTensorType<T>()), \
|
||||
name<T>);
|
||||
|
||||
#define REGISTER_KERNEL_TYPED_AXES_INPUT(name, T, version) \
|
||||
#define REGISTER_KERNEL_VERSIONED_SINCE_TYPED(name, T, version) \
|
||||
ONNX_OPERATOR_TYPED_KERNEL_EX( \
|
||||
name, \
|
||||
kOnnxDomain, \
|
||||
|
|
@ -37,8 +37,13 @@ namespace rocm {
|
|||
name<T>);
|
||||
|
||||
#define REGISTER_KERNEL_TYPED_AXES_INPUT_WITH_VERSIONED(name, T, last, cur) \
|
||||
REGISTER_KERNEL_UNTIL_VERSIONED_TYPED(name, T, last) \
|
||||
REGISTER_KERNEL_TYPED_AXES_INPUT(name, T, cur)
|
||||
REGISTER_KERNEL_VERSIONED_RANGE_TYPED(name, T, 1, last) \
|
||||
REGISTER_KERNEL_VERSIONED_SINCE_TYPED(name, T, cur)
|
||||
|
||||
#define REGISTER_KERNEL_ARGMIN_OR_ARGMAX(name, T) \
|
||||
REGISTER_KERNEL_VERSIONED_RANGE_TYPED(name, T, 1, 11) \
|
||||
REGISTER_KERNEL_VERSIONED_RANGE_TYPED(name, T, 12, 12) \
|
||||
REGISTER_KERNEL_VERSIONED_SINCE_TYPED(name, T, 13)
|
||||
|
||||
// TODO ReduceKernel::ReduceKernelShared() is still used by some other training classes though it's not used here - this should be refactored.
|
||||
template <bool allow_multi_axes>
|
||||
|
|
@ -830,14 +835,13 @@ template std::unique_ptr<Tensor> ReduceCompute<MLFloat16, MIOPEN_REDUCE_TENSOR_N
|
|||
|
||||
} // namespace ReductionOps
|
||||
|
||||
// ROCM ArgMax/ArgMin doesn't have OpSet12+ implementation (with select_last_index attr) yet
|
||||
REGISTER_KERNEL_UNTIL_VERSIONED_TYPED(ArgMax, MLFloat16, 11)
|
||||
REGISTER_KERNEL_UNTIL_VERSIONED_TYPED(ArgMax, float, 11)
|
||||
// REGISTER_KERNEL_UNTIL_VERSIONED_TYPED(ArgMax, double, 11)
|
||||
REGISTER_KERNEL_ARGMIN_OR_ARGMAX(ArgMax, MLFloat16)
|
||||
REGISTER_KERNEL_ARGMIN_OR_ARGMAX(ArgMax, float)
|
||||
// REGISTER_KERNEL_ARGMIN_OR_ARGMAX(ArgMax, double)
|
||||
|
||||
REGISTER_KERNEL_UNTIL_VERSIONED_TYPED(ArgMin, MLFloat16, 11)
|
||||
REGISTER_KERNEL_UNTIL_VERSIONED_TYPED(ArgMin, float, 11)
|
||||
// REGISTER_KERNEL_UNTIL_VERSIONED_TYPED(ArgMin, double, 11)
|
||||
REGISTER_KERNEL_ARGMIN_OR_ARGMAX(ArgMin, MLFloat16)
|
||||
REGISTER_KERNEL_ARGMIN_OR_ARGMAX(ArgMin, float)
|
||||
// REGISTER_KERNEL_ARGMIN_OR_ARGMAX(ArgMin, double)
|
||||
|
||||
REGISTER_KERNEL_TYPED_AXES_INPUT_WITH_VERSIONED(ReduceMax, MLFloat16, 17, 18)
|
||||
REGISTER_KERNEL_TYPED_AXES_INPUT_WITH_VERSIONED(ReduceMax, float, 17, 18)
|
||||
|
|
|
|||
|
|
@ -926,6 +926,12 @@ class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kO
|
|||
|
||||
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 12, 12, Dropout);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 12, Einsum);
|
||||
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 12, 12, float, ArgMax);
|
||||
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 12, 12, double, ArgMax);
|
||||
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 12, 12, MLFloat16, ArgMax);
|
||||
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 12, 12, float, ArgMin);
|
||||
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 12, 12, double, ArgMin);
|
||||
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 12, 12, MLFloat16, ArgMin);
|
||||
|
||||
// OpSet 13
|
||||
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, 14, Pow);
|
||||
|
|
@ -1163,6 +1169,13 @@ class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kO
|
|||
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, 18, int8_t, DequantizeLinear);
|
||||
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, 18, uint8_t, DequantizeLinear);
|
||||
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, float, ArgMax);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, double, ArgMax);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, MLFloat16, ArgMax);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, float, ArgMin);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, double, ArgMin);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, MLFloat16, ArgMin);
|
||||
|
||||
// OpSet 14
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 14, CumSum);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 14, float, Relu);
|
||||
|
|
@ -1603,6 +1616,10 @@ static Status RegisterRocmKernels(KernelRegistry& kernel_registry) {
|
|||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 11, float, ArgMin)>,
|
||||
// BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 11, double, ArgMin)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 11, MLFloat16, ArgMin)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 11, float, ArgMax)>,
|
||||
// BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 11, double, ArgMax)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 11, MLFloat16, ArgMax)>,
|
||||
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 17, float, ReduceL1)>,
|
||||
// BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 17, double, ReduceL1)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 17, MLFloat16, ReduceL1)>,
|
||||
|
|
@ -1785,9 +1802,6 @@ static Status RegisterRocmKernels(KernelRegistry& kernel_registry) {
|
|||
19, IsInf)>,
|
||||
|
||||
// opset 11
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 11, float, ArgMax)>,
|
||||
// BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 11, double, ArgMax)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 11, MLFloat16, ArgMax)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, Compress)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, 12, Concat)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, 12, Flatten)>,
|
||||
|
|
@ -1879,6 +1893,13 @@ static Status RegisterRocmKernels(KernelRegistry& kernel_registry) {
|
|||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 12, 12, Dropout)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 12, Einsum)>,
|
||||
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 12, 12, float, ArgMax)>,
|
||||
// BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 12, 12, double, ArgMax)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 12, 12, MLFloat16, ArgMax)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 12, 12, float, ArgMin)>,
|
||||
// BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 12, 12, double, ArgMin)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 12, 12, MLFloat16, ArgMin)>,
|
||||
|
||||
// OpSet 13
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, 14, Pow)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, 13, int32_t, Add)>,
|
||||
|
|
@ -2112,6 +2133,12 @@ static Status RegisterRocmKernels(KernelRegistry& kernel_registry) {
|
|||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, 18, uint8_t, QuantizeLinear)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, 18, int8_t, DequantizeLinear)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, 18, uint8_t, DequantizeLinear)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, float, ArgMax)>,
|
||||
// BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, double, ArgMax)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, MLFloat16, ArgMax)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, float, ArgMin)>,
|
||||
// BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, double, ArgMin)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, MLFloat16, ArgMin)>,
|
||||
|
||||
// OpSet 14
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 14, CumSum)>,
|
||||
|
|
@ -2387,6 +2414,26 @@ static bool CastNeedFallbackToCPU(const onnxruntime::Node& node) {
|
|||
return false;
|
||||
}
|
||||
|
||||
static bool ArgMaxOrArgMinNeedFallbackToCPU(const onnxruntime::Node& node) {
|
||||
// Opset 12 introduced the attribute "select_last_index"
|
||||
if (node.SinceVersion() >= 12) {
|
||||
const auto& node_attributes = node.GetAttributes();
|
||||
|
||||
for (auto& attr : node_attributes) {
|
||||
auto& attr_name = attr.first;
|
||||
auto& attr_value = attr.second;
|
||||
|
||||
// It is not supported to pick the last index in case of encountering duplicate max values.
|
||||
if ("select_last_index" == attr_name) {
|
||||
if (attr_value.i() != 0) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
std::unique_ptr<onnxruntime::IDataTransfer> ROCMExecutionProvider::GetDataTransfer() const {
|
||||
return std::make_unique<onnxruntime::GPUDataTransfer>();
|
||||
}
|
||||
|
|
@ -2425,6 +2472,9 @@ ROCMExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph,
|
|||
"GRU" == node.OpType()) {
|
||||
not_supported = true;
|
||||
force_inside = !not_supported;
|
||||
} else if ("ArgMax" == node.OpType() || "ArgMin" == node.OpType()) {
|
||||
not_supported = ArgMaxOrArgMinNeedFallbackToCPU(node);
|
||||
force_inside = !not_supported;
|
||||
} else if ("Cast" == node.OpType()) {
|
||||
not_supported = CastNeedFallbackToCPU(node);
|
||||
// cast is not compute heavy, and may be placed outside
|
||||
|
|
|
|||
|
|
@ -3,6 +3,7 @@
|
|||
|
||||
#include <random>
|
||||
#include <cmath>
|
||||
#include <limits>
|
||||
#include <type_traits>
|
||||
#include "gtest/gtest.h"
|
||||
#include "test/common/dnnl_op_test_utils.h"
|
||||
|
|
@ -3337,6 +3338,41 @@ TEST(ReductionOpTest, ArgMax_int32_last_index_dups) {
|
|||
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider});
|
||||
}
|
||||
|
||||
TEST(ReductionOpTest, ArgMax_float_first_index_random) {
|
||||
OpTester test("ArgMax", 12);
|
||||
test.AddAttribute("axis", static_cast<int64_t>(0));
|
||||
test.AddAttribute("keepdims", static_cast<int64_t>(1));
|
||||
|
||||
// Since select_last_index is 0 by default, this test should run on both CPU and CUDA
|
||||
test.AddAttribute("select_last_index", static_cast<int64_t>(0));
|
||||
|
||||
constexpr size_t vector_size = 64 * 1024;
|
||||
constexpr float max_value = std::numeric_limits<float>::infinity();
|
||||
|
||||
std::random_device rd;
|
||||
std::mt19937 generator(rd());
|
||||
std::uniform_int_distribution<int> distribution(0, static_cast<int>(vector_size) - 1);
|
||||
|
||||
std::vector<float> data_vec(vector_size, 0.0f);
|
||||
|
||||
int min_index = -1;
|
||||
|
||||
// Try replace 8 elements with max_value. It is fine that some elements hit same index.
|
||||
for (int i = 0; i < 8; ++i) {
|
||||
int index = distribution(generator);
|
||||
data_vec[index] = max_value;
|
||||
if (i == 0 || index < min_index) {
|
||||
min_index = index;
|
||||
}
|
||||
}
|
||||
|
||||
test.AddInput<float>("data", {vector_size}, data_vec);
|
||||
test.AddOutput<int64_t>("reduced", {1}, {min_index});
|
||||
|
||||
// Exclude OpenVINO since it failed to handle this case.
|
||||
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider, kOpenVINOExecutionProvider});
|
||||
}
|
||||
|
||||
TEST(ReductionOpTest, ArgMax_int32_neg_axis) {
|
||||
OpTester test("ArgMax");
|
||||
test.AddAttribute("axis", (int64_t)(-2));
|
||||
|
|
@ -3655,6 +3691,41 @@ TEST(ReductionOpTest, ArgMin_int32_neg_axis) {
|
|||
test.Run();
|
||||
}
|
||||
|
||||
TEST(ReductionOpTest, ArgMin_float_first_index_random) {
|
||||
OpTester test("ArgMin", 13);
|
||||
test.AddAttribute("axis", static_cast<int64_t>(0));
|
||||
test.AddAttribute("keepdims", static_cast<int64_t>(1));
|
||||
|
||||
// Since select_last_index is 0 by default, this test should run on both CPU and CUDA
|
||||
test.AddAttribute("select_last_index", static_cast<int64_t>(0));
|
||||
|
||||
constexpr size_t vector_size = 64 * 1024;
|
||||
constexpr float min_value = -std::numeric_limits<float>::infinity();
|
||||
|
||||
std::random_device rd;
|
||||
std::mt19937 generator(rd());
|
||||
std::uniform_int_distribution<int> distribution(0, static_cast<int>(vector_size) - 1);
|
||||
|
||||
std::vector<float> data_vec(vector_size, 0.0f);
|
||||
|
||||
int min_index = -1;
|
||||
|
||||
// Try replace 8 elements with min_value. It is fine that some elements hit same index.
|
||||
for (int i = 0; i < 8; ++i) {
|
||||
int index = distribution(generator);
|
||||
data_vec[index] = min_value;
|
||||
if (i == 0 || index < min_index) {
|
||||
min_index = index;
|
||||
}
|
||||
}
|
||||
|
||||
test.AddInput<float>("data", {vector_size}, data_vec);
|
||||
test.AddOutput<int64_t>("reduced", {1}, {min_index});
|
||||
|
||||
// Exclude OpenVINO since it failed to handle this case.
|
||||
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider, kOpenVINOExecutionProvider});
|
||||
}
|
||||
|
||||
TEST(ReductionOpTest, OptimizeShapeForFastReduce_ReduceDimWithZero1) {
|
||||
FastReduceKind fast_kind;
|
||||
TensorShapeVector fast_shape, fast_output_shape, fast_axes;
|
||||
|
|
|
|||
Loading…
Reference in a new issue