From 95b8941e9d534bdec2a445d7616e003ba3c953ee Mon Sep 17 00:00:00 2001 From: Ashwin Kumar <33531737+ashku-ms@users.noreply.github.com> Date: Tue, 15 Jan 2019 21:34:04 -0800 Subject: [PATCH] Fix Seg fault when repeats input contain a 0 (#336) * Fix Seg fault when repeats input contain a 0 * refine --- onnxruntime/core/providers/cpu/tensor/tile.cc | 11 +++++++++- .../test/providers/cpu/tensor/tile_op_test.cc | 20 +++++++++++++++++++ 2 files changed, 30 insertions(+), 1 deletion(-) diff --git a/onnxruntime/core/providers/cpu/tensor/tile.cc b/onnxruntime/core/providers/cpu/tensor/tile.cc index b478541b32..f4593d2a70 100644 --- a/onnxruntime/core/providers/cpu/tensor/tile.cc +++ b/onnxruntime/core/providers/cpu/tensor/tile.cc @@ -46,11 +46,20 @@ Status Tile::Compute(OpKernelContext* ctx) const { // Calculate the shape of the output tensor auto* repeats = repeats_tensor.template Data(); std::vector output_dims = input_tensor.Shape().GetDims(); - for (auto axis = 0; axis < input_tensor.Shape().NumDimensions(); axis++) + for (auto axis = 0; axis < input_tensor.Shape().NumDimensions(); axis++) { output_dims[axis] *= repeats[axis]; + } + TensorShape outputShape(output_dims); auto& output_tensor = *ctx->Output(0, outputShape); + // Repeat tensor input can have 0 as a valid value + // check if the computed outputshape size is 0 and + // return an empty tensor if so. + if (outputShape.Size() == 0) { + return Status::OK(); + } + auto* output = output_tensor.template MutableData(); auto* input = input_tensor.template Data(); diff --git a/onnxruntime/test/providers/cpu/tensor/tile_op_test.cc b/onnxruntime/test/providers/cpu/tensor/tile_op_test.cc index cd053f1921..09b0a0d6bf 100644 --- a/onnxruntime/test/providers/cpu/tensor/tile_op_test.cc +++ b/onnxruntime/test/providers/cpu/tensor/tile_op_test.cc @@ -7,6 +7,26 @@ namespace onnxruntime { namespace test { +TEST(TensorOpTest, Tile1DWithZeroRepeats) { + OpTester test("Tile"); + + test.AddInput("input", {3}, {1.0f, 2.0f, 3.0f}); + test.AddInput("repeats", {1}, {0}); + test.AddOutput("output", {0}, {}); + test.Run(); +} + +TEST(TensorOpTest, Tile2DWithZeroRepeats) { + OpTester test("Tile"); + + test.AddInput("input", {2, 2}, + {11.0f, 12.0f, + 21.0f, 22.0f}); + test.AddInput("repeats", {2}, {2, 0}); + test.AddOutput("output", {4, 0}, {}); + test.Run(); +} + TEST(TensorOpTest, Tile1D) { OpTester test("Tile");