From b68079fe5d334ab525cc80a901efdc5e4dee4833 Mon Sep 17 00:00:00 2001 From: David Fan <30608893+jiafatom@users.noreply.github.com> Date: Thu, 7 Mar 2019 00:13:11 -0800 Subject: [PATCH] Support int32_t for Split op (#563) * Support int32_t for Split op * Support int32_t for Split op --- .../core/providers/cpu/tensor/split.cc | 3 + .../providers/cpu/tensor/split_op_test.cc | 113 +++++++++++------- 2 files changed, 72 insertions(+), 44 deletions(-) diff --git a/onnxruntime/core/providers/cpu/tensor/split.cc b/onnxruntime/core/providers/cpu/tensor/split.cc index 2e99564712..d4e655a7f2 100644 --- a/onnxruntime/core/providers/cpu/tensor/split.cc +++ b/onnxruntime/core/providers/cpu/tensor/split.cc @@ -17,6 +17,7 @@ ONNX_CPU_OPERATOR_KERNEL( std::vector{ DataTypeImpl::GetTensorType(), DataTypeImpl::GetTensorType(), + DataTypeImpl::GetTensorType(), }), Split); @@ -28,6 +29,8 @@ Status Split::Compute(OpKernelContext* context) const { if (data_type == DataTypeImpl::GetType()) status = ComputeImpl(*context, input); + else if (data_type == DataTypeImpl::GetType()) + status = ComputeImpl(*context, input); else if (data_type == DataTypeImpl::GetType()) { /* Need to update CopyMatrix to support double... status = ComputeImpl(*context, input); */ diff --git a/onnxruntime/test/providers/cpu/tensor/split_op_test.cc b/onnxruntime/test/providers/cpu/tensor/split_op_test.cc index c2a4d07365..a7c4f35d82 100644 --- a/onnxruntime/test/providers/cpu/tensor/split_op_test.cc +++ b/onnxruntime/test/providers/cpu/tensor/split_op_test.cc @@ -7,12 +7,15 @@ namespace onnxruntime { namespace test { -using ShapeAndData = std::pair, const std::vector>; +template using ShapeAndData = std::pair, const std::vector>; + +using ShapeAndFloatData = ShapeAndData; +using ShapeAndInt32Data = ShapeAndData; using ExpectResult = OpTester::ExpectResult; -void RunTest(int64_t axis, const std::vector split_sizes, const ShapeAndData& input, - const std::vector& outputs, - bool expect_failure = false, const std::string& err_msg = {}) { +template void RunTest(int64_t axis, const std::vector split_sizes, const ShapeAndData& input, + const std::vector>& outputs, + bool expect_failure = false, const std::string& err_msg = {}) { OpTester test("Split"); test.AddAttribute("axis", axis); @@ -20,7 +23,7 @@ void RunTest(int64_t axis, const std::vector split_sizes, const ShapeAn if (!split_sizes.empty()) test.AddAttribute("split", split_sizes); - test.AddInput("input", input.first, input.second); + test.AddInput("input", input.first, input.second); int i = 0; for (auto& output : outputs) { @@ -28,7 +31,7 @@ void RunTest(int64_t axis, const std::vector split_sizes, const ShapeAn auto& data = output.second; std::ostringstream oss; oss << "output" << i++; - test.AddOutput(oss.str().c_str(), shape, data); + test.AddOutput(oss.str().c_str(), shape, data); } test.Run(expect_failure ? ExpectResult::kExpectFailure : ExpectResult::kExpectSuccess, err_msg); @@ -36,10 +39,10 @@ void RunTest(int64_t axis, const std::vector split_sizes, const ShapeAn TEST(SplitOperatorTest, Axis0EqualSplit) { const int64_t axis = 0; - std::vector outputs; + std::vector outputs; // input shape and data - ShapeAndData input = {{4, 2}, // shape + ShapeAndFloatData input = {{4, 2}, // shape {1.f, 2.f, 3.f, 4.f, 5.f, 6.f, @@ -53,15 +56,37 @@ TEST(SplitOperatorTest, Axis0EqualSplit) { {5.f, 6.f, 7.f, 8.f}}); - RunTest(axis, {}, input, outputs); + RunTest(axis, {}, input, outputs); +} + +TEST(SplitOperatorTest, Axis0EqualSplitInt32) { + const int64_t axis = 0; + std::vector outputs; + + // input shape and data + ShapeAndInt32Data input = {{4, 2}, // shape + {1, 2, + 3, 4, + 5, 6, + 7, 8}}; + + outputs.push_back({{2, 2}, + {1, 2, + 3, 4}}); + + outputs.push_back({{2, 2}, + {5, 6, + 7, 8}}); + + RunTest(axis, {}, input, outputs); } TEST(SplitOperatorTest, Axis0UnequalSplit) { const int64_t axis = 0; - std::vector outputs; + std::vector outputs; // input shape and data - ShapeAndData input = {{4, 2}, // shape + ShapeAndFloatData input = {{4, 2}, // shape {1.f, 2.f, 3.f, 4.f, 5.f, 6.f, @@ -76,15 +101,15 @@ TEST(SplitOperatorTest, Axis0UnequalSplit) { 5.f, 6.f, 7.f, 8.f}}); - RunTest(axis, splits, input, outputs); + RunTest(axis, splits, input, outputs); } TEST(SplitOperatorTest, Axis1EqualSplit) { const int64_t axis = 1; - std::vector outputs; + std::vector outputs; // input shape and data - ShapeAndData input = {{2, 4}, + ShapeAndFloatData input = {{2, 4}, {1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f}}; @@ -96,15 +121,15 @@ TEST(SplitOperatorTest, Axis1EqualSplit) { {3.f, 4.f, 7.f, 8.f}}); - RunTest(axis, {}, input, outputs); + RunTest(axis, {}, input, outputs); } TEST(SplitOperatorTest, Axis1UnequalSplit) { const int64_t axis = 1; - std::vector outputs; + std::vector outputs; // input shape and data - ShapeAndData input = {{2, 4}, + ShapeAndFloatData input = {{2, 4}, {1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f}}; @@ -118,10 +143,10 @@ TEST(SplitOperatorTest, Axis1UnequalSplit) { {4.f, 8.f}}); - RunTest(axis, splits, input, outputs); + RunTest(axis, splits, input, outputs); } -ShapeAndData CreateInput(std::vector shape) { +ShapeAndFloatData CreateInput(std::vector shape) { auto size = TensorShape(shape).Size(); float i = 0.f, increment = 1.f; @@ -129,16 +154,16 @@ ShapeAndData CreateInput(std::vector shape) { std::vector data; std::generate_n(std::back_inserter(data), size, [&]() { return i += increment; }); - ShapeAndData input = {shape, data}; + ShapeAndFloatData input = {shape, data}; return input; } TEST(SplitOperatorTest, Axis2EqualSplit) { const int64_t axis = 2; - std::vector outputs; + std::vector outputs; - ShapeAndData input = CreateInput({2, 2, 6}); + ShapeAndFloatData input = CreateInput({2, 2, 6}); outputs.push_back({{2, 2, 2}, {1.f, 2.f, @@ -161,14 +186,14 @@ TEST(SplitOperatorTest, Axis2EqualSplit) { 17.f, 18.f, 23.f, 24.f}}); - RunTest(axis, {}, input, outputs); + RunTest(axis, {}, input, outputs); } TEST(SplitOperatorTest, Axis2UnequalSplit) { const int64_t axis = 2; - std::vector outputs; + std::vector outputs; - ShapeAndData input = CreateInput({2, 2, 6}); + ShapeAndFloatData input = CreateInput({2, 2, 6}); std::vector splits{1, 2, 3}; @@ -193,15 +218,15 @@ TEST(SplitOperatorTest, Axis2UnequalSplit) { 16.f, 17.f, 18.f, 22.f, 23.f, 24.f}}); - RunTest(axis, splits, input, outputs); + RunTest(axis, splits, input, outputs); } // test a split of a dimension that has leading and trailing dimensions TEST(SplitOperatorTest, Axis1SplitMiddleDimensionEqually) { const int64_t axis = 1; - std::vector outputs; + std::vector outputs; - ShapeAndData input = CreateInput({2, 4, 4}); + ShapeAndFloatData input = CreateInput({2, 4, 4}); outputs.push_back({{2, 2, 4}, {1.f, 2.f, 3.f, 4.f, @@ -217,15 +242,15 @@ TEST(SplitOperatorTest, Axis1SplitMiddleDimensionEqually) { 25.f, 26.f, 27.f, 28.f, 29.f, 30.f, 31.f, 32.f}}); - RunTest(axis, {}, input, outputs); + RunTest(axis, {}, input, outputs); } // test a split of a dimension that has leading and trailing dimensions TEST(SplitOperatorTest, Axis1SplitMiddleDimensionUnequally) { const int64_t axis = 1; - std::vector outputs; + std::vector outputs; - ShapeAndData input = CreateInput({2, 4, 4}); + ShapeAndFloatData input = CreateInput({2, 4, 4}); std::vector splits{1, 3}; @@ -243,15 +268,15 @@ TEST(SplitOperatorTest, Axis1SplitMiddleDimensionUnequally) { 25.f, 26.f, 27.f, 28.f, 29.f, 30.f, 31.f, 32.f}}); - RunTest(axis, splits, input, outputs); + RunTest(axis, splits, input, outputs); } TEST(SplitOperatorTest, NegativeAxis) { const int64_t axis = -1; // split last axis equally - std::vector outputs; + std::vector outputs; // input shape and data - ShapeAndData input = {{2, 4}, + ShapeAndFloatData input = {{2, 4}, {1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f}}; @@ -263,15 +288,15 @@ TEST(SplitOperatorTest, NegativeAxis) { {3.f, 4.f, 7.f, 8.f}}); - RunTest(axis, {}, input, outputs); + RunTest(axis, {}, input, outputs); } TEST(SplitOperatorTest, InvalidAxis) { const int64_t axis = 2; - std::vector outputs; + std::vector outputs; // input shape and data - ShapeAndData input = {{4, 2}, // shape + ShapeAndFloatData input = {{4, 2}, // shape {1.f, 2.f, 3.f, 4.f, 5.f, 6.f, @@ -279,16 +304,16 @@ TEST(SplitOperatorTest, InvalidAxis) { outputs.push_back({{1}, {0.f}}); - RunTest(axis, {}, input, outputs, true, "Invalid value of attribute 'axis'"); + RunTest(axis, {}, input, outputs, true, "Invalid value of attribute 'axis'"); } // sum of values in splits is too small TEST(SplitOperatorTest, SplitAttributeSumTooSmall) { const int64_t axis = 0; - std::vector outputs; + std::vector outputs; // input shape and data - ShapeAndData input = {{4, 2}, // shape + ShapeAndFloatData input = {{4, 2}, // shape {1.f, 2.f, 3.f, 4.f, 5.f, 6.f, @@ -299,15 +324,15 @@ TEST(SplitOperatorTest, SplitAttributeSumTooSmall) { outputs.push_back({{1, 2}, {1.f, 2.f}}); outputs.push_back({{2, 2}, {3.f, 4.f, 5.f, 6.f}}); - RunTest(axis, splits, input, outputs, true, "Cannot split using values in 'split' attribute"); + RunTest(axis, splits, input, outputs, true, "Cannot split using values in 'split' attribute"); } TEST(SplitOperatorTest, InvalidValueInSplitAttribute) { const int64_t axis = 0; - std::vector outputs; + std::vector outputs; // input shape and data - ShapeAndData input = {{4, 2}, // shape + ShapeAndFloatData input = {{4, 2}, // shape {1.f, 2.f, 3.f, 4.f, 5.f, 6.f, @@ -317,7 +342,7 @@ TEST(SplitOperatorTest, InvalidValueInSplitAttribute) { outputs.push_back({{1, 2}, {1.f, 2.f}}); outputs.push_back({{3, 2}, {3.f, 4.f, 5.f, 6.f, 7.f, 8.f}}); - RunTest(axis, splits, input, outputs, true, "Invalid value in 'split' attribute"); + RunTest(axis, splits, input, outputs, true, "Invalid value in 'split' attribute"); } /*