mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-18 21:21:17 +00:00
Add ScatterElements CPU kernel (#1796)
* Support ScatterElements CPU kernel * Nits and remove test exclusion * PR feedback * Fix bug * Fix test * Remove unused variable * PR comments
This commit is contained in:
parent
034aa80167
commit
d3cb2a5572
5 changed files with 114 additions and 34 deletions
|
|
@ -251,7 +251,7 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, Asi
|
|||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, Acosh);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, Atanh);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, Scan);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, Scatter);
|
||||
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, 10, Scatter);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, string, TfIdfVectorizer);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, int32_t, TfIdfVectorizer);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, int64_t, TfIdfVectorizer);
|
||||
|
|
@ -317,6 +317,7 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, So
|
|||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, Loop);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, DepthToSpace);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, Det);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, ScatterElements);
|
||||
|
||||
void RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) {
|
||||
static const BuildKernelCreateInfoFn function_table[] = {
|
||||
|
|
@ -553,7 +554,7 @@ void RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) {
|
|||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, Acosh)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, Atanh)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, Scan)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, Scatter)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, 10, Scatter)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, string, TfIdfVectorizer)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, int32_t, TfIdfVectorizer)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, int64_t, TfIdfVectorizer)>,
|
||||
|
|
@ -619,6 +620,7 @@ void RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) {
|
|||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, Loop)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, DepthToSpace)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, Det)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, ScatterElements)>,
|
||||
};
|
||||
|
||||
for (auto& function_table_entry : function_table) {
|
||||
|
|
|
|||
|
|
@ -10,10 +10,11 @@ namespace onnxruntime {
|
|||
|
||||
class Scatter final : public OpKernel {
|
||||
public:
|
||||
Scatter(const OpKernelInfo& info) : OpKernel(info) {
|
||||
explicit Scatter(const OpKernelInfo& info) : OpKernel(info) {
|
||||
ORT_ENFORCE(info.GetAttr<int64_t>("axis", &axis_).IsOK(),
|
||||
"Missing/Invalid 'axis' attribute value");
|
||||
}
|
||||
|
||||
~Scatter() = default;
|
||||
Status Compute(OpKernelContext* context) const override;
|
||||
|
||||
|
|
@ -21,9 +22,18 @@ class Scatter final : public OpKernel {
|
|||
int64_t axis_;
|
||||
};
|
||||
|
||||
ONNX_CPU_OPERATOR_KERNEL(
|
||||
ONNX_CPU_OPERATOR_VERSIONED_KERNEL(
|
||||
Scatter,
|
||||
9,
|
||||
9, 10,
|
||||
KernelDefBuilder()
|
||||
.MayInplace(0, 0)
|
||||
.TypeConstraint("T", DataTypeImpl::AllTensorTypes())
|
||||
.TypeConstraint("Tind", std::vector<MLDataType>{DataTypeImpl::GetTensorType<int32_t>(), DataTypeImpl::GetTensorType<int64_t>()}),
|
||||
Scatter);
|
||||
|
||||
ONNX_CPU_OPERATOR_KERNEL(
|
||||
ScatterElements,
|
||||
11,
|
||||
KernelDefBuilder()
|
||||
.MayInplace(0, 0)
|
||||
.TypeConstraint("T", DataTypeImpl::AllTensorTypes())
|
||||
|
|
@ -34,14 +44,25 @@ template <class Tin, class Tdata>
|
|||
Status CopyScatterData(const Tensor* data_input, const Tensor* indices_input, const Tensor* updates_input,
|
||||
const int64_t axis, Tensor* data_output) {
|
||||
const TensorShape& input_data_shape = data_input->Shape();
|
||||
const Tin* indices_data = indices_input->template Data<Tin>();
|
||||
const Tin* indices_data_raw= indices_input->template Data<Tin>();
|
||||
const auto num_indices = indices_input->Shape().Size();
|
||||
|
||||
std::vector<Tin> indices_data;
|
||||
indices_data.reserve(num_indices);
|
||||
|
||||
auto axis_dim_limit = input_data_shape[axis];
|
||||
|
||||
for (int64_t i = 0; i < num_indices; ++i) {
|
||||
Tin idx = indices_data[i];
|
||||
if (idx < 0 || idx >= input_data_shape[axis]) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "indices element out of data bounds, idx=", idx,
|
||||
" data_dim=", input_data_shape[axis]);
|
||||
Tin idx = indices_data_raw[i];
|
||||
|
||||
if (idx < -axis_dim_limit || idx >= axis_dim_limit) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
|
||||
"indices element out of data bounds, idx=", idx,
|
||||
" must be within the inclusive range [", -axis_dim_limit,
|
||||
",", axis_dim_limit - 1, "]");
|
||||
}
|
||||
|
||||
indices_data.push_back(idx < 0 ? idx + static_cast<Tin>(axis_dim_limit) : idx);
|
||||
}
|
||||
|
||||
const auto input_elements = input_data_shape.Size();
|
||||
|
|
|
|||
|
|
@ -406,8 +406,6 @@ int real_main(int argc, char* argv[], Ort::Env& env) {
|
|||
{"unique_sorted_axis_3d", "Unique not implemented yet"},
|
||||
{"unique_sorted_axis", "Unique not implemented yet"},
|
||||
{"unique_sorted_with_negative_axis", "Unique not implemented yet"},
|
||||
{"scatter_elements_with_axis", "not implemented yet"},
|
||||
{"scatter_elements_without_axis", "not implemented yet"},
|
||||
{"round", "not implemented yet"},
|
||||
{"gather_elements_1", "not implemented yet"},
|
||||
{"gather_elements_0", "not implemented yet"},
|
||||
|
|
|
|||
|
|
@ -7,10 +7,8 @@
|
|||
namespace onnxruntime {
|
||||
namespace test {
|
||||
|
||||
const int Scatter_ver = 9;
|
||||
|
||||
TEST(ScatterOpTest, WithoutAxis) {
|
||||
OpTester test("Scatter", Scatter_ver);
|
||||
static void scatter_without_axis_tests(const char* op_name, int op_version) {
|
||||
OpTester test(op_name, op_version);
|
||||
|
||||
std::vector<float> input;
|
||||
input.resize(3 * 3);
|
||||
|
|
@ -32,8 +30,13 @@ TEST(ScatterOpTest, WithoutAxis) {
|
|||
test.Run();
|
||||
}
|
||||
|
||||
TEST(ScatterOpTest, WithAxis) {
|
||||
OpTester test("Scatter", Scatter_ver);
|
||||
TEST(Scatter, WithoutAxis) {
|
||||
scatter_without_axis_tests("Scatter", 9);
|
||||
scatter_without_axis_tests("ScatterElements", 11);
|
||||
}
|
||||
|
||||
static void scatter_with_axis_tests(const char* op_name, int op_version) {
|
||||
OpTester test(op_name, op_version);
|
||||
test.AddAttribute<int64_t>("axis", 1);
|
||||
|
||||
test.AddInput<float>("data", {1, 5}, {1.0f, 2.0f, 3.0f, 4.0f, 5.0f});
|
||||
|
|
@ -43,8 +46,13 @@ TEST(ScatterOpTest, WithAxis) {
|
|||
test.Run();
|
||||
}
|
||||
|
||||
TEST(ScatterOpTest, WithAxisThreeDims) {
|
||||
OpTester test("Scatter", Scatter_ver);
|
||||
TEST(Scatter, WithAxis) {
|
||||
scatter_with_axis_tests("Scatter", 9);
|
||||
scatter_with_axis_tests("ScatterElements", 11);
|
||||
}
|
||||
|
||||
static void scatter_three_dim_with_axis_0(const char* op_name, int op_version) {
|
||||
OpTester test(op_name, op_version);
|
||||
test.AddAttribute<int64_t>("axis", 0);
|
||||
|
||||
test.AddInput<int64_t>("data", {1, 3, 3},
|
||||
|
|
@ -67,8 +75,13 @@ TEST(ScatterOpTest, WithAxisThreeDims) {
|
|||
test.Run();
|
||||
}
|
||||
|
||||
TEST(ScatterOpTest, ThreeDimsWithAxisGE_1) {
|
||||
OpTester test("Scatter", Scatter_ver);
|
||||
TEST(Scatter, ThreeDimsWithAxis_0) {
|
||||
scatter_three_dim_with_axis_0("Scatter", 9);
|
||||
scatter_three_dim_with_axis_0("ScatterElements", 11);
|
||||
}
|
||||
|
||||
static void scatter_three_dim_with_axis_2(const char* op_name, int op_version) {
|
||||
OpTester test(op_name, op_version);
|
||||
test.AddAttribute<int64_t>("axis", 2);
|
||||
|
||||
test.AddInput<int64_t>("data", {1, 3, 3},
|
||||
|
|
@ -92,8 +105,13 @@ TEST(ScatterOpTest, ThreeDimsWithAxisGE_1) {
|
|||
test.Run();
|
||||
}
|
||||
|
||||
TEST(ScatterOpTest, WithAxisStrings) {
|
||||
OpTester test("Scatter", Scatter_ver);
|
||||
TEST(Scatter, ThreeDimsWithAxis_2) {
|
||||
scatter_three_dim_with_axis_2("Scatter", 9);
|
||||
scatter_three_dim_with_axis_2("ScatterElements", 11);
|
||||
}
|
||||
|
||||
static void scatter_string(const char* op_name, int op_version) {
|
||||
OpTester test(op_name, op_version);
|
||||
test.AddAttribute<int64_t>("axis", 1);
|
||||
|
||||
test.AddInput<std::string>("data", {1, 5}, {"1.0f", "2.0f", "3.0f", "4.0f", "5.0f"});
|
||||
|
|
@ -103,8 +121,13 @@ TEST(ScatterOpTest, WithAxisStrings) {
|
|||
test.Run();
|
||||
}
|
||||
|
||||
TEST(ScatterOpTest, NegativeAxis) {
|
||||
OpTester test("Scatter", Scatter_ver);
|
||||
TEST(Scatter, String) {
|
||||
scatter_string("Scatter", 9);
|
||||
scatter_string("ScatterElements", 11);
|
||||
}
|
||||
|
||||
static void scatter_negative_axis(const char* op_name, int op_version) {
|
||||
OpTester test(op_name, op_version);
|
||||
test.AddAttribute<int64_t>("axis", -1);
|
||||
|
||||
test.AddInput<float>("data", {1, 5}, {1.0f, 2.0f, 3.0f, 4.0f, 5.0f});
|
||||
|
|
@ -114,8 +137,13 @@ TEST(ScatterOpTest, NegativeAxis) {
|
|||
test.Run();
|
||||
}
|
||||
|
||||
TEST(ScatterOpTest, IndicesUpdatesDimsDonotMatch) {
|
||||
OpTester test("Scatter", Scatter_ver);
|
||||
TEST(Scatter, NegativeAxis) {
|
||||
scatter_negative_axis("Scatter", 9);
|
||||
scatter_negative_axis("ScatterElements", 11);
|
||||
}
|
||||
|
||||
static void scatter_indices_updates_dont_match(const char* op_name, int op_version) {
|
||||
OpTester test(op_name, op_version);
|
||||
test.AddAttribute<int64_t>("axis", 1);
|
||||
|
||||
test.AddInput<float>("data", {1, 5}, {1.0f, 2.0f, 3.0f, 4.0f, 5.0f});
|
||||
|
|
@ -125,8 +153,13 @@ TEST(ScatterOpTest, IndicesUpdatesDimsDonotMatch) {
|
|||
test.Run(OpTester::ExpectResult::kExpectFailure, "Indices vs updates dimensions differs at position=1 3 vs 2");
|
||||
}
|
||||
|
||||
TEST(ScatterOpTest, ValidIndex) {
|
||||
OpTester test("Scatter", Scatter_ver);
|
||||
TEST(Scatter, IndicesUpdatesDontMatch) {
|
||||
scatter_indices_updates_dont_match("Scatter", 9);
|
||||
scatter_indices_updates_dont_match("ScatterElements", 11);
|
||||
}
|
||||
|
||||
static void scatter_valid_index(const char* op_name, int op_version) {
|
||||
OpTester test(op_name, op_version);
|
||||
test.AddAttribute<int64_t>("axis", 0);
|
||||
|
||||
test.AddInput<float>("data", {4, 2, 1}, {0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f});
|
||||
|
|
@ -136,15 +169,42 @@ TEST(ScatterOpTest, ValidIndex) {
|
|||
test.Run();
|
||||
}
|
||||
|
||||
TEST(ScatterOpTest, InvalidIndex) {
|
||||
OpTester test("Scatter", Scatter_ver);
|
||||
TEST(Scatter, ValidAxis) {
|
||||
scatter_valid_index("Scatter", 9);
|
||||
scatter_valid_index("ScatterElements", 11);
|
||||
}
|
||||
|
||||
static void scatter_invalid_index(const char* op_name, int op_version) {
|
||||
OpTester test(op_name, op_version);
|
||||
test.AddAttribute<int64_t>("axis", 0);
|
||||
|
||||
test.AddInput<float>("data", {4, 2, 1}, {0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f});
|
||||
test.AddInput<int64_t>("indices", {1, 1, 1}, {4});
|
||||
test.AddInput<float>("updates", {1, 1, 1}, {5.0f});
|
||||
test.AddOutput<float>("y", {4, 2, 1}, {0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 5.0f, 0.0f});
|
||||
test.Run(OpTester::ExpectResult::kExpectFailure, "indices element out of data bounds, idx=4 data_dim=4");
|
||||
test.Run(OpTester::ExpectResult::kExpectFailure, "indices element out of data bounds, idx=4 must be within the inclusive range [-4,3]");
|
||||
}
|
||||
|
||||
TEST(Scatter, InvalidIndex) {
|
||||
scatter_invalid_index("Scatter", 9);
|
||||
scatter_invalid_index("ScatterElements", 11);
|
||||
}
|
||||
|
||||
static void scatter_valid_negative_index(const char* op_name, int op_version) {
|
||||
OpTester test(op_name, op_version);
|
||||
test.AddAttribute<int64_t>("axis", 0);
|
||||
|
||||
test.AddInput<float>("data", {4, 2, 1}, {0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f});
|
||||
test.AddInput<int64_t>("indices", {1, 1, 1}, {-1});
|
||||
test.AddInput<float>("updates", {1, 1, 1}, {5.0f});
|
||||
test.AddOutput<float>("y", {4, 2, 1}, {0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 5.0f, 0.0f});
|
||||
test.Run();
|
||||
}
|
||||
|
||||
TEST(Scatter, ValidNegativeIndex) {
|
||||
scatter_valid_negative_index("Scatter", 9);
|
||||
scatter_valid_negative_index("ScatterElements", 11);
|
||||
}
|
||||
|
||||
} // namespace test
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -115,7 +115,6 @@ def create_backend_test(testname=None):
|
|||
'^test_dynamicquantizelinear_max_adjusted_expanded*',
|
||||
'^test_dynamicquantizelinear_min_adjusted_expanded*',
|
||||
'^test_gather_elements*',
|
||||
'^test_scatter_elements*',
|
||||
'^test_top_k*',
|
||||
'^test_unique_*',
|
||||
'^test_mod_float_mixed_sign_example_cpu.*', #onnxruntime::Mod::Compute fmod_ was false. fmod attribute must be true for float, float16 and double types
|
||||
|
|
|
|||
Loading…
Reference in a new issue