Fix Seg fault when repeats input contain a 0 (#336)

* Fix Seg fault when repeats input contain a 0

* refine
This commit is contained in:
Ashwin Kumar 2019-01-15 21:34:04 -08:00 committed by GitHub
parent f678f58750
commit 95b8941e9d
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 30 additions and 1 deletions

View file

@ -46,11 +46,20 @@ Status Tile<float>::Compute(OpKernelContext* ctx) const {
// Calculate the shape of the output tensor
auto* repeats = repeats_tensor.template Data<int64_t>();
std::vector<int64_t> output_dims = input_tensor.Shape().GetDims();
for (auto axis = 0; axis < input_tensor.Shape().NumDimensions(); axis++)
for (auto axis = 0; axis < input_tensor.Shape().NumDimensions(); axis++) {
output_dims[axis] *= repeats[axis];
}
TensorShape outputShape(output_dims);
auto& output_tensor = *ctx->Output(0, outputShape);
// Repeat tensor input can have 0 as a valid value
// check if the computed outputshape size is 0 and
// return an empty tensor if so.
if (outputShape.Size() == 0) {
return Status::OK();
}
auto* output = output_tensor.template MutableData<float>();
auto* input = input_tensor.template Data<float>();

View file

@ -7,6 +7,26 @@
namespace onnxruntime {
namespace test {
TEST(TensorOpTest, Tile1DWithZeroRepeats) {
OpTester test("Tile");
test.AddInput<float>("input", {3}, {1.0f, 2.0f, 3.0f});
test.AddInput<int64_t>("repeats", {1}, {0});
test.AddOutput<float>("output", {0}, {});
test.Run();
}
TEST(TensorOpTest, Tile2DWithZeroRepeats) {
OpTester test("Tile");
test.AddInput<float>("input", {2, 2},
{11.0f, 12.0f,
21.0f, 22.0f});
test.AddInput<int64_t>("repeats", {2}, {2, 0});
test.AddOutput<float>("output", {4, 0}, {});
test.Run();
}
TEST(TensorOpTest, Tile1D) {
OpTester test("Tile");