Rashuai/gathernd op (#170)

* define gather_nd op

* add test cases

* add test file

* refactor the code and doc

* add test cases

* fix win compile err

* fix win compile err

* adjust indent

* make constructor explicit

* add coment

* remove templates

* remove wrong def

* migrate macros

* fix an issue in shape inference
This commit is contained in:
Randy 2018-12-17 13:47:20 -08:00 committed by GitHub
parent 82d04412a0
commit d0544a8082
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 387 additions and 0 deletions

View file

@ -17,6 +17,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1,
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, string, StringNormalizer);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, NonMaxSuppression);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, Range);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, GatherND);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, MurmurHash3);
void RegisterContribKernels(std::function<void(KernelCreateInfo&&)> fn) {
@ -33,6 +34,7 @@ void RegisterContribKernels(std::function<void(KernelCreateInfo&&)> fn) {
fn(BuildKernel<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, string, StringNormalizer)>());
fn(BuildKernel<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, NonMaxSuppression)>());
fn(BuildKernel<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, Range)>());
fn(BuildKernel<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, GatherND)>());
fn(BuildKernel<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, MurmurHash3)>());
}
} // namespace contrib

View file

@ -0,0 +1,114 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "contrib_ops/cpu/gather_nd.h"
namespace onnxruntime {
namespace contrib {
ONNX_OPERATOR_KERNEL_EX(
GatherND,
kMSDomain,
1,
kCpuExecutionProvider,
KernelDefBuilder()
.TypeConstraint("T", DataTypeImpl::AllTensorTypes())
.TypeConstraint("Tind", {DataTypeImpl::GetTensorType<int32_t>(),DataTypeImpl::GetTensorType<int64_t>()}),
GatherND);
template<typename Tind>
Status GatherNDBase::PrepareForCompute(OpKernelContext* context, Prepare& p) const {
auto input_tensor = context->Input<Tensor>(0);
auto indice_tensor = context->Input<Tensor>(1);
ORT_ENFORCE(input_tensor != nullptr);
ORT_ENFORCE(indice_tensor != nullptr);
auto input_shape = input_tensor->Shape();
auto indice_shape = indice_tensor->Shape();
if (indice_shape.NumDimensions() == 0) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"indices tensor must has rank larger than 0");
}
auto last_indice_dimension = indice_shape[indice_shape.NumDimensions() - 1];
if (last_indice_dimension > static_cast<int64_t>(input_shape.NumDimensions())) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"last dimension of indices must not be larger than rank of input tensor");
}
std::vector<int64_t> shape(indice_shape.GetDims().begin(),
indice_shape.GetDims().end() - 1);
shape.insert(shape.end(),
input_shape.GetDims().begin() + last_indice_dimension,
input_shape.GetDims().end());
auto output_tensor = context->Output(0,TensorShape(shape));
std::vector<int64_t> element_counts(last_indice_dimension, 0LL); // Number of elements for each input dimension
#pragma omp parallel for
for (int64_t i = 0; i < last_indice_dimension; ++i) {
element_counts[i] = input_shape.SizeFromDimension(i + 1);
}
int64_t err_indice = 0;
p.element_bytes = input_tensor->DataType()->Size();
p.element_to_copy = input_shape.SizeFromDimension(last_indice_dimension);
p.bytes_to_copy = p.element_bytes * p.element_to_copy;
auto indice_offset = static_cast<const Tind*>(context->Input<Tensor>(1)->DataRaw());
auto offset_count = indice_shape.Size() / last_indice_dimension; // Times to copy
p.element_offsets.assign(offset_count, 0LL);
if (input_tensor->DataType() == DataTypeImpl::GetType<std::string>()) {
p.input_str_base = static_cast<const std::string*>(input_tensor->DataRaw());
p.output_str_base = static_cast<std::string*>(output_tensor->MutableDataRaw());
} else {
p.input_base = static_cast<const uint8_t*>(context->Input<Tensor>(0)->DataRaw());
p.output_base = static_cast<uint8_t*>(output_tensor->MutableDataRaw());
}
#pragma omp parallel for
for (int64_t i = 0; i < offset_count; ++i) {
for (int64_t j = 0; j < last_indice_dimension; ++j) {
auto indice = *(indice_offset + i * last_indice_dimension + j);
if (indice < 0 || indice >= input_shape[j]) {
err_indice = indice;
}
p.element_offsets[i] += indice * element_counts[j];
}
}
return err_indice == 0 ? Status::OK() :
ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "invalid indice found, indice = ", err_indice);
}
template Status GatherNDBase::PrepareForCompute<int32_t>(OpKernelContext*, Prepare&) const;
template Status GatherNDBase::PrepareForCompute<int64_t>(OpKernelContext*, Prepare&) const;
Status GatherND::Compute(OpKernelContext* context) const {
Prepare p;
ORT_RETURN_IF_ERROR(context->Input<Tensor>(1)->DataType() == DataTypeImpl::GetType<int32_t>() ?
PrepareForCompute<int32_t>(context, p) : PrepareForCompute<int64_t>(context, p));
return nullptr == p.input_str_base ? GatherNumber(p) : GatherString(p);
}
Status GatherND::GatherNumber(const Prepare& p) const {
#pragma omp parallel for
for (int64_t i = 0; i < static_cast<int64_t>(p.element_offsets.size()); ++i) {
memcpy(p.output_base + i * p.bytes_to_copy,
p.input_base + p.element_offsets[i] * p.element_bytes,
p.bytes_to_copy);
}
return Status::OK();
}
Status GatherND::GatherString(const Prepare& p) const {
#pragma omp parallel for
for (int64_t i = 0; i < static_cast<int64_t>(p.element_offsets.size()); ++i) {
for (int64_t j = 0; j < static_cast<int64_t>(p.element_to_copy); ++j) {
p.output_str_base[i * p.element_to_copy + j] = p.input_str_base[p.element_offsets[i] + j];
}
}
return Status::OK();
}
}
}

