From d0544a80824de2fd3d18cc60dd39f99f2c781d9d Mon Sep 17 00:00:00 2001 From: Randy <45701928+RandyShuai@users.noreply.github.com> Date: Mon, 17 Dec 2018 13:47:20 -0800 Subject: [PATCH] Rashuai/gathernd op (#170) * define gather_nd op * add test cases * add test file * refactor the code and doc * add test cases * fix win compile err * fix win compile err * adjust indent * make constructor explicit * add coment * remove templates * remove wrong def * migrate macros * fix an issue in shape inference --- onnxruntime/contrib_ops/contrib_kernels.cc | 2 + onnxruntime/contrib_ops/cpu/gather_nd.cc | 114 +++++++++++++ onnxruntime/contrib_ops/cpu/gather_nd.h | 49 ++++++ .../core/graph/contrib_ops/contrib_defs.cc | 65 ++++++++ .../test/contrib_ops/gather_nd_op_test.cc | 157 ++++++++++++++++++ 5 files changed, 387 insertions(+) create mode 100644 onnxruntime/contrib_ops/cpu/gather_nd.cc create mode 100644 onnxruntime/contrib_ops/cpu/gather_nd.h create mode 100644 onnxruntime/test/contrib_ops/gather_nd_op_test.cc diff --git a/onnxruntime/contrib_ops/contrib_kernels.cc b/onnxruntime/contrib_ops/contrib_kernels.cc index 11b7e653f3..a8f8876e77 100644 --- a/onnxruntime/contrib_ops/contrib_kernels.cc +++ b/onnxruntime/contrib_ops/contrib_kernels.cc @@ -17,6 +17,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, string, StringNormalizer); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, NonMaxSuppression); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, Range); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, GatherND); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, MurmurHash3); void RegisterContribKernels(std::function fn) { @@ -33,6 +34,7 @@ void RegisterContribKernels(std::function fn) { fn(BuildKernel()); fn(BuildKernel()); fn(BuildKernel()); + fn(BuildKernel()); fn(BuildKernel()); } } // namespace contrib diff --git a/onnxruntime/contrib_ops/cpu/gather_nd.cc b/onnxruntime/contrib_ops/cpu/gather_nd.cc new file mode 100644 index 0000000000..f13794aa6b --- /dev/null +++ b/onnxruntime/contrib_ops/cpu/gather_nd.cc @@ -0,0 +1,114 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "contrib_ops/cpu/gather_nd.h" + +namespace onnxruntime { +namespace contrib { + +ONNX_OPERATOR_KERNEL_EX( + GatherND, + kMSDomain, + 1, + kCpuExecutionProvider, + KernelDefBuilder() + .TypeConstraint("T", DataTypeImpl::AllTensorTypes()) + .TypeConstraint("Tind", {DataTypeImpl::GetTensorType(),DataTypeImpl::GetTensorType()}), + GatherND); + +template +Status GatherNDBase::PrepareForCompute(OpKernelContext* context, Prepare& p) const { + + auto input_tensor = context->Input(0); + auto indice_tensor = context->Input(1); + ORT_ENFORCE(input_tensor != nullptr); + ORT_ENFORCE(indice_tensor != nullptr); + + auto input_shape = input_tensor->Shape(); + auto indice_shape = indice_tensor->Shape(); + if (indice_shape.NumDimensions() == 0) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "indices tensor must has rank larger than 0"); + } + + auto last_indice_dimension = indice_shape[indice_shape.NumDimensions() - 1]; + if (last_indice_dimension > static_cast(input_shape.NumDimensions())) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "last dimension of indices must not be larger than rank of input tensor"); + } + + std::vector shape(indice_shape.GetDims().begin(), + indice_shape.GetDims().end() - 1); + shape.insert(shape.end(), + input_shape.GetDims().begin() + last_indice_dimension, + input_shape.GetDims().end()); + auto output_tensor = context->Output(0,TensorShape(shape)); + std::vector element_counts(last_indice_dimension, 0LL); // Number of elements for each input dimension + +#pragma omp parallel for + for (int64_t i = 0; i < last_indice_dimension; ++i) { + element_counts[i] = input_shape.SizeFromDimension(i + 1); + } + + int64_t err_indice = 0; + p.element_bytes = input_tensor->DataType()->Size(); + p.element_to_copy = input_shape.SizeFromDimension(last_indice_dimension); + p.bytes_to_copy = p.element_bytes * p.element_to_copy; + auto indice_offset = static_cast(context->Input(1)->DataRaw()); + auto offset_count = indice_shape.Size() / last_indice_dimension; // Times to copy + p.element_offsets.assign(offset_count, 0LL); + + if (input_tensor->DataType() == DataTypeImpl::GetType()) { + p.input_str_base = static_cast(input_tensor->DataRaw()); + p.output_str_base = static_cast(output_tensor->MutableDataRaw()); + } else { + p.input_base = static_cast(context->Input(0)->DataRaw()); + p.output_base = static_cast(output_tensor->MutableDataRaw()); + } + +#pragma omp parallel for + for (int64_t i = 0; i < offset_count; ++i) { + for (int64_t j = 0; j < last_indice_dimension; ++j) { + auto indice = *(indice_offset + i * last_indice_dimension + j); + if (indice < 0 || indice >= input_shape[j]) { + err_indice = indice; + } + p.element_offsets[i] += indice * element_counts[j]; + } + } + return err_indice == 0 ? Status::OK() : + ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "invalid indice found, indice = ", err_indice); +} + +template Status GatherNDBase::PrepareForCompute(OpKernelContext*, Prepare&) const; +template Status GatherNDBase::PrepareForCompute(OpKernelContext*, Prepare&) const; + +Status GatherND::Compute(OpKernelContext* context) const { + Prepare p; + ORT_RETURN_IF_ERROR(context->Input(1)->DataType() == DataTypeImpl::GetType() ? + PrepareForCompute(context, p) : PrepareForCompute(context, p)); + return nullptr == p.input_str_base ? GatherNumber(p) : GatherString(p); +} + +Status GatherND::GatherNumber(const Prepare& p) const { +#pragma omp parallel for + for (int64_t i = 0; i < static_cast(p.element_offsets.size()); ++i) { + memcpy(p.output_base + i * p.bytes_to_copy, + p.input_base + p.element_offsets[i] * p.element_bytes, + p.bytes_to_copy); + } + return Status::OK(); +} + +Status GatherND::GatherString(const Prepare& p) const { +#pragma omp parallel for + for (int64_t i = 0; i < static_cast(p.element_offsets.size()); ++i) { + for (int64_t j = 0; j < static_cast(p.element_to_copy); ++j) { + p.output_str_base[i * p.element_to_copy + j] = p.input_str_base[p.element_offsets[i] + j]; + } + } + return Status::OK(); +} + +} +} diff --git a/onnxruntime/contrib_ops/cpu/gather_nd.h b/onnxruntime/contrib_ops/cpu/gather_nd.h new file mode 100644 index 0000000000..6d256e07fd --- /dev/null +++ b/onnxruntime/contrib_ops/cpu/gather_nd.h @@ -0,0 +1,49 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/common/common.h" +#include "core/framework/op_kernel.h" + +namespace onnxruntime { +namespace contrib { + +class GatherNDBase +{ +protected: + struct Prepare { + const uint8_t* input_base; + const std::string* input_str_base; + uint8_t* output_base; + std::string* output_str_base; + uint64_t bytes_to_copy; + uint64_t element_bytes; + uint64_t element_to_copy; + std::vector element_offsets; + + Prepare(): input_base (nullptr), + input_str_base (nullptr), + output_base (nullptr), + output_str_base (nullptr), + bytes_to_copy (0), + element_bytes (0), + element_to_copy (0), + element_offsets (0) {} + }; // struct Prepare + + template + Status PrepareForCompute(OpKernelContext* context, Prepare& p) const; +}; // class GatherNDBase + +class GatherND final : public OpKernel, protected GatherNDBase { +public: + explicit GatherND(const OpKernelInfo& info) : OpKernel(info) {} + Status Compute(OpKernelContext* context) const override; +private: + Status GatherNumber(const Prepare& p) const; + Status GatherString(const Prepare& p) const; +}; + +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc index 7316b85235..a125c72068 100644 --- a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc @@ -588,6 +588,71 @@ The bounding box coordinates corresponding to the selected indices can then be o output_elem_type->set_elem_type(ONNX_NAMESPACE::TensorProto::STRING); }) .SetDoc(R"DOC([optional] Step1: Remove elements in X if they match any of the stop words so that the output tensor will not contain any stop words. This operator only accepts [C]- and [1, C]-tensors. If all elements in X are dropped, the output will be the default value of string tensor with shape [1] if input shape is [C] and shape [1, 1] if input shape is [1, C].)DOC"); + + ONNX_CONTRIB_OPERATOR_SCHEMA(GatherND) + .SetDomain(kMSDomain) + .SinceVersion(1) + .Input (0, "data", "Tensor of rank r >= 1.", "T" ) + .Input (1, "indices", "Tensor of rank q >= 1.", "Tind" ) + .Output (0, "output", "Tensor of rank q-1+r-indices[-1].", "T" ) + .TypeConstraint( + "T", + OpSchema::all_tensor_types(), + "Constrain input and output types to any tensor type.") + .TypeConstraint( + "Tind", + {"tensor(int32)", "tensor(int64)"}, + "Constrain indice type to int32 or int64") + .TypeAndShapeInferenceFunction( [] (ONNX_NAMESPACE::InferenceContext& ctx) { + propagateElemTypeFromInputToOutput(ctx, 0, 0); + if (!hasNInputShapes(ctx, 2)) { + fail_shape_inference("GatherND requires two tensor inputs."); + } + auto& data_shape = ctx.getInputType(0)->tensor_type().shape(); + auto& indices_shape = ctx.getInputType(1)->tensor_type().shape(); + auto data_rank = data_shape.dim_size(); + auto indices_rank = indices_shape.dim_size(); + if (data_rank < 1 || indices_rank < 1) { + fail_shape_inference("both data and indices tensor need to have rank larger than zero."); + } + auto last_indice_dimension = indices_shape.dim(indices_rank - 1).dim_value(); + if (last_indice_dimension > data_rank) { + fail_shape_inference("last dimension of indices must not be larger and rank of data tensor"); + } + for (int i = 0; i < indices_rank - 1; ++i) { + *ctx.getOutputType(0) + ->mutable_tensor_type() + ->mutable_shape() + ->add_dim() = indices_shape.dim(i); + } + for (int i = static_cast(last_indice_dimension); i < data_rank; ++i) { + *ctx.getOutputType(0) + ->mutable_tensor_type() + ->mutable_shape() + ->add_dim() = data_shape.dim(i); + } + }) + .SetDoc(R"DOC( +Given `data` tensor of rank r >= 1, and `indices` tensor of rank q >= 1, gather +slices of `data` into an output tensor of rank q - 1 + r - indices[-1]. +Example 1: + data = [[0,1],[2,3]] + indices = [[0,0],[1,1]] + output = [0,3] +Example 2: + data = [[0,1],[2,3]] + indices = [[1],[0]] + output = [[2,3],[0,1]] +Example 3: + data = [[[0,1],[2,3]],[[4,5],[6,7]]] + indices = [[0,1],[1,0]] + output = [[2,3],[4,5]] +Example 4: + data = [[[0,1],[2,3]],[[4,5],[6,7]]] + indices = [[[0,1]],[[1,0]]] + output = [[[2,3]],[[4,5]]] +)DOC"); + } } // namespace contrib } // namespace onnxruntime diff --git a/onnxruntime/test/contrib_ops/gather_nd_op_test.cc b/onnxruntime/test/contrib_ops/gather_nd_op_test.cc new file mode 100644 index 0000000000..f7af11bb1f --- /dev/null +++ b/onnxruntime/test/contrib_ops/gather_nd_op_test.cc @@ -0,0 +1,157 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "gtest/gtest.h" +#include "test/providers/provider_test_utils.h" + +namespace onnxruntime { +namespace test { + +TEST(GatherNDOpTest, GatherND_scaler_string_int32) { + OpTester test1("GatherND", 1, onnxruntime::kMSDomain); + test1.AddInput("data", {2,2}, {"h","k","o","z"}); + test1.AddInput("indices", {2}, {0,1}); + test1.AddOutput("output", {}, {"k"}); + test1.Run(); + + OpTester test2("GatherND", 1, onnxruntime::kMSDomain); + test2.AddInput("data", {6}, {"h","k","o","z","l","t"}); + test2.AddInput("indices", {1}, {3}); + test2.AddOutput("output", {}, {"z"}); + test2.Run(); + + OpTester test3("GatherND", 1, onnxruntime::kMSDomain); + test3.AddInput("data", {3,2}, {"h","k","o","z","l","t"}); + test3.AddInput("indices", {2}, {2,1}); + test3.AddOutput("output", {}, {"t"}); + test3.Run(); +} + +TEST(GatherNDOpTest, GatherND_matrice_int64_int64) { + OpTester test("GatherND", 1, onnxruntime::kMSDomain); + test.AddInput ("data", {2,2}, {0LL,1LL,2LL,3LL}); + test.AddInput ("indices", {2,2}, {0LL,0LL,1LL,1LL}); + test.AddOutput("output", {2}, {0LL,3LL}); + test.Run(); +} + +TEST(GatherNDOpTest, GatherND_matrice_string_int64) { + OpTester test("GatherND", 1, onnxruntime::kMSDomain); + test.AddInput("data", {2,2}, {"a","b","c","d"}); + test.AddInput("indices", {2,2}, {0LL,0LL,1LL,1LL}); + test.AddOutput("output", {2}, {"a","d"}); + test.Run(); +} + +TEST(GatherNDOpTest, GatherND_matrice_int64_int32) { + OpTester test("GatherND", 1, onnxruntime::kMSDomain); + test.AddInput("data", {2,2}, {0LL,1LL,2LL,3LL}); + test.AddInput("indices", {2,2}, {0,0,1,1}); + test.AddOutput("output", {2}, {0LL,3LL}); + test.Run(); +} + +TEST(GatherNDOpTest, GatherND_matrice_string_int32) { + OpTester test1("GatherND", 1, onnxruntime::kMSDomain); + test1.AddInput("data", {2,2,2}, {"egg","dance","air","bob","terry","smart","laugh","kite"}); + test1.AddInput("indices", {2,1,2}, {0,1,1,0}); + test1.AddOutput("output", {2,1,2}, {"air","bob","terry","smart"}); + test1.Run(); + + OpTester test2("GatherND", 1, onnxruntime::kMSDomain); + test2.AddInput("data", {3,3}, {"egg","dance","air","bob","terry","smart","laugh","kite","hop"}); + test2.AddInput("indices", {3,2}, {2,1,1,0,0,1}); + test2.AddOutput("output", {3}, {"kite","bob","dance"}); + test2.Run(); +} + +TEST(GatherNDOpTest, GatherND_slice_float_int64_t) { + OpTester test("GatherND", 1, onnxruntime::kMSDomain); + test.AddInput("data", {2,2}, {0.0f,0.1f,0.2f,0.3f}); + test.AddInput("indices", {2,1}, {1LL,0LL}); + test.AddOutput("output", {2,2}, {0.2f,0.3f,0.0f,0.1f}); + test.Run(); +} + +TEST(GatherNDOpTest, GatherND_slice_double_int32_t) { + OpTester test("GatherND", 1, onnxruntime::kMSDomain); + test.AddInput("data", {2,2}, {0.0f,0.1f,0.2f,0.3f}); + test.AddInput("indices", {2,1}, {1LL,0LL}); + test.AddOutput("output", {2,2}, {0.2f,0.3f,0.0f,0.1f}); + test.Run(); +} + +TEST(GatherNDOpTest, GatherND_3tensor_int64) { + OpTester test1("GatherND", 1, onnxruntime::kMSDomain); + test1.AddInput("data", {2,2,2}, {0LL,1LL,2LL,3LL,4LL,5LL,6LL,7LL}); + test1.AddInput("indices", {2,2}, {0LL,1LL,1LL,0LL}); + test1.AddOutput("output", {2,2}, {2LL,3LL,4LL,5LL}); + test1.Run(); + + OpTester test2("GatherND", 1, onnxruntime::kMSDomain); + test2.AddInput("data", {2,2,2}, {0,1,2,3,4,5,6,7}); + test2.AddInput("indices", {2,3}, {0,0,1,1,0,1}); + test2.AddOutput("output", {2}, {1,5}); + test2.Run(); + + OpTester test3("GatherND", 1, onnxruntime::kMSDomain); + test3.AddInput("data", {2,2,2}, {0,1,2,3,4,5,6,7}); + test3.AddInput("indices", {1,1}, {1LL}); + test3.AddOutput("output", {1,2,2}, {4,5,6,7}); + test3.Run(); +} + +TEST(GatherNDOpTest, GatherND_batched_index_int64) { + OpTester test("GatherND", 1, onnxruntime::kMSDomain); + test.AddInput("data", {2,2}, {0LL,1LL,2LL,3LL}); + test.AddInput("indices", {2,1,2}, {0LL,0LL,0LL,1LL}); + test.AddOutput("output", {2,1}, {0LL,1LL}); + test.Run(); +} + +TEST(GatherNDOpTest, GatherND_batched_index_bool_int64) { + OpTester test("GatherND", 1, onnxruntime::kMSDomain); + test.AddInput("data", {2,2}, {true,false,false,true}); + test.AddInput("indices", {2,1,2}, {0LL,0LL,0LL,1LL}); + test.AddOutput("output", {2,1}, {true,false}); + test.Run(); +} + +TEST(GatherNDOpTest, GatherND_sliced_index_int64) { + OpTester test("GatherND", 1, onnxruntime::kMSDomain); + test.AddInput("data", {2,2}, {0LL,1LL,2LL,3LL}); + test.AddInput("indices", {2,1,1}, {1LL,0LL}); + test.AddOutput("output", {2,1,2}, {2LL,3LL,0LL,1LL}); + test.Run(); +} + +TEST(GatherNDOpTest, GatherND_sliced_index_string_int32) { + OpTester test("GatherND", 1, onnxruntime::kMSDomain); + test.AddInput("data", {2,2}, {"ab","cde","f","ghi"}); + test.AddInput("indices", {2,1,1}, {1LL,0LL}); + test.AddOutput("output", {2,1,2}, {"f","ghi","ab","cde"}); + test.Run(); +} + +TEST(GatherNDOpTest, GatherND_batched_3tensor_int64) { + OpTester test1("GatherND", 1, onnxruntime::kMSDomain); + test1.AddInput("data", {2,2,2}, {0,1,2,3,4,5,6,7}); + test1.AddInput("indices", {2,2,2}, {0LL,1LL,1LL,0LL,0LL,0LL,1LL,1LL}); + test1.AddOutput("output", {2,2,2}, {2,3,4,5,0,1,6,7}); + test1.Run(); + + OpTester test2("GatherND", 1, onnxruntime::kMSDomain); + test2.AddInput("data", {2,2,2}, {0,1,2,3,4,5,6,7}); + test2.AddInput("indices", {2,2,3}, {0,0,1,1,0,1,0,1,1,1,1,0}); + test2.AddOutput("output", {2,2}, {1,5,3,6}); + test2.Run(); + + OpTester test3("GatherND", 1, onnxruntime::kMSDomain); + test3.AddInput("data", {2,2,2}, {0LL,1LL,2LL,3LL,4LL,5LL,6LL,7LL}); + test3.AddInput("indices", {2,1,1}, {1,0}); + test3.AddOutput("output", {2,1,2,2}, {4LL,5LL,6LL,7LL,0LL,1LL,2LL,3LL}); + test3.Run(); +} + +} // namespace test +} // namespace onnxruntime