mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-25 22:26:24 +00:00
Support int32_t for Split op (#563)
* Support int32_t for Split op * Support int32_t for Split op
This commit is contained in:
parent
af9c554dd3
commit
b68079fe5d
2 changed files with 72 additions and 44 deletions
|
|
@ -17,6 +17,7 @@ ONNX_CPU_OPERATOR_KERNEL(
|
|||
std::vector<MLDataType>{
|
||||
DataTypeImpl::GetTensorType<float>(),
|
||||
DataTypeImpl::GetTensorType<double>(),
|
||||
DataTypeImpl::GetTensorType<int32_t>(),
|
||||
}),
|
||||
Split);
|
||||
|
||||
|
|
@ -28,6 +29,8 @@ Status Split::Compute(OpKernelContext* context) const {
|
|||
|
||||
if (data_type == DataTypeImpl::GetType<float>())
|
||||
status = ComputeImpl<float>(*context, input);
|
||||
else if (data_type == DataTypeImpl::GetType<int32_t>())
|
||||
status = ComputeImpl<int32_t>(*context, input);
|
||||
else if (data_type == DataTypeImpl::GetType<double>()) {
|
||||
/* Need to update CopyMatrix to support double...
|
||||
status = ComputeImpl<double>(*context, input); */
|
||||
|
|
|
|||
|
|
@ -7,12 +7,15 @@
|
|||
namespace onnxruntime {
|
||||
namespace test {
|
||||
|
||||
using ShapeAndData = std::pair<const std::vector<int64_t>, const std::vector<float>>;
|
||||
template<class T> using ShapeAndData = std::pair<const std::vector<int64_t>, const std::vector<T>>;
|
||||
|
||||
using ShapeAndFloatData = ShapeAndData<float>;
|
||||
using ShapeAndInt32Data = ShapeAndData<int32_t>;
|
||||
using ExpectResult = OpTester::ExpectResult;
|
||||
|
||||
void RunTest(int64_t axis, const std::vector<int64_t> split_sizes, const ShapeAndData& input,
|
||||
const std::vector<ShapeAndData>& outputs,
|
||||
bool expect_failure = false, const std::string& err_msg = {}) {
|
||||
template<typename T> void RunTest(int64_t axis, const std::vector<int64_t> split_sizes, const ShapeAndData<T>& input,
|
||||
const std::vector<ShapeAndData<T>>& outputs,
|
||||
bool expect_failure = false, const std::string& err_msg = {}) {
|
||||
OpTester test("Split");
|
||||
|
||||
test.AddAttribute("axis", axis);
|
||||
|
|
@ -20,7 +23,7 @@ void RunTest(int64_t axis, const std::vector<int64_t> split_sizes, const ShapeAn
|
|||
if (!split_sizes.empty())
|
||||
test.AddAttribute("split", split_sizes);
|
||||
|
||||
test.AddInput<float>("input", input.first, input.second);
|
||||
test.AddInput<T>("input", input.first, input.second);
|
||||
|
||||
int i = 0;
|
||||
for (auto& output : outputs) {
|
||||
|
|
@ -28,7 +31,7 @@ void RunTest(int64_t axis, const std::vector<int64_t> split_sizes, const ShapeAn
|
|||
auto& data = output.second;
|
||||
std::ostringstream oss;
|
||||
oss << "output" << i++;
|
||||
test.AddOutput<float>(oss.str().c_str(), shape, data);
|
||||
test.AddOutput<T>(oss.str().c_str(), shape, data);
|
||||
}
|
||||
|
||||
test.Run(expect_failure ? ExpectResult::kExpectFailure : ExpectResult::kExpectSuccess, err_msg);
|
||||
|
|
@ -36,10 +39,10 @@ void RunTest(int64_t axis, const std::vector<int64_t> split_sizes, const ShapeAn
|
|||
|
||||
TEST(SplitOperatorTest, Axis0EqualSplit) {
|
||||
const int64_t axis = 0;
|
||||
std::vector<ShapeAndData> outputs;
|
||||
std::vector<ShapeAndFloatData> outputs;
|
||||
|
||||
// input shape and data
|
||||
ShapeAndData input = {{4, 2}, // shape
|
||||
ShapeAndFloatData input = {{4, 2}, // shape
|
||||
{1.f, 2.f,
|
||||
3.f, 4.f,
|
||||
5.f, 6.f,
|
||||
|
|
@ -53,15 +56,37 @@ TEST(SplitOperatorTest, Axis0EqualSplit) {
|
|||
{5.f, 6.f,
|
||||
7.f, 8.f}});
|
||||
|
||||
RunTest(axis, {}, input, outputs);
|
||||
RunTest<float>(axis, {}, input, outputs);
|
||||
}
|
||||
|
||||
TEST(SplitOperatorTest, Axis0EqualSplitInt32) {
|
||||
const int64_t axis = 0;
|
||||
std::vector<ShapeAndInt32Data> outputs;
|
||||
|
||||
// input shape and data
|
||||
ShapeAndInt32Data input = {{4, 2}, // shape
|
||||
{1, 2,
|
||||
3, 4,
|
||||
5, 6,
|
||||
7, 8}};
|
||||
|
||||
outputs.push_back({{2, 2},
|
||||
{1, 2,
|
||||
3, 4}});
|
||||
|
||||
outputs.push_back({{2, 2},
|
||||
{5, 6,
|
||||
7, 8}});
|
||||
|
||||
RunTest<int32_t>(axis, {}, input, outputs);
|
||||
}
|
||||
|
||||
TEST(SplitOperatorTest, Axis0UnequalSplit) {
|
||||
const int64_t axis = 0;
|
||||
std::vector<ShapeAndData> outputs;
|
||||
std::vector<ShapeAndFloatData> outputs;
|
||||
|
||||
// input shape and data
|
||||
ShapeAndData input = {{4, 2}, // shape
|
||||
ShapeAndFloatData input = {{4, 2}, // shape
|
||||
{1.f, 2.f,
|
||||
3.f, 4.f,
|
||||
5.f, 6.f,
|
||||
|
|
@ -76,15 +101,15 @@ TEST(SplitOperatorTest, Axis0UnequalSplit) {
|
|||
5.f, 6.f,
|
||||
7.f, 8.f}});
|
||||
|
||||
RunTest(axis, splits, input, outputs);
|
||||
RunTest<float>(axis, splits, input, outputs);
|
||||
}
|
||||
|
||||
TEST(SplitOperatorTest, Axis1EqualSplit) {
|
||||
const int64_t axis = 1;
|
||||
std::vector<ShapeAndData> outputs;
|
||||
std::vector<ShapeAndFloatData> outputs;
|
||||
|
||||
// input shape and data
|
||||
ShapeAndData input = {{2, 4},
|
||||
ShapeAndFloatData input = {{2, 4},
|
||||
{1.f, 2.f, 3.f, 4.f,
|
||||
5.f, 6.f, 7.f, 8.f}};
|
||||
|
||||
|
|
@ -96,15 +121,15 @@ TEST(SplitOperatorTest, Axis1EqualSplit) {
|
|||
{3.f, 4.f,
|
||||
7.f, 8.f}});
|
||||
|
||||
RunTest(axis, {}, input, outputs);
|
||||
RunTest<float>(axis, {}, input, outputs);
|
||||
}
|
||||
|
||||
TEST(SplitOperatorTest, Axis1UnequalSplit) {
|
||||
const int64_t axis = 1;
|
||||
std::vector<ShapeAndData> outputs;
|
||||
std::vector<ShapeAndFloatData> outputs;
|
||||
|
||||
// input shape and data
|
||||
ShapeAndData input = {{2, 4},
|
||||
ShapeAndFloatData input = {{2, 4},
|
||||
{1.f, 2.f, 3.f, 4.f,
|
||||
5.f, 6.f, 7.f, 8.f}};
|
||||
|
||||
|
|
@ -118,10 +143,10 @@ TEST(SplitOperatorTest, Axis1UnequalSplit) {
|
|||
{4.f,
|
||||
8.f}});
|
||||
|
||||
RunTest(axis, splits, input, outputs);
|
||||
RunTest<float>(axis, splits, input, outputs);
|
||||
}
|
||||
|
||||
ShapeAndData CreateInput(std::vector<int64_t> shape) {
|
||||
ShapeAndFloatData CreateInput(std::vector<int64_t> shape) {
|
||||
auto size = TensorShape(shape).Size();
|
||||
|
||||
float i = 0.f, increment = 1.f;
|
||||
|
|
@ -129,16 +154,16 @@ ShapeAndData CreateInput(std::vector<int64_t> shape) {
|
|||
std::vector<float> data;
|
||||
std::generate_n(std::back_inserter(data), size, [&]() { return i += increment; });
|
||||
|
||||
ShapeAndData input = {shape, data};
|
||||
ShapeAndFloatData input = {shape, data};
|
||||
|
||||
return input;
|
||||
}
|
||||
|
||||
TEST(SplitOperatorTest, Axis2EqualSplit) {
|
||||
const int64_t axis = 2;
|
||||
std::vector<ShapeAndData> outputs;
|
||||
std::vector<ShapeAndFloatData> outputs;
|
||||
|
||||
ShapeAndData input = CreateInput({2, 2, 6});
|
||||
ShapeAndFloatData input = CreateInput({2, 2, 6});
|
||||
|
||||
outputs.push_back({{2, 2, 2},
|
||||
{1.f, 2.f,
|
||||
|
|
@ -161,14 +186,14 @@ TEST(SplitOperatorTest, Axis2EqualSplit) {
|
|||
17.f, 18.f,
|
||||
23.f, 24.f}});
|
||||
|
||||
RunTest(axis, {}, input, outputs);
|
||||
RunTest<float>(axis, {}, input, outputs);
|
||||
}
|
||||
|
||||
TEST(SplitOperatorTest, Axis2UnequalSplit) {
|
||||
const int64_t axis = 2;
|
||||
std::vector<ShapeAndData> outputs;
|
||||
std::vector<ShapeAndFloatData> outputs;
|
||||
|
||||
ShapeAndData input = CreateInput({2, 2, 6});
|
||||
ShapeAndFloatData input = CreateInput({2, 2, 6});
|
||||
|
||||
std::vector<int64_t> splits{1, 2, 3};
|
||||
|
||||
|
|
@ -193,15 +218,15 @@ TEST(SplitOperatorTest, Axis2UnequalSplit) {
|
|||
16.f, 17.f, 18.f,
|
||||
22.f, 23.f, 24.f}});
|
||||
|
||||
RunTest(axis, splits, input, outputs);
|
||||
RunTest<float>(axis, splits, input, outputs);
|
||||
}
|
||||
|
||||
// test a split of a dimension that has leading and trailing dimensions
|
||||
TEST(SplitOperatorTest, Axis1SplitMiddleDimensionEqually) {
|
||||
const int64_t axis = 1;
|
||||
std::vector<ShapeAndData> outputs;
|
||||
std::vector<ShapeAndFloatData> outputs;
|
||||
|
||||
ShapeAndData input = CreateInput({2, 4, 4});
|
||||
ShapeAndFloatData input = CreateInput({2, 4, 4});
|
||||
|
||||
outputs.push_back({{2, 2, 4},
|
||||
{1.f, 2.f, 3.f, 4.f,
|
||||
|
|
@ -217,15 +242,15 @@ TEST(SplitOperatorTest, Axis1SplitMiddleDimensionEqually) {
|
|||
25.f, 26.f, 27.f, 28.f,
|
||||
29.f, 30.f, 31.f, 32.f}});
|
||||
|
||||
RunTest(axis, {}, input, outputs);
|
||||
RunTest<float>(axis, {}, input, outputs);
|
||||
}
|
||||
|
||||
// test a split of a dimension that has leading and trailing dimensions
|
||||
TEST(SplitOperatorTest, Axis1SplitMiddleDimensionUnequally) {
|
||||
const int64_t axis = 1;
|
||||
std::vector<ShapeAndData> outputs;
|
||||
std::vector<ShapeAndFloatData> outputs;
|
||||
|
||||
ShapeAndData input = CreateInput({2, 4, 4});
|
||||
ShapeAndFloatData input = CreateInput({2, 4, 4});
|
||||
|
||||
std::vector<int64_t> splits{1, 3};
|
||||
|
||||
|
|
@ -243,15 +268,15 @@ TEST(SplitOperatorTest, Axis1SplitMiddleDimensionUnequally) {
|
|||
25.f, 26.f, 27.f, 28.f,
|
||||
29.f, 30.f, 31.f, 32.f}});
|
||||
|
||||
RunTest(axis, splits, input, outputs);
|
||||
RunTest<float>(axis, splits, input, outputs);
|
||||
}
|
||||
|
||||
TEST(SplitOperatorTest, NegativeAxis) {
|
||||
const int64_t axis = -1; // split last axis equally
|
||||
std::vector<ShapeAndData> outputs;
|
||||
std::vector<ShapeAndFloatData> outputs;
|
||||
|
||||
// input shape and data
|
||||
ShapeAndData input = {{2, 4},
|
||||
ShapeAndFloatData input = {{2, 4},
|
||||
{1.f, 2.f, 3.f, 4.f,
|
||||
5.f, 6.f, 7.f, 8.f}};
|
||||
|
||||
|
|
@ -263,15 +288,15 @@ TEST(SplitOperatorTest, NegativeAxis) {
|
|||
{3.f, 4.f,
|
||||
7.f, 8.f}});
|
||||
|
||||
RunTest(axis, {}, input, outputs);
|
||||
RunTest<float>(axis, {}, input, outputs);
|
||||
}
|
||||
|
||||
TEST(SplitOperatorTest, InvalidAxis) {
|
||||
const int64_t axis = 2;
|
||||
std::vector<ShapeAndData> outputs;
|
||||
std::vector<ShapeAndFloatData> outputs;
|
||||
|
||||
// input shape and data
|
||||
ShapeAndData input = {{4, 2}, // shape
|
||||
ShapeAndFloatData input = {{4, 2}, // shape
|
||||
{1.f, 2.f,
|
||||
3.f, 4.f,
|
||||
5.f, 6.f,
|
||||
|
|
@ -279,16 +304,16 @@ TEST(SplitOperatorTest, InvalidAxis) {
|
|||
|
||||
outputs.push_back({{1}, {0.f}});
|
||||
|
||||
RunTest(axis, {}, input, outputs, true, "Invalid value of attribute 'axis'");
|
||||
RunTest<float>(axis, {}, input, outputs, true, "Invalid value of attribute 'axis'");
|
||||
}
|
||||
|
||||
// sum of values in splits is too small
|
||||
TEST(SplitOperatorTest, SplitAttributeSumTooSmall) {
|
||||
const int64_t axis = 0;
|
||||
std::vector<ShapeAndData> outputs;
|
||||
std::vector<ShapeAndFloatData> outputs;
|
||||
|
||||
// input shape and data
|
||||
ShapeAndData input = {{4, 2}, // shape
|
||||
ShapeAndFloatData input = {{4, 2}, // shape
|
||||
{1.f, 2.f,
|
||||
3.f, 4.f,
|
||||
5.f, 6.f,
|
||||
|
|
@ -299,15 +324,15 @@ TEST(SplitOperatorTest, SplitAttributeSumTooSmall) {
|
|||
outputs.push_back({{1, 2}, {1.f, 2.f}});
|
||||
outputs.push_back({{2, 2}, {3.f, 4.f, 5.f, 6.f}});
|
||||
|
||||
RunTest(axis, splits, input, outputs, true, "Cannot split using values in 'split' attribute");
|
||||
RunTest<float>(axis, splits, input, outputs, true, "Cannot split using values in 'split' attribute");
|
||||
}
|
||||
|
||||
TEST(SplitOperatorTest, InvalidValueInSplitAttribute) {
|
||||
const int64_t axis = 0;
|
||||
std::vector<ShapeAndData> outputs;
|
||||
std::vector<ShapeAndFloatData> outputs;
|
||||
|
||||
// input shape and data
|
||||
ShapeAndData input = {{4, 2}, // shape
|
||||
ShapeAndFloatData input = {{4, 2}, // shape
|
||||
{1.f, 2.f,
|
||||
3.f, 4.f,
|
||||
5.f, 6.f,
|
||||
|
|
@ -317,7 +342,7 @@ TEST(SplitOperatorTest, InvalidValueInSplitAttribute) {
|
|||
outputs.push_back({{1, 2}, {1.f, 2.f}});
|
||||
outputs.push_back({{3, 2}, {3.f, 4.f, 5.f, 6.f, 7.f, 8.f}});
|
||||
|
||||
RunTest(axis, splits, input, outputs, true, "Invalid value in 'split' attribute");
|
||||
RunTest<float>(axis, splits, input, outputs, true, "Invalid value in 'split' attribute");
|
||||
}
|
||||
|
||||
/*
|
||||
|
|
|
|||
Loading…
Reference in a new issue