Add ScatterElements CPU kernel (#1796)

* Support ScatterElements CPU kernel

* Nits and remove test exclusion

* PR feedback

* Fix bug

* Fix test

* Remove unused variable

* PR comments
This commit is contained in:
Hariharan Seshadri 2019-09-24 17:11:01 -07:00 committed by GitHub
parent 034aa80167
commit d3cb2a5572
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 114 additions and 34 deletions

View file

@ -251,7 +251,7 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, Asi
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, Acosh);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, Atanh);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, Scan);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, Scatter);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, 10, Scatter);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, string, TfIdfVectorizer);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, int32_t, TfIdfVectorizer);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, int64_t, TfIdfVectorizer);
@ -317,6 +317,7 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, So
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, Loop);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, DepthToSpace);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, Det);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, ScatterElements);
void RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) {
static const BuildKernelCreateInfoFn function_table[] = {
@ -553,7 +554,7 @@ void RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, Acosh)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, Atanh)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, Scan)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, Scatter)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, 10, Scatter)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, string, TfIdfVectorizer)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, int32_t, TfIdfVectorizer)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, int64_t, TfIdfVectorizer)>,
@ -619,6 +620,7 @@ void RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, Loop)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, DepthToSpace)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, Det)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, ScatterElements)>,
};
for (auto& function_table_entry : function_table) {

View file

@ -10,10 +10,11 @@ namespace onnxruntime {
class Scatter final : public OpKernel {
public:
Scatter(const OpKernelInfo& info) : OpKernel(info) {
explicit Scatter(const OpKernelInfo& info) : OpKernel(info) {
ORT_ENFORCE(info.GetAttr<int64_t>("axis", &axis_).IsOK(),
"Missing/Invalid 'axis' attribute value");
}
~Scatter() = default;
Status Compute(OpKernelContext* context) const override;
@ -21,9 +22,18 @@ class Scatter final : public OpKernel {
int64_t axis_;
};
ONNX_CPU_OPERATOR_KERNEL(
ONNX_CPU_OPERATOR_VERSIONED_KERNEL(
Scatter,
9,
9, 10,
KernelDefBuilder()
.MayInplace(0, 0)
.TypeConstraint("T", DataTypeImpl::AllTensorTypes())
.TypeConstraint("Tind", std::vector<MLDataType>{DataTypeImpl::GetTensorType<int32_t>(), DataTypeImpl::GetTensorType<int64_t>()}),
Scatter);
ONNX_CPU_OPERATOR_KERNEL(
ScatterElements,
11,
KernelDefBuilder()
.MayInplace(0, 0)
.TypeConstraint("T", DataTypeImpl::AllTensorTypes())
@ -34,14 +44,25 @@ template <class Tin, class Tdata>
Status CopyScatterData(const Tensor* data_input, const Tensor* indices_input, const Tensor* updates_input,
const int64_t axis, Tensor* data_output) {
const TensorShape& input_data_shape = data_input->Shape();
const Tin* indices_data = indices_input->template Data<Tin>();
const Tin* indices_data_raw= indices_input->template Data<Tin>();
const auto num_indices = indices_input->Shape().Size();
std::vector<Tin> indices_data;
indices_data.reserve(num_indices);
auto axis_dim_limit = input_data_shape[axis];
for (int64_t i = 0; i < num_indices; ++i) {
Tin idx = indices_data[i];
if (idx < 0 || idx >= input_data_shape[axis]) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "indices element out of data bounds, idx=", idx,
" data_dim=", input_data_shape[axis]);
Tin idx = indices_data_raw[i];
if (idx < -axis_dim_limit || idx >= axis_dim_limit) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"indices element out of data bounds, idx=", idx,
" must be within the inclusive range [", -axis_dim_limit,
",", axis_dim_limit - 1, "]");
}
indices_data.push_back(idx < 0 ? idx + static_cast<Tin>(axis_dim_limit) : idx);
}
const auto input_elements = input_data_shape.Size();

View file

@ -406,8 +406,6 @@ int real_main(int argc, char* argv[], Ort::Env& env) {
{"unique_sorted_axis_3d", "Unique not implemented yet"},
{"unique_sorted_axis", "Unique not implemented yet"},
{"unique_sorted_with_negative_axis", "Unique not implemented yet"},
{"scatter_elements_with_axis", "not implemented yet"},
{"scatter_elements_without_axis", "not implemented yet"},
{"round", "not implemented yet"},
{"gather_elements_1", "not implemented yet"},
{"gather_elements_0", "not implemented yet"},

View file

@ -7,10 +7,8 @@
namespace onnxruntime {
namespace test {
const int Scatter_ver = 9;
TEST(ScatterOpTest, WithoutAxis) {
OpTester test("Scatter", Scatter_ver);
static void scatter_without_axis_tests(const char* op_name, int op_version) {
OpTester test(op_name, op_version);
std::vector<float> input;
input.resize(3 * 3);
@ -32,8 +30,13 @@ TEST(ScatterOpTest, WithoutAxis) {
test.Run();
}
TEST(ScatterOpTest, WithAxis) {
OpTester test("Scatter", Scatter_ver);
TEST(Scatter, WithoutAxis) {
scatter_without_axis_tests("Scatter", 9);
scatter_without_axis_tests("ScatterElements", 11);
}
static void scatter_with_axis_tests(const char* op_name, int op_version) {
OpTester test(op_name, op_version);
test.AddAttribute<int64_t>("axis", 1);
test.AddInput<float>("data", {1, 5}, {1.0f, 2.0f, 3.0f, 4.0f, 5.0f});
@ -43,8 +46,13 @@ TEST(ScatterOpTest, WithAxis) {
test.Run();
}
TEST(ScatterOpTest, WithAxisThreeDims) {
OpTester test("Scatter", Scatter_ver);
TEST(Scatter, WithAxis) {
scatter_with_axis_tests("Scatter", 9);
scatter_with_axis_tests("ScatterElements", 11);
}
static void scatter_three_dim_with_axis_0(const char* op_name, int op_version) {
OpTester test(op_name, op_version);
test.AddAttribute<int64_t>("axis", 0);
test.AddInput<int64_t>("data", {1, 3, 3},
@ -67,8 +75,13 @@ TEST(ScatterOpTest, WithAxisThreeDims) {
test.Run();
}
TEST(ScatterOpTest, ThreeDimsWithAxisGE_1) {
OpTester test("Scatter", Scatter_ver);
TEST(Scatter, ThreeDimsWithAxis_0) {
scatter_three_dim_with_axis_0("Scatter", 9);
scatter_three_dim_with_axis_0("ScatterElements", 11);
}
static void scatter_three_dim_with_axis_2(const char* op_name, int op_version) {
OpTester test(op_name, op_version);
test.AddAttribute<int64_t>("axis", 2);
test.AddInput<int64_t>("data", {1, 3, 3},
@ -92,8 +105,13 @@ TEST(ScatterOpTest, ThreeDimsWithAxisGE_1) {
test.Run();
}
TEST(ScatterOpTest, WithAxisStrings) {
OpTester test("Scatter", Scatter_ver);
TEST(Scatter, ThreeDimsWithAxis_2) {
scatter_three_dim_with_axis_2("Scatter", 9);
scatter_three_dim_with_axis_2("ScatterElements", 11);
}
static void scatter_string(const char* op_name, int op_version) {
OpTester test(op_name, op_version);
test.AddAttribute<int64_t>("axis", 1);
test.AddInput<std::string>("data", {1, 5}, {"1.0f", "2.0f", "3.0f", "4.0f", "5.0f"});
@ -103,8 +121,13 @@ TEST(ScatterOpTest, WithAxisStrings) {
test.Run();
}
TEST(ScatterOpTest, NegativeAxis) {
OpTester test("Scatter", Scatter_ver);
TEST(Scatter, String) {
scatter_string("Scatter", 9);
scatter_string("ScatterElements", 11);
}
static void scatter_negative_axis(const char* op_name, int op_version) {
OpTester test(op_name, op_version);
test.AddAttribute<int64_t>("axis", -1);
test.AddInput<float>("data", {1, 5}, {1.0f, 2.0f, 3.0f, 4.0f, 5.0f});
@ -114,8 +137,13 @@ TEST(ScatterOpTest, NegativeAxis) {
test.Run();
}
TEST(ScatterOpTest, IndicesUpdatesDimsDonotMatch) {
OpTester test("Scatter", Scatter_ver);
TEST(Scatter, NegativeAxis) {
scatter_negative_axis("Scatter", 9);
scatter_negative_axis("ScatterElements", 11);
}
static void scatter_indices_updates_dont_match(const char* op_name, int op_version) {
OpTester test(op_name, op_version);
test.AddAttribute<int64_t>("axis", 1);
test.AddInput<float>("data", {1, 5}, {1.0f, 2.0f, 3.0f, 4.0f, 5.0f});
@ -125,8 +153,13 @@ TEST(ScatterOpTest, IndicesUpdatesDimsDonotMatch) {
test.Run(OpTester::ExpectResult::kExpectFailure, "Indices vs updates dimensions differs at position=1 3 vs 2");
}
TEST(ScatterOpTest, ValidIndex) {
OpTester test("Scatter", Scatter_ver);
TEST(Scatter, IndicesUpdatesDontMatch) {
scatter_indices_updates_dont_match("Scatter", 9);
scatter_indices_updates_dont_match("ScatterElements", 11);
}
static void scatter_valid_index(const char* op_name, int op_version) {
OpTester test(op_name, op_version);
test.AddAttribute<int64_t>("axis", 0);
test.AddInput<float>("data", {4, 2, 1}, {0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f});
@ -136,15 +169,42 @@ TEST(ScatterOpTest, ValidIndex) {
test.Run();
}
TEST(ScatterOpTest, InvalidIndex) {
OpTester test("Scatter", Scatter_ver);
TEST(Scatter, ValidAxis) {
scatter_valid_index("Scatter", 9);
scatter_valid_index("ScatterElements", 11);
}
static void scatter_invalid_index(const char* op_name, int op_version) {
OpTester test(op_name, op_version);
test.AddAttribute<int64_t>("axis", 0);
test.AddInput<float>("data", {4, 2, 1}, {0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f});
test.AddInput<int64_t>("indices", {1, 1, 1}, {4});
test.AddInput<float>("updates", {1, 1, 1}, {5.0f});
test.AddOutput<float>("y", {4, 2, 1}, {0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 5.0f, 0.0f});
test.Run(OpTester::ExpectResult::kExpectFailure, "indices element out of data bounds, idx=4 data_dim=4");
test.Run(OpTester::ExpectResult::kExpectFailure, "indices element out of data bounds, idx=4 must be within the inclusive range [-4,3]");
}
TEST(Scatter, InvalidIndex) {
scatter_invalid_index("Scatter", 9);
scatter_invalid_index("ScatterElements", 11);
}
static void scatter_valid_negative_index(const char* op_name, int op_version) {
OpTester test(op_name, op_version);
test.AddAttribute<int64_t>("axis", 0);
test.AddInput<float>("data", {4, 2, 1}, {0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f});
test.AddInput<int64_t>("indices", {1, 1, 1}, {-1});
test.AddInput<float>("updates", {1, 1, 1}, {5.0f});
test.AddOutput<float>("y", {4, 2, 1}, {0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 5.0f, 0.0f});
test.Run();
}
TEST(Scatter, ValidNegativeIndex) {
scatter_valid_negative_index("Scatter", 9);
scatter_valid_negative_index("ScatterElements", 11);
}
} // namespace test
} // namespace onnxruntime

View file

@ -115,7 +115,6 @@ def create_backend_test(testname=None):
'^test_dynamicquantizelinear_max_adjusted_expanded*',
'^test_dynamicquantizelinear_min_adjusted_expanded*',
'^test_gather_elements*',
'^test_scatter_elements*',
'^test_top_k*',
'^test_unique_*',
'^test_mod_float_mixed_sign_example_cpu.*', #onnxruntime::Mod::Compute fmod_ was false. fmod attribute must be true for float, float16 and double types