mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-07 00:13:17 +00:00
Support opset-11 GatherND CPU kernel (#1969)
* Initial commit * Update * Update * Update * Remove tests from exclusion * Update * Formatting * Formatting * Formatting * Update * Update * Update * Update
This commit is contained in:
parent
627f853a44
commit
74517bb742
9 changed files with 884 additions and 625 deletions
|
|
@ -1,126 +0,0 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include "contrib_ops/cpu/gather_nd.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace contrib {
|
||||
|
||||
ONNX_OPERATOR_KERNEL_EX(
|
||||
GatherND,
|
||||
kMSDomain,
|
||||
1,
|
||||
kCpuExecutionProvider,
|
||||
KernelDefBuilder()
|
||||
.TypeConstraint("T", DataTypeImpl::AllTensorTypes())
|
||||
.TypeConstraint("Tind", {DataTypeImpl::GetTensorType<int32_t>(),DataTypeImpl::GetTensorType<int64_t>()}),
|
||||
GatherND);
|
||||
|
||||
template<typename Tind>
|
||||
Status GatherNDBase::PrepareForCompute(OpKernelContext* context, Prepare& p) const {
|
||||
|
||||
auto input_tensor = context->Input<Tensor>(0);
|
||||
auto indice_tensor = context->Input<Tensor>(1);
|
||||
ORT_ENFORCE(input_tensor != nullptr);
|
||||
ORT_ENFORCE(indice_tensor != nullptr);
|
||||
|
||||
auto input_shape = input_tensor->Shape();
|
||||
auto indice_shape = indice_tensor->Shape();
|
||||
if (indice_shape.NumDimensions() == 0) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
|
||||
"indices tensor must has rank larger than 0");
|
||||
}
|
||||
|
||||
auto last_indice_dimension = indice_shape[indice_shape.NumDimensions() - 1];
|
||||
if (last_indice_dimension > static_cast<int64_t>(input_shape.NumDimensions())) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
|
||||
"last dimension of indices must not be larger than rank of input tensor");
|
||||
}
|
||||
|
||||
std::vector<int64_t> shape(indice_shape.GetDims().begin(),
|
||||
indice_shape.GetDims().end() - 1);
|
||||
shape.insert(shape.end(),
|
||||
input_shape.GetDims().begin() + last_indice_dimension,
|
||||
input_shape.GetDims().end());
|
||||
auto output_tensor = context->Output(0,TensorShape(shape));
|
||||
std::vector<int64_t> element_counts(last_indice_dimension, 0LL); // Number of elements for each input dimension
|
||||
|
||||
#ifdef USE_OPENMP
|
||||
#pragma omp parallel for
|
||||
#endif
|
||||
for (int64_t i = 0; i < last_indice_dimension; ++i) {
|
||||
element_counts[i] = input_shape.SizeFromDimension(i + 1);
|
||||
}
|
||||
|
||||
int64_t err_indice = 0;
|
||||
p.element_bytes = input_tensor->DataType()->Size();
|
||||
p.element_to_copy = input_shape.SizeFromDimension(last_indice_dimension);
|
||||
p.bytes_to_copy = p.element_bytes * p.element_to_copy;
|
||||
auto indice_offset = indice_tensor->Data<Tind>();
|
||||
auto offset_count = indice_shape.Size() / last_indice_dimension; // Times to copy
|
||||
p.element_offsets.assign(offset_count, 0LL);
|
||||
|
||||
if (input_tensor->DataType() == DataTypeImpl::GetType<std::string>()) {
|
||||
p.input_str_base = static_cast<const std::string*>(input_tensor->DataRaw());
|
||||
p.output_str_base = static_cast<std::string*>(output_tensor->MutableDataRaw());
|
||||
} else {
|
||||
p.input_base = static_cast<const uint8_t*>(input_tensor->DataRaw());
|
||||
p.output_base = static_cast<uint8_t*>(output_tensor->MutableDataRaw());
|
||||
}
|
||||
|
||||
#ifdef USE_OPENMP
|
||||
#pragma omp parallel for
|
||||
#endif
|
||||
for (int64_t i = 0; i < offset_count; ++i) {
|
||||
for (int64_t j = 0; j < last_indice_dimension; ++j) {
|
||||
auto indice = *(indice_offset + i * last_indice_dimension + j);
|
||||
if (indice < 0 || indice >= input_shape[j]) {
|
||||
err_indice = indice;
|
||||
}
|
||||
p.element_offsets[i] += indice * element_counts[j];
|
||||
}
|
||||
}
|
||||
|
||||
return err_indice == 0 ? Status::OK() :
|
||||
ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "invalid indice found, indice = ", err_indice);
|
||||
}
|
||||
|
||||
template Status GatherNDBase::PrepareForCompute<int32_t>(OpKernelContext*, Prepare&) const;
|
||||
template Status GatherNDBase::PrepareForCompute<int64_t>(OpKernelContext*, Prepare&) const;
|
||||
|
||||
Status GatherND::Compute(OpKernelContext* context) const {
|
||||
Prepare p;
|
||||
ORT_RETURN_IF_ERROR(context->Input<Tensor>(1)->DataType() == DataTypeImpl::GetType<int32_t>() ?
|
||||
PrepareForCompute<int32_t>(context, p) : PrepareForCompute<int64_t>(context, p));
|
||||
|
||||
return nullptr == p.input_str_base ? GatherNumber(p) : GatherString(p);
|
||||
}
|
||||
|
||||
Status GatherND::GatherNumber(const Prepare& p) const {
|
||||
#ifdef USE_OPENMP
|
||||
#pragma omp parallel for
|
||||
#endif
|
||||
for (int64_t i = 0; i < static_cast<int64_t>(p.element_offsets.size()); ++i) {
|
||||
memcpy(p.output_base + i * p.bytes_to_copy,
|
||||
p.input_base + p.element_offsets[i] * p.element_bytes,
|
||||
p.bytes_to_copy);
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status GatherND::GatherString(const Prepare& p) const {
|
||||
#ifdef USE_OPENMP
|
||||
#pragma omp parallel for
|
||||
#endif
|
||||
for (int64_t i = 0; i < static_cast<int64_t>(p.element_offsets.size()); ++i) {
|
||||
for (int64_t j = 0; j < static_cast<int64_t>(p.element_to_copy); ++j) {
|
||||
p.output_str_base[i * p.element_to_copy + j] = p.input_str_base[p.element_offsets[i] + j];
|
||||
}
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
|
|
@ -1,50 +0,0 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "core/common/common.h"
|
||||
#include "core/framework/op_kernel.h"
|
||||
#include "core/platform/threadpool.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace contrib {
|
||||
|
||||
class GatherNDBase
|
||||
{
|
||||
protected:
|
||||
struct Prepare {
|
||||
const uint8_t* input_base;
|
||||
const std::string* input_str_base;
|
||||
uint8_t* output_base;
|
||||
std::string* output_str_base;
|
||||
uint64_t bytes_to_copy;
|
||||
uint64_t element_bytes;
|
||||
uint64_t element_to_copy;
|
||||
std::vector<uint64_t> element_offsets;
|
||||
|
||||
Prepare(): input_base (nullptr),
|
||||
input_str_base (nullptr),
|
||||
output_base (nullptr),
|
||||
output_str_base (nullptr),
|
||||
bytes_to_copy (0),
|
||||
element_bytes (0),
|
||||
element_to_copy (0),
|
||||
element_offsets (0) {}
|
||||
}; // struct Prepare
|
||||
|
||||
template<typename Tind>
|
||||
Status PrepareForCompute(OpKernelContext* context, Prepare& p) const;
|
||||
}; // class GatherNDBase
|
||||
|
||||
class GatherND final : public OpKernel, protected GatherNDBase {
|
||||
public:
|
||||
explicit GatherND(const OpKernelInfo& info) : OpKernel(info) {}
|
||||
Status Compute(OpKernelContext* context) const override;
|
||||
private:
|
||||
Status GatherNumber(const Prepare& p) const;
|
||||
Status GatherString(const Prepare& p) const;
|
||||
};
|
||||
|
||||
} // namespace contrib
|
||||
} // namespace onnxruntime
|
||||
File diff suppressed because it is too large
Load diff
141
onnxruntime/core/providers/cpu/tensor/gather_nd.cc
Normal file
141
onnxruntime/core/providers/cpu/tensor/gather_nd.cc
Normal file
|
|
@ -0,0 +1,141 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include "gather_nd.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
|
||||
// Register a kernel for kMsDomain (contrib op) GatherND
|
||||
#ifndef DISABLE_CONTRIB_OPS
|
||||
|
||||
namespace contrib {
|
||||
// TODO: Remove this contrib kernel registration and the schema from the appropriate places
|
||||
// once Keras Mask RCNN is shipped with all ONNX domain ops
|
||||
|
||||
// Currently this kernel is required to support Keras Mask-RCNN
|
||||
ONNX_OPERATOR_KERNEL_EX(GatherND, kMSDomain, 1, kCpuExecutionProvider,
|
||||
KernelDefBuilder()
|
||||
.TypeConstraint("T", DataTypeImpl::AllTensorTypes())
|
||||
// contrib spec supports `int32_t` and `int64_t` for indices
|
||||
.TypeConstraint("Tind", {DataTypeImpl::GetTensorType<int32_t>(),
|
||||
DataTypeImpl::GetTensorType<int64_t>()}),
|
||||
GatherND);
|
||||
|
||||
} // namespace contrib
|
||||
|
||||
#endif
|
||||
|
||||
ONNX_CPU_OPERATOR_KERNEL(GatherND, 11,
|
||||
KernelDefBuilder()
|
||||
.TypeConstraint("T", DataTypeImpl::AllTensorTypes())
|
||||
// official ONNX spec only supports `int64_t` for indices
|
||||
.TypeConstraint("Tind", DataTypeImpl::GetTensorType<int64_t>()),
|
||||
GatherND);
|
||||
|
||||
template <typename Tind>
|
||||
Status GatherNDBase::PrepareForCompute(OpKernelContext* context, Prepare& p) const {
|
||||
const auto* input_tensor = context->Input<Tensor>(0);
|
||||
const auto* indices_tensor = context->Input<Tensor>(1);
|
||||
ORT_ENFORCE(input_tensor != nullptr && indices_tensor != nullptr, "GatherND op: Input count mismatch");
|
||||
|
||||
const auto& input_shape = input_tensor->Shape();
|
||||
const auto& indices_shape = indices_tensor->Shape();
|
||||
if (indices_shape.NumDimensions() == 0) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "indices tensor must has rank larger than 0");
|
||||
}
|
||||
|
||||
int64_t last_indices_dimension = indices_shape[indices_shape.NumDimensions() - 1];
|
||||
if (last_indices_dimension > static_cast<int64_t>(input_shape.NumDimensions())) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
|
||||
"last dimension of indices must not be larger than rank of input tensor");
|
||||
}
|
||||
|
||||
std::vector<int64_t> shape(indices_shape.GetDims().begin(), indices_shape.GetDims().end() - 1);
|
||||
shape.insert(shape.end(), input_shape.GetDims().begin() + last_indices_dimension, input_shape.GetDims().end());
|
||||
auto* output_tensor = context->Output(0, TensorShape(std::move(shape)));
|
||||
std::vector<int64_t> element_counts(last_indices_dimension,
|
||||
0LL); // Number of elements for each input dimension
|
||||
|
||||
#ifdef USE_OPENMP
|
||||
#pragma omp parallel for
|
||||
#endif
|
||||
for (int64_t i = 0; i < last_indices_dimension; ++i) {
|
||||
element_counts[i] = input_shape.SizeFromDimension(i + 1);
|
||||
}
|
||||
|
||||
int64_t err_index = 0;
|
||||
p.element_bytes = input_tensor->DataType()->Size();
|
||||
p.element_to_copy = input_shape.SizeFromDimension(last_indices_dimension);
|
||||
p.bytes_to_copy = p.element_bytes * p.element_to_copy;
|
||||
const auto* indices_data = indices_tensor->Data<Tind>();
|
||||
const int64_t offset_count = indices_shape.Size() / last_indices_dimension; // Times to copy
|
||||
p.element_offsets.assign(offset_count, 0LL);
|
||||
|
||||
if (input_tensor->DataType() == DataTypeImpl::GetType<std::string>()) {
|
||||
p.input_str_base = static_cast<const std::string*>(input_tensor->DataRaw());
|
||||
p.output_str_base = static_cast<std::string*>(output_tensor->MutableDataRaw());
|
||||
} else {
|
||||
p.input_base = static_cast<const uint8_t*>(input_tensor->DataRaw());
|
||||
p.output_base = static_cast<uint8_t*>(output_tensor->MutableDataRaw());
|
||||
}
|
||||
|
||||
#ifdef USE_OPENMP
|
||||
#pragma omp parallel for
|
||||
#endif
|
||||
for (int64_t i = 0; i < offset_count; ++i) {
|
||||
for (int64_t j = 0; j < last_indices_dimension; ++j) {
|
||||
auto index = *(indices_data + i * last_indices_dimension + j);
|
||||
auto upper_limit = input_shape[j];
|
||||
auto lower_limit = -upper_limit;
|
||||
if (index < lower_limit || index >= upper_limit) {
|
||||
err_index = index;
|
||||
}
|
||||
if (index < 0) {
|
||||
index += static_cast<Tind>(upper_limit);
|
||||
}
|
||||
p.element_offsets[i] += index * element_counts[j];
|
||||
}
|
||||
}
|
||||
|
||||
return err_index == 0 ? Status::OK()
|
||||
: ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "invalid index found, index = ", err_index);
|
||||
}
|
||||
|
||||
template Status GatherNDBase::PrepareForCompute<int32_t>(OpKernelContext*, Prepare&) const;
|
||||
template Status GatherNDBase::PrepareForCompute<int64_t>(OpKernelContext*, Prepare&) const;
|
||||
|
||||
Status GatherND::Compute(OpKernelContext* context) const {
|
||||
Prepare p;
|
||||
ORT_RETURN_IF_ERROR(context->Input<Tensor>(1)->DataType() == DataTypeImpl::GetType<int32_t>()
|
||||
? PrepareForCompute<int32_t>(context, p)
|
||||
: PrepareForCompute<int64_t>(context, p));
|
||||
|
||||
return nullptr == p.input_str_base ? GatherNumber(p) : GatherString(p);
|
||||
}
|
||||
|
||||
Status GatherND::GatherNumber(const Prepare& p) const {
|
||||
#ifdef USE_OPENMP
|
||||
#pragma omp parallel for
|
||||
#endif
|
||||
for (int64_t i = 0; i < static_cast<int64_t>(p.element_offsets.size()); ++i) {
|
||||
memcpy(p.output_base + i * p.bytes_to_copy, p.input_base + p.element_offsets[i] * p.element_bytes,
|
||||
p.bytes_to_copy);
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status GatherND::GatherString(const Prepare& p) const {
|
||||
#ifdef USE_OPENMP
|
||||
#pragma omp parallel for
|
||||
#endif
|
||||
for (int64_t i = 0; i < static_cast<int64_t>(p.element_offsets.size()); ++i) {
|
||||
for (int64_t j = 0; j < static_cast<int64_t>(p.element_to_copy); ++j) {
|
||||
p.output_str_base[i * p.element_to_copy + j] = p.input_str_base[p.element_offsets[i] + j];
|
||||
}
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace onnxruntime
|
||||
48
onnxruntime/core/providers/cpu/tensor/gather_nd.h
Normal file
48
onnxruntime/core/providers/cpu/tensor/gather_nd.h
Normal file
|
|
@ -0,0 +1,48 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "core/common/common.h"
|
||||
#include "core/framework/op_kernel.h"
|
||||
#include "core/platform/threadpool.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
|
||||
class GatherNDBase {
|
||||
protected:
|
||||
struct Prepare {
|
||||
const uint8_t* input_base;
|
||||
const std::string* input_str_base;
|
||||
uint8_t* output_base;
|
||||
std::string* output_str_base;
|
||||
uint64_t bytes_to_copy;
|
||||
uint64_t element_bytes;
|
||||
uint64_t element_to_copy;
|
||||
std::vector<uint64_t> element_offsets;
|
||||
|
||||
Prepare() : input_base(nullptr),
|
||||
input_str_base(nullptr),
|
||||
output_base(nullptr),
|
||||
output_str_base(nullptr),
|
||||
bytes_to_copy(0),
|
||||
element_bytes(0),
|
||||
element_to_copy(0),
|
||||
element_offsets(0) {}
|
||||
}; // struct Prepare
|
||||
|
||||
template <typename Tind>
|
||||
Status PrepareForCompute(OpKernelContext* context, Prepare& p) const;
|
||||
}; // class GatherNDBase
|
||||
|
||||
class GatherND final : public OpKernel, protected GatherNDBase {
|
||||
public:
|
||||
explicit GatherND(const OpKernelInfo& info) : OpKernel(info) {}
|
||||
Status Compute(OpKernelContext* context) const override;
|
||||
|
||||
private:
|
||||
Status GatherNumber(const Prepare& p) const;
|
||||
Status GatherString(const Prepare& p) const;
|
||||
};
|
||||
|
||||
} // namespace onnxruntime
|
||||
|
|
@ -1,157 +0,0 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
#include "test/providers/provider_test_utils.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace test {
|
||||
|
||||
TEST(GatherNDOpTest, GatherND_scaler_string_int32) {
|
||||
OpTester test1("GatherND", 1, onnxruntime::kMSDomain);
|
||||
test1.AddInput<std::string>("data", {2,2}, {"h","k","o","z"});
|
||||
test1.AddInput<int32_t>("indices", {2}, {0,1});
|
||||
test1.AddOutput<std::string>("output", {}, {"k"});
|
||||
test1.Run();
|
||||
|
||||
OpTester test2("GatherND", 1, onnxruntime::kMSDomain);
|
||||
test2.AddInput<std::string>("data", {6}, {"h","k","o","z","l","t"});
|
||||
test2.AddInput<int32_t>("indices", {1}, {3});
|
||||
test2.AddOutput<std::string>("output", {}, {"z"});
|
||||
test2.Run();
|
||||
|
||||
OpTester test3("GatherND", 1, onnxruntime::kMSDomain);
|
||||
test3.AddInput<std::string>("data", {3,2}, {"h","k","o","z","l","t"});
|
||||
test3.AddInput<int32_t>("indices", {2}, {2,1});
|
||||
test3.AddOutput<std::string>("output", {}, {"t"});
|
||||
test3.Run();
|
||||
}
|
||||
|
||||
TEST(GatherNDOpTest, GatherND_matrice_int64_int64) {
|
||||
OpTester test("GatherND", 1, onnxruntime::kMSDomain);
|
||||
test.AddInput<int64_t> ("data", {2,2}, {0LL,1LL,2LL,3LL});
|
||||
test.AddInput<int64_t> ("indices", {2,2}, {0LL,0LL,1LL,1LL});
|
||||
test.AddOutput<int64_t>("output", {2}, {0LL,3LL});
|
||||
test.Run();
|
||||
}
|
||||
|
||||
TEST(GatherNDOpTest, GatherND_matrice_string_int64) {
|
||||
OpTester test("GatherND", 1, onnxruntime::kMSDomain);
|
||||
test.AddInput<std::string>("data", {2,2}, {"a","b","c","d"});
|
||||
test.AddInput<int64_t>("indices", {2,2}, {0LL,0LL,1LL,1LL});
|
||||
test.AddOutput<std::string>("output", {2}, {"a","d"});
|
||||
test.Run();
|
||||
}
|
||||
|
||||
TEST(GatherNDOpTest, GatherND_matrice_int64_int32) {
|
||||
OpTester test("GatherND", 1, onnxruntime::kMSDomain);
|
||||
test.AddInput<int64_t>("data", {2,2}, {0LL,1LL,2LL,3LL});
|
||||
test.AddInput<int32_t>("indices", {2,2}, {0,0,1,1});
|
||||
test.AddOutput<int64_t>("output", {2}, {0LL,3LL});
|
||||
test.Run();
|
||||
}
|
||||
|
||||
TEST(GatherNDOpTest, GatherND_matrice_string_int32) {
|
||||
OpTester test1("GatherND", 1, onnxruntime::kMSDomain);
|
||||
test1.AddInput<std::string>("data", {2,2,2}, {"egg","dance","air","bob","terry","smart","laugh","kite"});
|
||||
test1.AddInput<int32_t>("indices", {2,1,2}, {0,1,1,0});
|
||||
test1.AddOutput<std::string>("output", {2,1,2}, {"air","bob","terry","smart"});
|
||||
test1.Run();
|
||||
|
||||
OpTester test2("GatherND", 1, onnxruntime::kMSDomain);
|
||||
test2.AddInput<std::string>("data", {3,3}, {"egg","dance","air","bob","terry","smart","laugh","kite","hop"});
|
||||
test2.AddInput<int32_t>("indices", {3,2}, {2,1,1,0,0,1});
|
||||
test2.AddOutput<std::string>("output", {3}, {"kite","bob","dance"});
|
||||
test2.Run();
|
||||
}
|
||||
|
||||
TEST(GatherNDOpTest, GatherND_slice_float_int64_t) {
|
||||
OpTester test("GatherND", 1, onnxruntime::kMSDomain);
|
||||
test.AddInput<float>("data", {2,2}, {0.0f,0.1f,0.2f,0.3f});
|
||||
test.AddInput<int64_t>("indices", {2,1}, {1LL,0LL});
|
||||
test.AddOutput<float>("output", {2,2}, {0.2f,0.3f,0.0f,0.1f});
|
||||
test.Run();
|
||||
}
|
||||
|
||||
TEST(GatherNDOpTest, GatherND_slice_double_int32_t) {
|
||||
OpTester test("GatherND", 1, onnxruntime::kMSDomain);
|
||||
test.AddInput<double>("data", {2,2}, {0.0f,0.1f,0.2f,0.3f});
|
||||
test.AddInput<int32_t>("indices", {2,1}, {1LL,0LL});
|
||||
test.AddOutput<double>("output", {2,2}, {0.2f,0.3f,0.0f,0.1f});
|
||||
test.Run();
|
||||
}
|
||||
|
||||
TEST(GatherNDOpTest, GatherND_3tensor_int64) {
|
||||
OpTester test1("GatherND", 1, onnxruntime::kMSDomain);
|
||||
test1.AddInput<int64_t>("data", {2,2,2}, {0LL,1LL,2LL,3LL,4LL,5LL,6LL,7LL});
|
||||
test1.AddInput<int64_t>("indices", {2,2}, {0LL,1LL,1LL,0LL});
|
||||
test1.AddOutput<int64_t>("output", {2,2}, {2LL,3LL,4LL,5LL});
|
||||
test1.Run();
|
||||
|
||||
OpTester test2("GatherND", 1, onnxruntime::kMSDomain);
|
||||
test2.AddInput<int8_t>("data", {2,2,2}, {0,1,2,3,4,5,6,7});
|
||||
test2.AddInput<int32_t>("indices", {2,3}, {0,0,1,1,0,1});
|
||||
test2.AddOutput<int8_t>("output", {2}, {1,5});
|
||||
test2.Run();
|
||||
|
||||
OpTester test3("GatherND", 1, onnxruntime::kMSDomain);
|
||||
test3.AddInput<int16_t>("data", {2,2,2}, {0,1,2,3,4,5,6,7});
|
||||
test3.AddInput<int64_t>("indices", {1,1}, {1LL});
|
||||
test3.AddOutput<int16_t>("output", {1,2,2}, {4,5,6,7});
|
||||
test3.Run();
|
||||
}
|
||||
|
||||
TEST(GatherNDOpTest, GatherND_batched_index_int64) {
|
||||
OpTester test("GatherND", 1, onnxruntime::kMSDomain);
|
||||
test.AddInput<int64_t>("data", {2,2}, {0LL,1LL,2LL,3LL});
|
||||
test.AddInput<int64_t>("indices", {2,1,2}, {0LL,0LL,0LL,1LL});
|
||||
test.AddOutput<int64_t>("output", {2,1}, {0LL,1LL});
|
||||
test.Run();
|
||||
}
|
||||
|
||||
TEST(GatherNDOpTest, GatherND_batched_index_bool_int64) {
|
||||
OpTester test("GatherND", 1, onnxruntime::kMSDomain);
|
||||
test.AddInput<bool>("data", {2,2}, {true,false,false,true});
|
||||
test.AddInput<int64_t>("indices", {2,1,2}, {0LL,0LL,0LL,1LL});
|
||||
test.AddOutput<bool>("output", {2,1}, {true,false});
|
||||
test.Run();
|
||||
}
|
||||
|
||||
TEST(GatherNDOpTest, GatherND_sliced_index_int64) {
|
||||
OpTester test("GatherND", 1, onnxruntime::kMSDomain);
|
||||
test.AddInput<int64_t>("data", {2,2}, {0LL,1LL,2LL,3LL});
|
||||
test.AddInput<int64_t>("indices", {2,1,1}, {1LL,0LL});
|
||||
test.AddOutput<int64_t>("output", {2,1,2}, {2LL,3LL,0LL,1LL});
|
||||
test.Run();
|
||||
}
|
||||
|
||||
TEST(GatherNDOpTest, GatherND_sliced_index_string_int32) {
|
||||
OpTester test("GatherND", 1, onnxruntime::kMSDomain);
|
||||
test.AddInput<std::string>("data", {2,2}, {"ab","cde","f","ghi"});
|
||||
test.AddInput<int32_t>("indices", {2,1,1}, {1LL,0LL});
|
||||
test.AddOutput<std::string>("output", {2,1,2}, {"f","ghi","ab","cde"});
|
||||
test.Run();
|
||||
}
|
||||
|
||||
TEST(GatherNDOpTest, GatherND_batched_3tensor_int64) {
|
||||
OpTester test1("GatherND", 1, onnxruntime::kMSDomain);
|
||||
test1.AddInput<uint32_t>("data", {2,2,2}, {0,1,2,3,4,5,6,7});
|
||||
test1.AddInput<int64_t>("indices", {2,2,2}, {0LL,1LL,1LL,0LL,0LL,0LL,1LL,1LL});
|
||||
test1.AddOutput<uint32_t>("output", {2,2,2}, {2,3,4,5,0,1,6,7});
|
||||
test1.Run();
|
||||
|
||||
OpTester test2("GatherND", 1, onnxruntime::kMSDomain);
|
||||
test2.AddInput<uint32_t>("data", {2,2,2}, {0,1,2,3,4,5,6,7});
|
||||
test2.AddInput<int32_t>("indices", {2,2,3}, {0,0,1,1,0,1,0,1,1,1,1,0});
|
||||
test2.AddOutput<uint32_t>("output", {2,2}, {1,5,3,6});
|
||||
test2.Run();
|
||||
|
||||
OpTester test3("GatherND", 1, onnxruntime::kMSDomain);
|
||||
test3.AddInput<int64_t>("data", {2,2,2}, {0LL,1LL,2LL,3LL,4LL,5LL,6LL,7LL});
|
||||
test3.AddInput<int32_t>("indices", {2,1,1}, {1,0});
|
||||
test3.AddOutput<int64_t>("output", {2,1,2,2}, {4LL,5LL,6LL,7LL,0LL,1LL,2LL,3LL});
|
||||
test3.Run();
|
||||
}
|
||||
|
||||
} // namespace test
|
||||
} // namespace onnxruntime
|
||||
|
|
@ -406,8 +406,6 @@ int real_main(int argc, char* argv[], Ort::Env& env) {
|
|||
{"range_int32_type_negative_delta_expanded", "not implemented yet"},
|
||||
{"det_2d", "not implemented yet"},
|
||||
{"det_nd", "not implemented yet"},
|
||||
{"gathernd_example_float32", "not implemented yet"},
|
||||
{"gathernd_example_int32", "not implemented yet"},
|
||||
{"resize_downsample_scales_cubic_A_n0p5_exclude_outside", "not implemented yet"},
|
||||
{"resize_downsample_scales_cubic_align_corners", "not implemented yet"},
|
||||
{"resize_downsample_scales_cubic", "not implemented yet"},
|
||||
|
|
|
|||
118
onnxruntime/test/providers/cpu/tensor/gather_nd_op_test.cc
Normal file
118
onnxruntime/test/providers/cpu/tensor/gather_nd_op_test.cc
Normal file
|
|
@ -0,0 +1,118 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
#include "test/providers/provider_test_utils.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace test {
|
||||
|
||||
template <typename T>
|
||||
static void RunTest(const std::vector<int64_t>& input_dims, const std::initializer_list<T>& input,
|
||||
const std::vector<int64_t>& indices_dims, const std::initializer_list<int64_t>& indices,
|
||||
const std::vector<int64_t>& output_dims, const std::initializer_list<T>& output) {
|
||||
// ONNX domain opset-11
|
||||
OpTester test1("GatherND", 11);
|
||||
test1.AddInput<T>("data", input_dims, input);
|
||||
test1.AddInput<int64_t>("indices", indices_dims, indices);
|
||||
test1.AddOutput<T>("output", output_dims, output);
|
||||
test1.Run();
|
||||
|
||||
#ifndef DISABLE_CONTRIB_OPS
|
||||
|
||||
// MSFT domain opset-1 (contrib op)
|
||||
OpTester test2("GatherND", 1, kMSDomain);
|
||||
test2.AddInput<T>("data", input_dims, input);
|
||||
test2.AddInput<int64_t>("indices", indices_dims, indices);
|
||||
test2.AddOutput<T>("output", output_dims, output);
|
||||
test2.Run();
|
||||
|
||||
#endif
|
||||
}
|
||||
|
||||
TEST(GatherNDOpTest, string) {
|
||||
RunTest<std::string>({2, 2}, {"h", "k", "o", "z"}, {2}, {0, 1}, {}, {"k"});
|
||||
|
||||
RunTest<std::string>({6}, {"h", "k", "o", "z", "l", "t"}, {1}, {3}, {}, {"z"});
|
||||
|
||||
RunTest<std::string>({3, 2}, {"h", "k", "o", "z", "l", "t"}, {2}, {2, 1}, {}, {"t"});
|
||||
|
||||
RunTest<std::string>({2, 2}, {"a", "b", "c", "d"}, {2, 2}, {0LL, 0LL, 1LL, 1LL}, {2}, {"a", "d"});
|
||||
|
||||
RunTest<std::string>({2, 2, 2}, {"egg", "dance", "air", "bob", "terry", "smart", "laugh", "kite"}, {2, 1, 2},
|
||||
{0LL, 1LL, 1LL, 0LL}, {2, 1, 2}, {"air", "bob", "terry", "smart"});
|
||||
|
||||
RunTest<std::string>({3, 3}, {"egg", "dance", "air", "bob", "terry", "smart", "laugh", "kite", "hop"}, {3, 2},
|
||||
{2, 1, 1, 0, 0, 1}, {3}, {"kite", "bob", "dance"});
|
||||
|
||||
RunTest<std::string>({2, 2}, {"ab", "cde", "f", "ghi"}, {2, 1, 1}, {1LL, 0LL}, {2, 1, 2}, {"f", "ghi", "ab", "cde"});
|
||||
|
||||
// with negative indices
|
||||
RunTest<std::string>({2, 2}, {"ab", "cde", "f", "ghi"}, {2, 1, 1}, {-1, 0}, {2, 1, 2}, {"f", "ghi", "ab", "cde"});
|
||||
}
|
||||
|
||||
TEST(GatherNDOpTest, int64_t) {
|
||||
RunTest<int64_t>({2, 2}, {0LL, 1LL, 2LL, 3LL}, {2, 2}, {0LL, 0LL, 1LL, 1LL}, {2}, {0LL, 3LL});
|
||||
|
||||
RunTest<int64_t>({2, 2, 2}, {0LL, 1LL, 2LL, 3LL, 4LL, 5LL, 6LL, 7LL}, {2, 2}, {0LL, 1LL, 1LL, 0LL}, {2, 2},
|
||||
{2LL, 3LL, 4LL, 5LL});
|
||||
|
||||
RunTest<int64_t>({2, 2}, {0LL, 1LL, 2LL, 3LL}, {2, 1, 2}, {0LL, 0LL, 0LL, 1LL}, {2, 1}, {0LL, 1LL});
|
||||
|
||||
RunTest<int64_t>({2, 2}, {0LL, 1LL, 2LL, 3LL}, {2, 1, 1}, {1LL, 0LL}, {2, 1, 2}, {2LL, 3LL, 0LL, 1LL});
|
||||
|
||||
RunTest<int64_t>({2, 2, 2}, {0LL, 1LL, 2LL, 3LL, 4LL, 5LL, 6LL, 7LL}, {2, 1, 1}, {1, 0}, {2, 1, 2, 2},
|
||||
{4LL, 5LL, 6LL, 7LL, 0LL, 1LL, 2LL, 3LL});
|
||||
|
||||
// with negative indices
|
||||
RunTest<int64_t>({2, 2, 2}, {0LL, 1LL, 2LL, 3LL, 4LL, 5LL, 6LL, 7LL}, {2, 1, 1}, {-1, 0}, {2, 1, 2, 2},
|
||||
{4LL, 5LL, 6LL, 7LL, 0LL, 1LL, 2LL, 3LL});
|
||||
}
|
||||
|
||||
TEST(GatherNDOpTest, float) {
|
||||
RunTest<float>({2, 2}, {0.0f, 0.1f, 0.2f, 0.3f}, {2, 1}, {1LL, 0LL}, {2, 2}, {0.2f, 0.3f, 0.0f, 0.1f});
|
||||
}
|
||||
|
||||
TEST(GatherNDOpTest, double) {
|
||||
RunTest<double>({2, 2}, {0.0, 0.1, 0.2, 0.3}, {2, 1}, {1LL, 0LL}, {2, 2}, {0.2, 0.3, 0.0, 0.1});
|
||||
}
|
||||
|
||||
TEST(GatherNDOpTest, int8_t) {
|
||||
RunTest<int8_t>({2, 2, 2}, {0, 1, 2, 3, 4, 5, 6, 7}, {2, 3}, {0, 0, 1, 1, 0, 1}, {2}, {1, 5});
|
||||
}
|
||||
|
||||
TEST(GatherNDOpTest, int16_t) {
|
||||
RunTest<int16_t>({2, 2, 2}, {0, 1, 2, 3, 4, 5, 6, 7}, {1, 1}, {1}, {1, 2, 2}, {4, 5, 6, 7});
|
||||
}
|
||||
|
||||
TEST(GatherNDOpTest, uint32_t) {
|
||||
RunTest<uint32_t>({2, 2, 2}, {0, 1, 2, 3, 4, 5, 6, 7}, {2, 2, 2}, {0LL, 1LL, 1LL, 0LL, 0LL, 0LL, 1LL, 1LL},
|
||||
{2, 2, 2}, {2, 3, 4, 5, 0, 1, 6, 7});
|
||||
|
||||
RunTest<uint32_t>({2, 2, 2}, {0, 1, 2, 3, 4, 5, 6, 7}, {2, 2, 3}, {0, 0, 1, 1, 0, 1, 0, 1, 1, 1, 1, 0}, {2, 2},
|
||||
{1, 5, 3, 6});
|
||||
}
|
||||
|
||||
TEST(GatherNDOpTest, bool) {
|
||||
RunTest<bool>({2, 2}, {true, false, false, true}, {2, 1, 2}, {0LL, 0LL, 0LL, 1LL}, {2, 1}, {true, false});
|
||||
}
|
||||
|
||||
#ifndef DISABLE_CONTRIB_OPS
|
||||
|
||||
// The contrib spec of GatherND supports `int64` AND `int32` type for `indices`
|
||||
// The official spec only support `int64`
|
||||
// This test covers `int32` indices just for the contrib kernel
|
||||
|
||||
TEST(GatherNDOpTest, ContribOpInt32Indices) {
|
||||
// MSFT domain opset-1 (contrib op)
|
||||
OpTester test2("GatherND", 1, kMSDomain);
|
||||
test2.AddInput<int64_t>("data", {2, 2, 2}, {0, 1, 2, 3, 4, 5, 6, 7});
|
||||
test2.AddInput<int32_t>("indices", {2, 3}, {0, 0, 1, 1, 0, 1});
|
||||
test2.AddOutput<int64_t>("output", {2}, {1, 5});
|
||||
test2.Run();
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
} // namespace test
|
||||
} // namespace onnxruntime
|
||||
|
|
@ -124,8 +124,6 @@ def create_backend_test(testname=None):
|
|||
'^test_range_int32_type_negative_delta_expanded_cpu',
|
||||
'^test_det_2d_cpu',
|
||||
'^test_det_nd_cpu',
|
||||
'^test_gathernd_example_float32_cpu',
|
||||
'^test_gathernd_example_int32_cpu',
|
||||
'^test_resize_downsample_scales_cubic_A_n0p5_exclude_outside_cpu',
|
||||
'^test_resize_downsample_scales_cubic_align_corners_cpu',
|
||||
'^test_resize_downsample_scales_cubic_cpu',
|
||||
|
|
|
|||
Loading…
Reference in a new issue