View file

@ -0,0 +1,49 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include "core/common/common.h"
#include "core/framework/op_kernel.h"
namespace onnxruntime {
namespace contrib {
class GatherNDBase
{
protected:
struct Prepare {
const uint8_t* input_base;
const std::string* input_str_base;
uint8_t* output_base;
std::string* output_str_base;
uint64_t bytes_to_copy;
uint64_t element_bytes;
uint64_t element_to_copy;
std::vector<uint64_t> element_offsets;
Prepare(): input_base (nullptr),
input_str_base (nullptr),
output_base (nullptr),
output_str_base (nullptr),
bytes_to_copy (0),
element_bytes (0),
element_to_copy (0),
element_offsets (0) {}
}; // struct Prepare
template<typename Tind>
Status PrepareForCompute(OpKernelContext* context, Prepare& p) const;
}; // class GatherNDBase
class GatherND final : public OpKernel, protected GatherNDBase {
public:
explicit GatherND(const OpKernelInfo& info) : OpKernel(info) {}
Status Compute(OpKernelContext* context) const override;
private:
Status GatherNumber(const Prepare& p) const;
Status GatherString(const Prepare& p) const;
};
} // namespace contrib
} // namespace onnxruntime

View file

@ -588,6 +588,71 @@ The bounding box coordinates corresponding to the selected indices can then be o
output_elem_type->set_elem_type(ONNX_NAMESPACE::TensorProto::STRING);
})
.SetDoc(R"DOC([optional] Step1: Remove elements in X if they match any of the stop words so that the output tensor will not contain any stop words. This operator only accepts [C]- and [1, C]-tensors. If all elements in X are dropped, the output will be the default value of string tensor with shape [1] if input shape is [C] and shape [1, 1] if input shape is [1, C].)DOC");
ONNX_CONTRIB_OPERATOR_SCHEMA(GatherND)
.SetDomain(kMSDomain)
.SinceVersion(1)
.Input (0, "data", "Tensor of rank r >= 1.", "T" )
.Input (1, "indices", "Tensor of rank q >= 1.", "Tind" )
.Output (0, "output", "Tensor of rank q-1+r-indices[-1].", "T" )
.TypeConstraint(
"T",
OpSchema::all_tensor_types(),
"Constrain input and output types to any tensor type.")
.TypeConstraint(
"Tind",
{"tensor(int32)", "tensor(int64)"},
"Constrain indice type to int32 or int64")
.TypeAndShapeInferenceFunction( [] (ONNX_NAMESPACE::InferenceContext& ctx) {
propagateElemTypeFromInputToOutput(ctx, 0, 0);
if (!hasNInputShapes(ctx, 2)) {
fail_shape_inference("GatherND requires two tensor inputs.");
}
auto& data_shape = ctx.getInputType(0)->tensor_type().shape();
auto& indices_shape = ctx.getInputType(1)->tensor_type().shape();
auto data_rank = data_shape.dim_size();
auto indices_rank = indices_shape.dim_size();
if (data_rank < 1 || indices_rank < 1) {
fail_shape_inference("both data and indices tensor need to have rank larger than zero.");
}
auto last_indice_dimension = indices_shape.dim(indices_rank - 1).dim_value();
if (last_indice_dimension > data_rank) {
fail_shape_inference("last dimension of indices must not be larger and rank of data tensor");
}
for (int i = 0; i < indices_rank - 1; ++i) {
*ctx.getOutputType(0)
->mutable_tensor_type()
->mutable_shape()
->add_dim() = indices_shape.dim(i);
}
for (int i = static_cast<int>(last_indice_dimension); i < data_rank; ++i) {
*ctx.getOutputType(0)
->mutable_tensor_type()
->mutable_shape()
->add_dim() = data_shape.dim(i);
}
})
.SetDoc(R"DOC(
Given `data` tensor of rank r >= 1, and `indices` tensor of rank q >= 1, gather
slices of `data` into an output tensor of rank q - 1 + r - indices[-1].
Example 1:
data = [[0,1],[2,3]]
indices = [[0,0],[1,1]]
output = [0,3]
Example 2:
data = [[0,1],[2,3]]
indices = [[1],[0]]
output = [[2,3],[0,1]]
Example 3:
data = [[[0,1],[2,3]],[[4,5],[6,7]]]
indices = [[0,1],[1,0]]
output = [[2,3],[4,5]]
Example 4:
data = [[[0,1],[2,3]],[[4,5],[6,7]]]
indices = [[[0,1]],[[1,0]]]
output = [[[2,3]],[[4,5]]]
)DOC");
}
} // namespace contrib
} // namespace onnxruntime

