diff --git a/onnxruntime/contrib_ops/cpu/gather_nd.cc b/onnxruntime/contrib_ops/cpu/gather_nd.cc deleted file mode 100644 index 462367671f..0000000000 --- a/onnxruntime/contrib_ops/cpu/gather_nd.cc +++ /dev/null @@ -1,126 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include "contrib_ops/cpu/gather_nd.h" - -namespace onnxruntime { -namespace contrib { - -ONNX_OPERATOR_KERNEL_EX( - GatherND, - kMSDomain, - 1, - kCpuExecutionProvider, - KernelDefBuilder() - .TypeConstraint("T", DataTypeImpl::AllTensorTypes()) - .TypeConstraint("Tind", {DataTypeImpl::GetTensorType(),DataTypeImpl::GetTensorType()}), - GatherND); - -template -Status GatherNDBase::PrepareForCompute(OpKernelContext* context, Prepare& p) const { - - auto input_tensor = context->Input(0); - auto indice_tensor = context->Input(1); - ORT_ENFORCE(input_tensor != nullptr); - ORT_ENFORCE(indice_tensor != nullptr); - - auto input_shape = input_tensor->Shape(); - auto indice_shape = indice_tensor->Shape(); - if (indice_shape.NumDimensions() == 0) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "indices tensor must has rank larger than 0"); - } - - auto last_indice_dimension = indice_shape[indice_shape.NumDimensions() - 1]; - if (last_indice_dimension > static_cast(input_shape.NumDimensions())) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "last dimension of indices must not be larger than rank of input tensor"); - } - - std::vector shape(indice_shape.GetDims().begin(), - indice_shape.GetDims().end() - 1); - shape.insert(shape.end(), - input_shape.GetDims().begin() + last_indice_dimension, - input_shape.GetDims().end()); - auto output_tensor = context->Output(0,TensorShape(shape)); - std::vector element_counts(last_indice_dimension, 0LL); // Number of elements for each input dimension - -#ifdef USE_OPENMP -#pragma omp parallel for -#endif - for (int64_t i = 0; i < last_indice_dimension; ++i) { - element_counts[i] = input_shape.SizeFromDimension(i + 1); -} - - int64_t err_indice = 0; - p.element_bytes = input_tensor->DataType()->Size(); - p.element_to_copy = input_shape.SizeFromDimension(last_indice_dimension); - p.bytes_to_copy = p.element_bytes * p.element_to_copy; - auto indice_offset = indice_tensor->Data(); - auto offset_count = indice_shape.Size() / last_indice_dimension; // Times to copy - p.element_offsets.assign(offset_count, 0LL); - - if (input_tensor->DataType() == DataTypeImpl::GetType()) { - p.input_str_base = static_cast(input_tensor->DataRaw()); - p.output_str_base = static_cast(output_tensor->MutableDataRaw()); - } else { - p.input_base = static_cast(input_tensor->DataRaw()); - p.output_base = static_cast(output_tensor->MutableDataRaw()); - } - -#ifdef USE_OPENMP -#pragma omp parallel for -#endif - for (int64_t i = 0; i < offset_count; ++i) { - for (int64_t j = 0; j < last_indice_dimension; ++j) { - auto indice = *(indice_offset + i * last_indice_dimension + j); - if (indice < 0 || indice >= input_shape[j]) { - err_indice = indice; - } - p.element_offsets[i] += indice * element_counts[j]; - } - } - - return err_indice == 0 ? Status::OK() : - ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "invalid indice found, indice = ", err_indice); -} - -template Status GatherNDBase::PrepareForCompute(OpKernelContext*, Prepare&) const; -template Status GatherNDBase::PrepareForCompute(OpKernelContext*, Prepare&) const; - -Status GatherND::Compute(OpKernelContext* context) const { - Prepare p; - ORT_RETURN_IF_ERROR(context->Input(1)->DataType() == DataTypeImpl::GetType() ? - PrepareForCompute(context, p) : PrepareForCompute(context, p)); - - return nullptr == p.input_str_base ? GatherNumber(p) : GatherString(p); -} - -Status GatherND::GatherNumber(const Prepare& p) const { -#ifdef USE_OPENMP -#pragma omp parallel for -#endif - for (int64_t i = 0; i < static_cast(p.element_offsets.size()); ++i) { - memcpy(p.output_base + i * p.bytes_to_copy, - p.input_base + p.element_offsets[i] * p.element_bytes, - p.bytes_to_copy); - } - - return Status::OK(); -} - -Status GatherND::GatherString(const Prepare& p) const { -#ifdef USE_OPENMP -#pragma omp parallel for -#endif - for (int64_t i = 0; i < static_cast(p.element_offsets.size()); ++i) { - for (int64_t j = 0; j < static_cast(p.element_to_copy); ++j) { - p.output_str_base[i * p.element_to_copy + j] = p.input_str_base[p.element_offsets[i] + j]; - } - } - - return Status::OK(); -} - -} -} \ No newline at end of file diff --git a/onnxruntime/contrib_ops/cpu/gather_nd.h b/onnxruntime/contrib_ops/cpu/gather_nd.h deleted file mode 100644 index 2642f2a28b..0000000000 --- a/onnxruntime/contrib_ops/cpu/gather_nd.h +++ /dev/null @@ -1,50 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once - -#include "core/common/common.h" -#include "core/framework/op_kernel.h" -#include "core/platform/threadpool.h" - -namespace onnxruntime { -namespace contrib { - -class GatherNDBase -{ -protected: - struct Prepare { - const uint8_t* input_base; - const std::string* input_str_base; - uint8_t* output_base; - std::string* output_str_base; - uint64_t bytes_to_copy; - uint64_t element_bytes; - uint64_t element_to_copy; - std::vector element_offsets; - - Prepare(): input_base (nullptr), - input_str_base (nullptr), - output_base (nullptr), - output_str_base (nullptr), - bytes_to_copy (0), - element_bytes (0), - element_to_copy (0), - element_offsets (0) {} - }; // struct Prepare - - template - Status PrepareForCompute(OpKernelContext* context, Prepare& p) const; -}; // class GatherNDBase - -class GatherND final : public OpKernel, protected GatherNDBase { -public: - explicit GatherND(const OpKernelInfo& info) : OpKernel(info) {} - Status Compute(OpKernelContext* context) const override; -private: - Status GatherNumber(const Prepare& p) const; - Status GatherString(const Prepare& p) const; -}; - -} // namespace contrib -} // namespace onnxruntime \ No newline at end of file diff --git a/onnxruntime/core/providers/cpu/cpu_execution_provider.cc b/onnxruntime/core/providers/cpu/cpu_execution_provider.cc index 55510ae5c3..d766371837 100644 --- a/onnxruntime/core/providers/cpu/cpu_execution_provider.cc +++ b/onnxruntime/core/providers/cpu/cpu_execution_provider.cc @@ -128,9 +128,12 @@ class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOn class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 10, float, ReduceL2); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 10, int32_t, ReduceL2); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 10, float, ReduceLogSum); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 10, int32_t, ReduceLogSum); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 10, float, ReduceLogSumExp); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 10, int32_t, ReduceLogSumExp); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 10, int32_t, + ReduceLogSum); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 10, float, + ReduceLogSumExp); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 10, int32_t, + ReduceLogSumExp); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 10, float, ReduceMax); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 10, int32_t, ReduceMax); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 10, int64_t, ReduceMax); @@ -145,9 +148,12 @@ class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOn class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 10, int32_t, ReduceSum); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 10, double, ReduceSum); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 10, int64_t, ReduceSum); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 10, float, ReduceSumSquare); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 10, int32_t, ReduceSumSquare); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 10, double, ReduceSumSquare); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 10, float, + ReduceSumSquare); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 10, int32_t, + ReduceSumSquare); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 10, double, + ReduceSumSquare); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 10, float, ArgMax); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 10, int32_t, ArgMax); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 10, float, ArgMin); @@ -387,10 +393,12 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, Sc class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, Gemm); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, GatherElements); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, Pad); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, GatherND); void RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) { static const BuildKernelCreateInfoFn function_table[] = { - BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, @@ -400,7 +408,8 @@ void RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, @@ -408,63 +417,92 @@ void RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, @@ -474,251 +512,449 @@ void RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, // Opset 9 - BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, // Opset 10 BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, - //opset 11 + // opset 11 BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, @@ -729,25 +965,39 @@ void RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, @@ -759,6 +1009,7 @@ void RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, }; for (auto& function_table_entry : function_table) { @@ -821,55 +1072,93 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMLDomain, 2, void RegisterOnnxMLOperatorKernels(KernelRegistry& kernel_registry) { static const BuildKernelCreateInfoFn function_table[] = { - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, }; for (auto& function_table_entry : function_table) { diff --git a/onnxruntime/core/providers/cpu/tensor/gather_nd.cc b/onnxruntime/core/providers/cpu/tensor/gather_nd.cc new file mode 100644 index 0000000000..4be09345aa --- /dev/null +++ b/onnxruntime/core/providers/cpu/tensor/gather_nd.cc @@ -0,0 +1,141 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "gather_nd.h" + +namespace onnxruntime { + +// Register a kernel for kMsDomain (contrib op) GatherND +#ifndef DISABLE_CONTRIB_OPS + +namespace contrib { +// TODO: Remove this contrib kernel registration and the schema from the appropriate places +// once Keras Mask RCNN is shipped with all ONNX domain ops + +// Currently this kernel is required to support Keras Mask-RCNN +ONNX_OPERATOR_KERNEL_EX(GatherND, kMSDomain, 1, kCpuExecutionProvider, + KernelDefBuilder() + .TypeConstraint("T", DataTypeImpl::AllTensorTypes()) + // contrib spec supports `int32_t` and `int64_t` for indices + .TypeConstraint("Tind", {DataTypeImpl::GetTensorType(), + DataTypeImpl::GetTensorType()}), + GatherND); + +} // namespace contrib + +#endif + +ONNX_CPU_OPERATOR_KERNEL(GatherND, 11, + KernelDefBuilder() + .TypeConstraint("T", DataTypeImpl::AllTensorTypes()) + // official ONNX spec only supports `int64_t` for indices + .TypeConstraint("Tind", DataTypeImpl::GetTensorType()), + GatherND); + +template +Status GatherNDBase::PrepareForCompute(OpKernelContext* context, Prepare& p) const { + const auto* input_tensor = context->Input(0); + const auto* indices_tensor = context->Input(1); + ORT_ENFORCE(input_tensor != nullptr && indices_tensor != nullptr, "GatherND op: Input count mismatch"); + + const auto& input_shape = input_tensor->Shape(); + const auto& indices_shape = indices_tensor->Shape(); + if (indices_shape.NumDimensions() == 0) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "indices tensor must has rank larger than 0"); + } + + int64_t last_indices_dimension = indices_shape[indices_shape.NumDimensions() - 1]; + if (last_indices_dimension > static_cast(input_shape.NumDimensions())) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "last dimension of indices must not be larger than rank of input tensor"); + } + + std::vector shape(indices_shape.GetDims().begin(), indices_shape.GetDims().end() - 1); + shape.insert(shape.end(), input_shape.GetDims().begin() + last_indices_dimension, input_shape.GetDims().end()); + auto* output_tensor = context->Output(0, TensorShape(std::move(shape))); + std::vector element_counts(last_indices_dimension, + 0LL); // Number of elements for each input dimension + +#ifdef USE_OPENMP +#pragma omp parallel for +#endif + for (int64_t i = 0; i < last_indices_dimension; ++i) { + element_counts[i] = input_shape.SizeFromDimension(i + 1); + } + + int64_t err_index = 0; + p.element_bytes = input_tensor->DataType()->Size(); + p.element_to_copy = input_shape.SizeFromDimension(last_indices_dimension); + p.bytes_to_copy = p.element_bytes * p.element_to_copy; + const auto* indices_data = indices_tensor->Data(); + const int64_t offset_count = indices_shape.Size() / last_indices_dimension; // Times to copy + p.element_offsets.assign(offset_count, 0LL); + + if (input_tensor->DataType() == DataTypeImpl::GetType()) { + p.input_str_base = static_cast(input_tensor->DataRaw()); + p.output_str_base = static_cast(output_tensor->MutableDataRaw()); + } else { + p.input_base = static_cast(input_tensor->DataRaw()); + p.output_base = static_cast(output_tensor->MutableDataRaw()); + } + +#ifdef USE_OPENMP +#pragma omp parallel for +#endif + for (int64_t i = 0; i < offset_count; ++i) { + for (int64_t j = 0; j < last_indices_dimension; ++j) { + auto index = *(indices_data + i * last_indices_dimension + j); + auto upper_limit = input_shape[j]; + auto lower_limit = -upper_limit; + if (index < lower_limit || index >= upper_limit) { + err_index = index; + } + if (index < 0) { + index += static_cast(upper_limit); + } + p.element_offsets[i] += index * element_counts[j]; + } + } + + return err_index == 0 ? Status::OK() + : ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "invalid index found, index = ", err_index); +} + +template Status GatherNDBase::PrepareForCompute(OpKernelContext*, Prepare&) const; +template Status GatherNDBase::PrepareForCompute(OpKernelContext*, Prepare&) const; + +Status GatherND::Compute(OpKernelContext* context) const { + Prepare p; + ORT_RETURN_IF_ERROR(context->Input(1)->DataType() == DataTypeImpl::GetType() + ? PrepareForCompute(context, p) + : PrepareForCompute(context, p)); + + return nullptr == p.input_str_base ? GatherNumber(p) : GatherString(p); +} + +Status GatherND::GatherNumber(const Prepare& p) const { +#ifdef USE_OPENMP +#pragma omp parallel for +#endif + for (int64_t i = 0; i < static_cast(p.element_offsets.size()); ++i) { + memcpy(p.output_base + i * p.bytes_to_copy, p.input_base + p.element_offsets[i] * p.element_bytes, + p.bytes_to_copy); + } + + return Status::OK(); +} + +Status GatherND::GatherString(const Prepare& p) const { +#ifdef USE_OPENMP +#pragma omp parallel for +#endif + for (int64_t i = 0; i < static_cast(p.element_offsets.size()); ++i) { + for (int64_t j = 0; j < static_cast(p.element_to_copy); ++j) { + p.output_str_base[i * p.element_to_copy + j] = p.input_str_base[p.element_offsets[i] + j]; + } + } + + return Status::OK(); +} + +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/cpu/tensor/gather_nd.h b/onnxruntime/core/providers/cpu/tensor/gather_nd.h new file mode 100644 index 0000000000..a169c5ab78 --- /dev/null +++ b/onnxruntime/core/providers/cpu/tensor/gather_nd.h @@ -0,0 +1,48 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/common/common.h" +#include "core/framework/op_kernel.h" +#include "core/platform/threadpool.h" + +namespace onnxruntime { + +class GatherNDBase { + protected: + struct Prepare { + const uint8_t* input_base; + const std::string* input_str_base; + uint8_t* output_base; + std::string* output_str_base; + uint64_t bytes_to_copy; + uint64_t element_bytes; + uint64_t element_to_copy; + std::vector element_offsets; + + Prepare() : input_base(nullptr), + input_str_base(nullptr), + output_base(nullptr), + output_str_base(nullptr), + bytes_to_copy(0), + element_bytes(0), + element_to_copy(0), + element_offsets(0) {} + }; // struct Prepare + + template + Status PrepareForCompute(OpKernelContext* context, Prepare& p) const; +}; // class GatherNDBase + +class GatherND final : public OpKernel, protected GatherNDBase { + public: + explicit GatherND(const OpKernelInfo& info) : OpKernel(info) {} + Status Compute(OpKernelContext* context) const override; + + private: + Status GatherNumber(const Prepare& p) const; + Status GatherString(const Prepare& p) const; +}; + +} // namespace onnxruntime diff --git a/onnxruntime/test/contrib_ops/gather_nd_op_test.cc b/onnxruntime/test/contrib_ops/gather_nd_op_test.cc deleted file mode 100644 index f7af11bb1f..0000000000 --- a/onnxruntime/test/contrib_ops/gather_nd_op_test.cc +++ /dev/null @@ -1,157 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include "gtest/gtest.h" -#include "test/providers/provider_test_utils.h" - -namespace onnxruntime { -namespace test { - -TEST(GatherNDOpTest, GatherND_scaler_string_int32) { - OpTester test1("GatherND", 1, onnxruntime::kMSDomain); - test1.AddInput("data", {2,2}, {"h","k","o","z"}); - test1.AddInput("indices", {2}, {0,1}); - test1.AddOutput("output", {}, {"k"}); - test1.Run(); - - OpTester test2("GatherND", 1, onnxruntime::kMSDomain); - test2.AddInput("data", {6}, {"h","k","o","z","l","t"}); - test2.AddInput("indices", {1}, {3}); - test2.AddOutput("output", {}, {"z"}); - test2.Run(); - - OpTester test3("GatherND", 1, onnxruntime::kMSDomain); - test3.AddInput("data", {3,2}, {"h","k","o","z","l","t"}); - test3.AddInput("indices", {2}, {2,1}); - test3.AddOutput("output", {}, {"t"}); - test3.Run(); -} - -TEST(GatherNDOpTest, GatherND_matrice_int64_int64) { - OpTester test("GatherND", 1, onnxruntime::kMSDomain); - test.AddInput ("data", {2,2}, {0LL,1LL,2LL,3LL}); - test.AddInput ("indices", {2,2}, {0LL,0LL,1LL,1LL}); - test.AddOutput("output", {2}, {0LL,3LL}); - test.Run(); -} - -TEST(GatherNDOpTest, GatherND_matrice_string_int64) { - OpTester test("GatherND", 1, onnxruntime::kMSDomain); - test.AddInput("data", {2,2}, {"a","b","c","d"}); - test.AddInput("indices", {2,2}, {0LL,0LL,1LL,1LL}); - test.AddOutput("output", {2}, {"a","d"}); - test.Run(); -} - -TEST(GatherNDOpTest, GatherND_matrice_int64_int32) { - OpTester test("GatherND", 1, onnxruntime::kMSDomain); - test.AddInput("data", {2,2}, {0LL,1LL,2LL,3LL}); - test.AddInput("indices", {2,2}, {0,0,1,1}); - test.AddOutput("output", {2}, {0LL,3LL}); - test.Run(); -} - -TEST(GatherNDOpTest, GatherND_matrice_string_int32) { - OpTester test1("GatherND", 1, onnxruntime::kMSDomain); - test1.AddInput("data", {2,2,2}, {"egg","dance","air","bob","terry","smart","laugh","kite"}); - test1.AddInput("indices", {2,1,2}, {0,1,1,0}); - test1.AddOutput("output", {2,1,2}, {"air","bob","terry","smart"}); - test1.Run(); - - OpTester test2("GatherND", 1, onnxruntime::kMSDomain); - test2.AddInput("data", {3,3}, {"egg","dance","air","bob","terry","smart","laugh","kite","hop"}); - test2.AddInput("indices", {3,2}, {2,1,1,0,0,1}); - test2.AddOutput("output", {3}, {"kite","bob","dance"}); - test2.Run(); -} - -TEST(GatherNDOpTest, GatherND_slice_float_int64_t) { - OpTester test("GatherND", 1, onnxruntime::kMSDomain); - test.AddInput("data", {2,2}, {0.0f,0.1f,0.2f,0.3f}); - test.AddInput("indices", {2,1}, {1LL,0LL}); - test.AddOutput("output", {2,2}, {0.2f,0.3f,0.0f,0.1f}); - test.Run(); -} - -TEST(GatherNDOpTest, GatherND_slice_double_int32_t) { - OpTester test("GatherND", 1, onnxruntime::kMSDomain); - test.AddInput("data", {2,2}, {0.0f,0.1f,0.2f,0.3f}); - test.AddInput("indices", {2,1}, {1LL,0LL}); - test.AddOutput("output", {2,2}, {0.2f,0.3f,0.0f,0.1f}); - test.Run(); -} - -TEST(GatherNDOpTest, GatherND_3tensor_int64) { - OpTester test1("GatherND", 1, onnxruntime::kMSDomain); - test1.AddInput("data", {2,2,2}, {0LL,1LL,2LL,3LL,4LL,5LL,6LL,7LL}); - test1.AddInput("indices", {2,2}, {0LL,1LL,1LL,0LL}); - test1.AddOutput("output", {2,2}, {2LL,3LL,4LL,5LL}); - test1.Run(); - - OpTester test2("GatherND", 1, onnxruntime::kMSDomain); - test2.AddInput("data", {2,2,2}, {0,1,2,3,4,5,6,7}); - test2.AddInput("indices", {2,3}, {0,0,1,1,0,1}); - test2.AddOutput("output", {2}, {1,5}); - test2.Run(); - - OpTester test3("GatherND", 1, onnxruntime::kMSDomain); - test3.AddInput("data", {2,2,2}, {0,1,2,3,4,5,6,7}); - test3.AddInput("indices", {1,1}, {1LL}); - test3.AddOutput("output", {1,2,2}, {4,5,6,7}); - test3.Run(); -} - -TEST(GatherNDOpTest, GatherND_batched_index_int64) { - OpTester test("GatherND", 1, onnxruntime::kMSDomain); - test.AddInput("data", {2,2}, {0LL,1LL,2LL,3LL}); - test.AddInput("indices", {2,1,2}, {0LL,0LL,0LL,1LL}); - test.AddOutput("output", {2,1}, {0LL,1LL}); - test.Run(); -} - -TEST(GatherNDOpTest, GatherND_batched_index_bool_int64) { - OpTester test("GatherND", 1, onnxruntime::kMSDomain); - test.AddInput("data", {2,2}, {true,false,false,true}); - test.AddInput("indices", {2,1,2}, {0LL,0LL,0LL,1LL}); - test.AddOutput("output", {2,1}, {true,false}); - test.Run(); -} - -TEST(GatherNDOpTest, GatherND_sliced_index_int64) { - OpTester test("GatherND", 1, onnxruntime::kMSDomain); - test.AddInput("data", {2,2}, {0LL,1LL,2LL,3LL}); - test.AddInput("indices", {2,1,1}, {1LL,0LL}); - test.AddOutput("output", {2,1,2}, {2LL,3LL,0LL,1LL}); - test.Run(); -} - -TEST(GatherNDOpTest, GatherND_sliced_index_string_int32) { - OpTester test("GatherND", 1, onnxruntime::kMSDomain); - test.AddInput("data", {2,2}, {"ab","cde","f","ghi"}); - test.AddInput("indices", {2,1,1}, {1LL,0LL}); - test.AddOutput("output", {2,1,2}, {"f","ghi","ab","cde"}); - test.Run(); -} - -TEST(GatherNDOpTest, GatherND_batched_3tensor_int64) { - OpTester test1("GatherND", 1, onnxruntime::kMSDomain); - test1.AddInput("data", {2,2,2}, {0,1,2,3,4,5,6,7}); - test1.AddInput("indices", {2,2,2}, {0LL,1LL,1LL,0LL,0LL,0LL,1LL,1LL}); - test1.AddOutput("output", {2,2,2}, {2,3,4,5,0,1,6,7}); - test1.Run(); - - OpTester test2("GatherND", 1, onnxruntime::kMSDomain); - test2.AddInput("data", {2,2,2}, {0,1,2,3,4,5,6,7}); - test2.AddInput("indices", {2,2,3}, {0,0,1,1,0,1,0,1,1,1,1,0}); - test2.AddOutput("output", {2,2}, {1,5,3,6}); - test2.Run(); - - OpTester test3("GatherND", 1, onnxruntime::kMSDomain); - test3.AddInput("data", {2,2,2}, {0LL,1LL,2LL,3LL,4LL,5LL,6LL,7LL}); - test3.AddInput("indices", {2,1,1}, {1,0}); - test3.AddOutput("output", {2,1,2,2}, {4LL,5LL,6LL,7LL,0LL,1LL,2LL,3LL}); - test3.Run(); -} - -} // namespace test -} // namespace onnxruntime diff --git a/onnxruntime/test/onnx/main.cc b/onnxruntime/test/onnx/main.cc index f323dbd38e..3819272a71 100644 --- a/onnxruntime/test/onnx/main.cc +++ b/onnxruntime/test/onnx/main.cc @@ -406,8 +406,6 @@ int real_main(int argc, char* argv[], Ort::Env& env) { {"range_int32_type_negative_delta_expanded", "not implemented yet"}, {"det_2d", "not implemented yet"}, {"det_nd", "not implemented yet"}, - {"gathernd_example_float32", "not implemented yet"}, - {"gathernd_example_int32", "not implemented yet"}, {"resize_downsample_scales_cubic_A_n0p5_exclude_outside", "not implemented yet"}, {"resize_downsample_scales_cubic_align_corners", "not implemented yet"}, {"resize_downsample_scales_cubic", "not implemented yet"}, diff --git a/onnxruntime/test/providers/cpu/tensor/gather_nd_op_test.cc b/onnxruntime/test/providers/cpu/tensor/gather_nd_op_test.cc new file mode 100644 index 0000000000..91164a608f --- /dev/null +++ b/onnxruntime/test/providers/cpu/tensor/gather_nd_op_test.cc @@ -0,0 +1,118 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "gtest/gtest.h" +#include "test/providers/provider_test_utils.h" + +namespace onnxruntime { +namespace test { + +template +static void RunTest(const std::vector& input_dims, const std::initializer_list& input, + const std::vector& indices_dims, const std::initializer_list& indices, + const std::vector& output_dims, const std::initializer_list& output) { + // ONNX domain opset-11 + OpTester test1("GatherND", 11); + test1.AddInput("data", input_dims, input); + test1.AddInput("indices", indices_dims, indices); + test1.AddOutput("output", output_dims, output); + test1.Run(); + +#ifndef DISABLE_CONTRIB_OPS + + // MSFT domain opset-1 (contrib op) + OpTester test2("GatherND", 1, kMSDomain); + test2.AddInput("data", input_dims, input); + test2.AddInput("indices", indices_dims, indices); + test2.AddOutput("output", output_dims, output); + test2.Run(); + +#endif +} + +TEST(GatherNDOpTest, string) { + RunTest({2, 2}, {"h", "k", "o", "z"}, {2}, {0, 1}, {}, {"k"}); + + RunTest({6}, {"h", "k", "o", "z", "l", "t"}, {1}, {3}, {}, {"z"}); + + RunTest({3, 2}, {"h", "k", "o", "z", "l", "t"}, {2}, {2, 1}, {}, {"t"}); + + RunTest({2, 2}, {"a", "b", "c", "d"}, {2, 2}, {0LL, 0LL, 1LL, 1LL}, {2}, {"a", "d"}); + + RunTest({2, 2, 2}, {"egg", "dance", "air", "bob", "terry", "smart", "laugh", "kite"}, {2, 1, 2}, + {0LL, 1LL, 1LL, 0LL}, {2, 1, 2}, {"air", "bob", "terry", "smart"}); + + RunTest({3, 3}, {"egg", "dance", "air", "bob", "terry", "smart", "laugh", "kite", "hop"}, {3, 2}, + {2, 1, 1, 0, 0, 1}, {3}, {"kite", "bob", "dance"}); + + RunTest({2, 2}, {"ab", "cde", "f", "ghi"}, {2, 1, 1}, {1LL, 0LL}, {2, 1, 2}, {"f", "ghi", "ab", "cde"}); + + // with negative indices + RunTest({2, 2}, {"ab", "cde", "f", "ghi"}, {2, 1, 1}, {-1, 0}, {2, 1, 2}, {"f", "ghi", "ab", "cde"}); +} + +TEST(GatherNDOpTest, int64_t) { + RunTest({2, 2}, {0LL, 1LL, 2LL, 3LL}, {2, 2}, {0LL, 0LL, 1LL, 1LL}, {2}, {0LL, 3LL}); + + RunTest({2, 2, 2}, {0LL, 1LL, 2LL, 3LL, 4LL, 5LL, 6LL, 7LL}, {2, 2}, {0LL, 1LL, 1LL, 0LL}, {2, 2}, + {2LL, 3LL, 4LL, 5LL}); + + RunTest({2, 2}, {0LL, 1LL, 2LL, 3LL}, {2, 1, 2}, {0LL, 0LL, 0LL, 1LL}, {2, 1}, {0LL, 1LL}); + + RunTest({2, 2}, {0LL, 1LL, 2LL, 3LL}, {2, 1, 1}, {1LL, 0LL}, {2, 1, 2}, {2LL, 3LL, 0LL, 1LL}); + + RunTest({2, 2, 2}, {0LL, 1LL, 2LL, 3LL, 4LL, 5LL, 6LL, 7LL}, {2, 1, 1}, {1, 0}, {2, 1, 2, 2}, + {4LL, 5LL, 6LL, 7LL, 0LL, 1LL, 2LL, 3LL}); + + // with negative indices + RunTest({2, 2, 2}, {0LL, 1LL, 2LL, 3LL, 4LL, 5LL, 6LL, 7LL}, {2, 1, 1}, {-1, 0}, {2, 1, 2, 2}, + {4LL, 5LL, 6LL, 7LL, 0LL, 1LL, 2LL, 3LL}); +} + +TEST(GatherNDOpTest, float) { + RunTest({2, 2}, {0.0f, 0.1f, 0.2f, 0.3f}, {2, 1}, {1LL, 0LL}, {2, 2}, {0.2f, 0.3f, 0.0f, 0.1f}); +} + +TEST(GatherNDOpTest, double) { + RunTest({2, 2}, {0.0, 0.1, 0.2, 0.3}, {2, 1}, {1LL, 0LL}, {2, 2}, {0.2, 0.3, 0.0, 0.1}); +} + +TEST(GatherNDOpTest, int8_t) { + RunTest({2, 2, 2}, {0, 1, 2, 3, 4, 5, 6, 7}, {2, 3}, {0, 0, 1, 1, 0, 1}, {2}, {1, 5}); +} + +TEST(GatherNDOpTest, int16_t) { + RunTest({2, 2, 2}, {0, 1, 2, 3, 4, 5, 6, 7}, {1, 1}, {1}, {1, 2, 2}, {4, 5, 6, 7}); +} + +TEST(GatherNDOpTest, uint32_t) { + RunTest({2, 2, 2}, {0, 1, 2, 3, 4, 5, 6, 7}, {2, 2, 2}, {0LL, 1LL, 1LL, 0LL, 0LL, 0LL, 1LL, 1LL}, + {2, 2, 2}, {2, 3, 4, 5, 0, 1, 6, 7}); + + RunTest({2, 2, 2}, {0, 1, 2, 3, 4, 5, 6, 7}, {2, 2, 3}, {0, 0, 1, 1, 0, 1, 0, 1, 1, 1, 1, 0}, {2, 2}, + {1, 5, 3, 6}); +} + +TEST(GatherNDOpTest, bool) { + RunTest({2, 2}, {true, false, false, true}, {2, 1, 2}, {0LL, 0LL, 0LL, 1LL}, {2, 1}, {true, false}); +} + +#ifndef DISABLE_CONTRIB_OPS + +// The contrib spec of GatherND supports `int64` AND `int32` type for `indices` +// The official spec only support `int64` +// This test covers `int32` indices just for the contrib kernel + +TEST(GatherNDOpTest, ContribOpInt32Indices) { + // MSFT domain opset-1 (contrib op) + OpTester test2("GatherND", 1, kMSDomain); + test2.AddInput("data", {2, 2, 2}, {0, 1, 2, 3, 4, 5, 6, 7}); + test2.AddInput("indices", {2, 3}, {0, 0, 1, 1, 0, 1}); + test2.AddOutput("output", {2}, {1, 5}); + test2.Run(); +} + +#endif + +} // namespace test +} // namespace onnxruntime diff --git a/onnxruntime/test/python/onnx_backend_test_series.py b/onnxruntime/test/python/onnx_backend_test_series.py index 11c28d65c9..6dbc4b290a 100644 --- a/onnxruntime/test/python/onnx_backend_test_series.py +++ b/onnxruntime/test/python/onnx_backend_test_series.py @@ -124,8 +124,6 @@ def create_backend_test(testname=None): '^test_range_int32_type_negative_delta_expanded_cpu', '^test_det_2d_cpu', '^test_det_nd_cpu', - '^test_gathernd_example_float32_cpu', - '^test_gathernd_example_int32_cpu', '^test_resize_downsample_scales_cubic_A_n0p5_exclude_outside_cpu', '^test_resize_downsample_scales_cubic_align_corners_cpu', '^test_resize_downsample_scales_cubic_cpu',