Support int32_t for Split op (#563)

* Support int32_t for Split op

* Support int32_t for Split op
This commit is contained in:
David Fan 2019-03-07 00:13:11 -08:00 committed by GitHub
parent af9c554dd3
commit b68079fe5d
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 72 additions and 44 deletions

View file

@ -17,6 +17,7 @@ ONNX_CPU_OPERATOR_KERNEL(
std::vector<MLDataType>{
DataTypeImpl::GetTensorType<float>(),
DataTypeImpl::GetTensorType<double>(),
DataTypeImpl::GetTensorType<int32_t>(),
}),
Split);
@ -28,6 +29,8 @@ Status Split::Compute(OpKernelContext* context) const {
if (data_type == DataTypeImpl::GetType<float>())
status = ComputeImpl<float>(*context, input);
else if (data_type == DataTypeImpl::GetType<int32_t>())
status = ComputeImpl<int32_t>(*context, input);
else if (data_type == DataTypeImpl::GetType<double>()) {
/* Need to update CopyMatrix to support double...
status = ComputeImpl<double>(*context, input); */

View file

@ -7,12 +7,15 @@
namespace onnxruntime {
namespace test {
using ShapeAndData = std::pair<const std::vector<int64_t>, const std::vector<float>>;
template<class T> using ShapeAndData = std::pair<const std::vector<int64_t>, const std::vector<T>>;
using ShapeAndFloatData = ShapeAndData<float>;
using ShapeAndInt32Data = ShapeAndData<int32_t>;
using ExpectResult = OpTester::ExpectResult;
void RunTest(int64_t axis, const std::vector<int64_t> split_sizes, const ShapeAndData& input,
const std::vector<ShapeAndData>& outputs,
bool expect_failure = false, const std::string& err_msg = {}) {
template<typename T> void RunTest(int64_t axis, const std::vector<int64_t> split_sizes, const ShapeAndData<T>& input,
const std::vector<ShapeAndData<T>>& outputs,
bool expect_failure = false, const std::string& err_msg = {}) {
OpTester test("Split");
test.AddAttribute("axis", axis);
@ -20,7 +23,7 @@ void RunTest(int64_t axis, const std::vector<int64_t> split_sizes, const ShapeAn
if (!split_sizes.empty())
test.AddAttribute("split", split_sizes);
test.AddInput<float>("input", input.first, input.second);
test.AddInput<T>("input", input.first, input.second);
int i = 0;
for (auto& output : outputs) {
@ -28,7 +31,7 @@ void RunTest(int64_t axis, const std::vector<int64_t> split_sizes, const ShapeAn
auto& data = output.second;
std::ostringstream oss;
oss << "output" << i++;
test.AddOutput<float>(oss.str().c_str(), shape, data);
test.AddOutput<T>(oss.str().c_str(), shape, data);
}
test.Run(expect_failure ? ExpectResult::kExpectFailure : ExpectResult::kExpectSuccess, err_msg);
@ -36,10 +39,10 @@ void RunTest(int64_t axis, const std::vector<int64_t> split_sizes, const ShapeAn
TEST(SplitOperatorTest, Axis0EqualSplit) {
const int64_t axis = 0;
std::vector<ShapeAndData> outputs;
std::vector<ShapeAndFloatData> outputs;
// input shape and data
ShapeAndData input = {{4, 2}, // shape
ShapeAndFloatData input = {{4, 2}, // shape
{1.f, 2.f,
3.f, 4.f,
5.f, 6.f,
@ -53,15 +56,37 @@ TEST(SplitOperatorTest, Axis0EqualSplit) {
{5.f, 6.f,
7.f, 8.f}});
RunTest(axis, {}, input, outputs);
RunTest<float>(axis, {}, input, outputs);
}
TEST(SplitOperatorTest, Axis0EqualSplitInt32) {
const int64_t axis = 0;
std::vector<ShapeAndInt32Data> outputs;
// input shape and data
ShapeAndInt32Data input = {{4, 2}, // shape
{1, 2,
3, 4,
5, 6,
7, 8}};
outputs.push_back({{2, 2},
{1, 2,
3, 4}});
outputs.push_back({{2, 2},
{5, 6,
7, 8}});
RunTest<int32_t>(axis, {}, input, outputs);
}
TEST(SplitOperatorTest, Axis0UnequalSplit) {
const int64_t axis = 0;
std::vector<ShapeAndData> outputs;
std::vector<ShapeAndFloatData> outputs;
// input shape and data
ShapeAndData input = {{4, 2}, // shape
ShapeAndFloatData input = {{4, 2}, // shape
{1.f, 2.f,
3.f, 4.f,
5.f, 6.f,
@ -76,15 +101,15 @@ TEST(SplitOperatorTest, Axis0UnequalSplit) {
5.f, 6.f,
7.f, 8.f}});
RunTest(axis, splits, input, outputs);
RunTest<float>(axis, splits, input, outputs);
}
TEST(SplitOperatorTest, Axis1EqualSplit) {
const int64_t axis = 1;
std::vector<ShapeAndData> outputs;
std::vector<ShapeAndFloatData> outputs;
// input shape and data
ShapeAndData input = {{2, 4},
ShapeAndFloatData input = {{2, 4},
{1.f, 2.f, 3.f, 4.f,
5.f, 6.f, 7.f, 8.f}};
@ -96,15 +121,15 @@ TEST(SplitOperatorTest, Axis1EqualSplit) {
{3.f, 4.f,
7.f, 8.f}});
RunTest(axis, {}, input, outputs);
RunTest<float>(axis, {}, input, outputs);
}
TEST(SplitOperatorTest, Axis1UnequalSplit) {
const int64_t axis = 1;
std::vector<ShapeAndData> outputs;
std::vector<ShapeAndFloatData> outputs;
// input shape and data
ShapeAndData input = {{2, 4},
ShapeAndFloatData input = {{2, 4},
{1.f, 2.f, 3.f, 4.f,
5.f, 6.f, 7.f, 8.f}};
@ -118,10 +143,10 @@ TEST(SplitOperatorTest, Axis1UnequalSplit) {
{4.f,
8.f}});
RunTest(axis, splits, input, outputs);
RunTest<float>(axis, splits, input, outputs);
}
ShapeAndData CreateInput(std::vector<int64_t> shape) {
ShapeAndFloatData CreateInput(std::vector<int64_t> shape) {
auto size = TensorShape(shape).Size();
float i = 0.f, increment = 1.f;
@ -129,16 +154,16 @@ ShapeAndData CreateInput(std::vector<int64_t> shape) {
std::vector<float> data;
std::generate_n(std::back_inserter(data), size, [&]() { return i += increment; });
ShapeAndData input = {shape, data};
ShapeAndFloatData input = {shape, data};
return input;
}
TEST(SplitOperatorTest, Axis2EqualSplit) {
const int64_t axis = 2;
std::vector<ShapeAndData> outputs;
std::vector<ShapeAndFloatData> outputs;
ShapeAndData input = CreateInput({2, 2, 6});
ShapeAndFloatData input = CreateInput({2, 2, 6});
outputs.push_back({{2, 2, 2},
{1.f, 2.f,
@ -161,14 +186,14 @@ TEST(SplitOperatorTest, Axis2EqualSplit) {
17.f, 18.f,
23.f, 24.f}});
RunTest(axis, {}, input, outputs);
RunTest<float>(axis, {}, input, outputs);
}
TEST(SplitOperatorTest, Axis2UnequalSplit) {
const int64_t axis = 2;
std::vector<ShapeAndData> outputs;
std::vector<ShapeAndFloatData> outputs;
ShapeAndData input = CreateInput({2, 2, 6});
ShapeAndFloatData input = CreateInput({2, 2, 6});
std::vector<int64_t> splits{1, 2, 3};
@ -193,15 +218,15 @@ TEST(SplitOperatorTest, Axis2UnequalSplit) {
16.f, 17.f, 18.f,
22.f, 23.f, 24.f}});
RunTest(axis, splits, input, outputs);
RunTest<float>(axis, splits, input, outputs);
}
// test a split of a dimension that has leading and trailing dimensions
TEST(SplitOperatorTest, Axis1SplitMiddleDimensionEqually) {
const int64_t axis = 1;
std::vector<ShapeAndData> outputs;
std::vector<ShapeAndFloatData> outputs;
ShapeAndData input = CreateInput({2, 4, 4});
ShapeAndFloatData input = CreateInput({2, 4, 4});
outputs.push_back({{2, 2, 4},
{1.f, 2.f, 3.f, 4.f,
@ -217,15 +242,15 @@ TEST(SplitOperatorTest, Axis1SplitMiddleDimensionEqually) {
25.f, 26.f, 27.f, 28.f,
29.f, 30.f, 31.f, 32.f}});
RunTest(axis, {}, input, outputs);
RunTest<float>(axis, {}, input, outputs);
}
// test a split of a dimension that has leading and trailing dimensions
TEST(SplitOperatorTest, Axis1SplitMiddleDimensionUnequally) {
const int64_t axis = 1;
std::vector<ShapeAndData> outputs;
std::vector<ShapeAndFloatData> outputs;
ShapeAndData input = CreateInput({2, 4, 4});
ShapeAndFloatData input = CreateInput({2, 4, 4});
std::vector<int64_t> splits{1, 3};
@ -243,15 +268,15 @@ TEST(SplitOperatorTest, Axis1SplitMiddleDimensionUnequally) {
25.f, 26.f, 27.f, 28.f,
29.f, 30.f, 31.f, 32.f}});
RunTest(axis, splits, input, outputs);
RunTest<float>(axis, splits, input, outputs);
}
TEST(SplitOperatorTest, NegativeAxis) {
const int64_t axis = -1; // split last axis equally
std::vector<ShapeAndData> outputs;
std::vector<ShapeAndFloatData> outputs;
// input shape and data
ShapeAndData input = {{2, 4},
ShapeAndFloatData input = {{2, 4},
{1.f, 2.f, 3.f, 4.f,
5.f, 6.f, 7.f, 8.f}};
@ -263,15 +288,15 @@ TEST(SplitOperatorTest, NegativeAxis) {
{3.f, 4.f,
7.f, 8.f}});
RunTest(axis, {}, input, outputs);
RunTest<float>(axis, {}, input, outputs);
}
TEST(SplitOperatorTest, InvalidAxis) {
const int64_t axis = 2;
std::vector<ShapeAndData> outputs;
std::vector<ShapeAndFloatData> outputs;
// input shape and data
ShapeAndData input = {{4, 2}, // shape
ShapeAndFloatData input = {{4, 2}, // shape
{1.f, 2.f,
3.f, 4.f,
5.f, 6.f,
@ -279,16 +304,16 @@ TEST(SplitOperatorTest, InvalidAxis) {
outputs.push_back({{1}, {0.f}});
RunTest(axis, {}, input, outputs, true, "Invalid value of attribute 'axis'");
RunTest<float>(axis, {}, input, outputs, true, "Invalid value of attribute 'axis'");
}
// sum of values in splits is too small
TEST(SplitOperatorTest, SplitAttributeSumTooSmall) {
const int64_t axis = 0;
std::vector<ShapeAndData> outputs;
std::vector<ShapeAndFloatData> outputs;
// input shape and data
ShapeAndData input = {{4, 2}, // shape
ShapeAndFloatData input = {{4, 2}, // shape
{1.f, 2.f,
3.f, 4.f,
5.f, 6.f,
@ -299,15 +324,15 @@ TEST(SplitOperatorTest, SplitAttributeSumTooSmall) {
outputs.push_back({{1, 2}, {1.f, 2.f}});
outputs.push_back({{2, 2}, {3.f, 4.f, 5.f, 6.f}});
RunTest(axis, splits, input, outputs, true, "Cannot split using values in 'split' attribute");
RunTest<float>(axis, splits, input, outputs, true, "Cannot split using values in 'split' attribute");
}
TEST(SplitOperatorTest, InvalidValueInSplitAttribute) {
const int64_t axis = 0;
std::vector<ShapeAndData> outputs;
std::vector<ShapeAndFloatData> outputs;
// input shape and data
ShapeAndData input = {{4, 2}, // shape
ShapeAndFloatData input = {{4, 2}, // shape
{1.f, 2.f,
3.f, 4.f,
5.f, 6.f,
@ -317,7 +342,7 @@ TEST(SplitOperatorTest, InvalidValueInSplitAttribute) {
outputs.push_back({{1, 2}, {1.f, 2.f}});
outputs.push_back({{3, 2}, {3.f, 4.f, 5.f, 6.f, 7.f, 8.f}});
RunTest(axis, splits, input, outputs, true, "Invalid value in 'split' attribute");
RunTest<float>(axis, splits, input, outputs, true, "Invalid value in 'split' attribute");
}
/*