View file

@ -0,0 +1,157 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "gtest/gtest.h"
#include "test/providers/provider_test_utils.h"
namespace onnxruntime {
namespace test {
TEST(GatherNDOpTest, GatherND_scaler_string_int32) {
OpTester test1("GatherND", 1, onnxruntime::kMSDomain);
test1.AddInput<std::string>("data", {2,2}, {"h","k","o","z"});
test1.AddInput<int32_t>("indices", {2}, {0,1});
test1.AddOutput<std::string>("output", {}, {"k"});
test1.Run();
OpTester test2("GatherND", 1, onnxruntime::kMSDomain);
test2.AddInput<std::string>("data", {6}, {"h","k","o","z","l","t"});
test2.AddInput<int32_t>("indices", {1}, {3});
test2.AddOutput<std::string>("output", {}, {"z"});
test2.Run();
OpTester test3("GatherND", 1, onnxruntime::kMSDomain);
test3.AddInput<std::string>("data", {3,2}, {"h","k","o","z","l","t"});
test3.AddInput<int32_t>("indices", {2}, {2,1});
test3.AddOutput<std::string>("output", {}, {"t"});
test3.Run();
}
TEST(GatherNDOpTest, GatherND_matrice_int64_int64) {
OpTester test("GatherND", 1, onnxruntime::kMSDomain);
test.AddInput<int64_t> ("data", {2,2}, {0LL,1LL,2LL,3LL});
test.AddInput<int64_t> ("indices", {2,2}, {0LL,0LL,1LL,1LL});
test.AddOutput<int64_t>("output", {2}, {0LL,3LL});
test.Run();
}
TEST(GatherNDOpTest, GatherND_matrice_string_int64) {
OpTester test("GatherND", 1, onnxruntime::kMSDomain);
test.AddInput<std::string>("data", {2,2}, {"a","b","c","d"});
test.AddInput<int64_t>("indices", {2,2}, {0LL,0LL,1LL,1LL});
test.AddOutput<std::string>("output", {2}, {"a","d"});
test.Run();
}
TEST(GatherNDOpTest, GatherND_matrice_int64_int32) {
OpTester test("GatherND", 1, onnxruntime::kMSDomain);
test.AddInput<int64_t>("data", {2,2}, {0LL,1LL,2LL,3LL});
test.AddInput<int32_t>("indices", {2,2}, {0,0,1,1});
test.AddOutput<int64_t>("output", {2}, {0LL,3LL});
test.Run();
}
TEST(GatherNDOpTest, GatherND_matrice_string_int32) {
OpTester test1("GatherND", 1, onnxruntime::kMSDomain);
test1.AddInput<std::string>("data", {2,2,2}, {"egg","dance","air","bob","terry","smart","laugh","kite"});
test1.AddInput<int32_t>("indices", {2,1,2}, {0,1,1,0});
test1.AddOutput<std::string>("output", {2,1,2}, {"air","bob","terry","smart"});
test1.Run();
OpTester test2("GatherND", 1, onnxruntime::kMSDomain);
test2.AddInput<std::string>("data", {3,3}, {"egg","dance","air","bob","terry","smart","laugh","kite","hop"});
test2.AddInput<int32_t>("indices", {3,2}, {2,1,1,0,0,1});
test2.AddOutput<std::string>("output", {3}, {"kite","bob","dance"});
test2.Run();
}
TEST(GatherNDOpTest, GatherND_slice_float_int64_t) {
OpTester test("GatherND", 1, onnxruntime::kMSDomain);
test.AddInput<float>("data", {2,2}, {0.0f,0.1f,0.2f,0.3f});
test.AddInput<int64_t>("indices", {2,1}, {1LL,0LL});
test.AddOutput<float>("output", {2,2}, {0.2f,0.3f,0.0f,0.1f});
test.Run();
}
TEST(GatherNDOpTest, GatherND_slice_double_int32_t) {
OpTester test("GatherND", 1, onnxruntime::kMSDomain);
test.AddInput<double>("data", {2,2}, {0.0f,0.1f,0.2f,0.3f});
test.AddInput<int32_t>("indices", {2,1}, {1LL,0LL});
test.AddOutput<double>("output", {2,2}, {0.2f,0.3f,0.0f,0.1f});
test.Run();
}
TEST(GatherNDOpTest, GatherND_3tensor_int64) {
OpTester test1("GatherND", 1, onnxruntime::kMSDomain);
test1.AddInput<int64_t>("data", {2,2,2}, {0LL,1LL,2LL,3LL,4LL,5LL,6LL,7LL});
test1.AddInput<int64_t>("indices", {2,2}, {0LL,1LL,1LL,0LL});
test1.AddOutput<int64_t>("output", {2,2}, {2LL,3LL,4LL,5LL});
test1.Run();
OpTester test2("GatherND", 1, onnxruntime::kMSDomain);
test2.AddInput<int8_t>("data", {2,2,2}, {0,1,2,3,4,5,6,7});
test2.AddInput<int32_t>("indices", {2,3}, {0,0,1,1,0,1});
test2.AddOutput<int8_t>("output", {2}, {1,5});
test2.Run();
OpTester test3("GatherND", 1, onnxruntime::kMSDomain);
test3.AddInput<int16_t>("data", {2,2,2}, {0,1,2,3,4,5,6,7});
test3.AddInput<int64_t>("indices", {1,1}, {1LL});
test3.AddOutput<int16_t>("output", {1,2,2}, {4,5,6,7});
test3.Run();
}
TEST(GatherNDOpTest, GatherND_batched_index_int64) {
OpTester test("GatherND", 1, onnxruntime::kMSDomain);
test.AddInput<int64_t>("data", {2,2}, {0LL,1LL,2LL,3LL});
test.AddInput<int64_t>("indices", {2,1,2}, {0LL,0LL,0LL,1LL});
test.AddOutput<int64_t>("output", {2,1}, {0LL,1LL});
test.Run();
}
TEST(GatherNDOpTest, GatherND_batched_index_bool_int64) {
OpTester test("GatherND", 1, onnxruntime::kMSDomain);
test.AddInput<bool>("data", {2,2}, {true,false,false,true});
test.AddInput<int64_t>("indices", {2,1,2}, {0LL,0LL,0LL,1LL});
test.AddOutput<bool>("output", {2,1}, {true,false});
test.Run();
}
TEST(GatherNDOpTest, GatherND_sliced_index_int64) {
OpTester test("GatherND", 1, onnxruntime::kMSDomain);
test.AddInput<int64_t>("data", {2,2}, {0LL,1LL,2LL,3LL});
test.AddInput<int64_t>("indices", {2,1,1}, {1LL,0LL});
test.AddOutput<int64_t>("output", {2,1,2}, {2LL,3LL,0LL,1LL});
test.Run();
}
TEST(GatherNDOpTest, GatherND_sliced_index_string_int32) {
OpTester test("GatherND", 1, onnxruntime::kMSDomain);
test.AddInput<std::string>("data", {2,2}, {"ab","cde","f","ghi"});
test.AddInput<int32_t>("indices", {2,1,1}, {1LL,0LL});
test.AddOutput<std::string>("output", {2,1,2}, {"f","ghi","ab","cde"});
test.Run();
}
TEST(GatherNDOpTest, GatherND_batched_3tensor_int64) {
OpTester test1("GatherND", 1, onnxruntime::kMSDomain);
test1.AddInput<uint32_t>("data", {2,2,2}, {0,1,2,3,4,5,6,7});
test1.AddInput<int64_t>("indices", {2,2,2}, {0LL,1LL,1LL,0LL,0LL,0LL,1LL,1LL});
test1.AddOutput<uint32_t>("output", {2,2,2}, {2,3,4,5,0,1,6,7});
test1.Run();
OpTester test2("GatherND", 1, onnxruntime::kMSDomain);
test2.AddInput<uint32_t>("data", {2,2,2}, {0,1,2,3,4,5,6,7});
test2.AddInput<int32_t>("indices", {2,2,3}, {0,0,1,1,0,1,0,1,1,1,1,0});
test2.AddOutput<uint32_t>("output", {2,2}, {1,5,3,6});
test2.Run();
OpTester test3("GatherND", 1, onnxruntime::kMSDomain);
test3.AddInput<int64_t>("data", {2,2,2}, {0LL,1LL,2LL,3LL,4LL,5LL,6LL,7LL});
test3.AddInput<int32_t>("indices", {2,1,1}, {1,0});
test3.AddOutput<int64_t>("output", {2,1,2,2}, {4LL,5LL,6LL,7LL,0LL,1LL,2LL,3LL});
test3.Run();
}
} // namespace test
} // namespace onnxruntime