From c8f1da28c4975d19a89ac2fc45715bb4cd975f2d Mon Sep 17 00:00:00 2001 From: Hariharan Seshadri Date: Mon, 25 Mar 2019 11:55:51 -0700 Subject: [PATCH] tile op: make implementation type-agnostic (and support a few more types) (#688) * Initial commit * PR feedback * PR feedback --- onnxruntime/core/providers/cpu/tensor/tile.cc | 123 +++++++++++---- onnxruntime/core/providers/cpu/tensor/tile.h | 1 - .../test/providers/cpu/tensor/tile_op_test.cc | 144 ++++++++++-------- 3 files changed, 173 insertions(+), 95 deletions(-) diff --git a/onnxruntime/core/providers/cpu/tensor/tile.cc b/onnxruntime/core/providers/cpu/tensor/tile.cc index f4593d2a70..833a2654d0 100644 --- a/onnxruntime/core/providers/cpu/tensor/tile.cc +++ b/onnxruntime/core/providers/cpu/tensor/tile.cc @@ -24,29 +24,82 @@ namespace onnxruntime { ONNX_CPU_OPERATOR_KERNEL( Tile, 6, - KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType()), - Tile); + KernelDefBuilder().TypeConstraint("T", {DataTypeImpl::GetTensorType(), + DataTypeImpl::GetTensorType(), + DataTypeImpl::GetTensorType(), + DataTypeImpl::GetTensorType(), + DataTypeImpl::GetTensorType(), + DataTypeImpl::GetTensorType(), + DataTypeImpl::GetTensorType(), + DataTypeImpl::GetTensorType(), + DataTypeImpl::GetTensorType(), + DataTypeImpl::GetTensorType(), + DataTypeImpl::GetTensorType()}), + Tile); -template <> -Status Tile::Compute(OpKernelContext* ctx) const { +Status TileCoreForFixedSizeTypes(const Tensor& input_tensor, Tensor& output_tensor, const int64_t* repeats, TensorAxisCounters& input_counters, const TensorPitches& output_pitches, size_t element_size) { + const auto& input_shape = input_tensor.Shape().GetDims(); + const size_t dimension_count = input_shape.size(); + + const uint8_t* input = reinterpret_cast(input_tensor.DataRaw()); + uint8_t* output = reinterpret_cast(output_tensor.MutableDataRaw()); + + // some helper variables that will be used along the way + size_t block_size = 0; + int64_t num_repeats = 0; + const uint8_t* copy = nullptr; + const int64_t innermost_dim = input_shape[dimension_count - 1]; + + while (input_counters) { + // Copy the input data over + block_size = innermost_dim * element_size; + memcpy(output, input, block_size); + output += block_size; + input += block_size; + + // Tile data for the innermost axis + copy = output - block_size; + num_repeats = repeats[dimension_count - 1] - 1; + for (int64_t repeat = 0; repeat < num_repeats; ++repeat) { + memcpy(output, copy, block_size); + output += block_size; + } + + // Tile data for other axes + while (input_counters.Increment()) { + ptrdiff_t pitch = output_pitches[input_counters.Axis()] * input_shape[input_counters.Axis()]; + block_size = pitch * element_size; + copy = output - block_size; + num_repeats = repeats[input_counters.Axis()] - 1; + for (int64_t repeat = 0; repeat < num_repeats; ++repeat) { + memcpy(output, copy, block_size); + output += block_size; + } + } + } + return Status::OK(); +} + +Status Tile::Compute(OpKernelContext* ctx) const { const Tensor* tensor_pointer = ctx->Input(0); if (tensor_pointer == nullptr) return Status(common::ONNXRUNTIME, common::FAIL, "Input count of Tile OP mismatch, the first one is empty"); const Tensor& input_tensor = *tensor_pointer; + const auto& input_shape = input_tensor.Shape(); + const size_t input_rank = input_shape.NumDimensions(); tensor_pointer = ctx->Input(1); if (tensor_pointer == nullptr) return Status(common::ONNXRUNTIME, common::FAIL, "Input count of Tile OP mismatch, the second one is empty"); const Tensor& repeats_tensor = *tensor_pointer; - - size_t dimension_count = input_tensor.Shape().NumDimensions(); - + if (input_rank < 1) + return Status(ONNXRUNTIME, INVALID_ARGUMENT, "the tensor to be tiled using Tile OP must be atleast 1 dimensional"); if (repeats_tensor.Shape().NumDimensions() != 1) return Status(ONNXRUNTIME, INVALID_ARGUMENT, "'repeat' input tensor must be 1 dimensional"); - if (size_t(repeats_tensor.Shape().Size()) != input_tensor.Shape().NumDimensions()) + if (size_t(repeats_tensor.Shape().Size()) != input_rank) return Status(ONNXRUNTIME, INVALID_ARGUMENT, "'repeat' input tensor must have the same length as the 'input' tensor"); // Calculate the shape of the output tensor auto* repeats = repeats_tensor.template Data(); - std::vector output_dims = input_tensor.Shape().GetDims(); - for (auto axis = 0; axis < input_tensor.Shape().NumDimensions(); axis++) { + std::vector output_dims = input_shape.GetDims(); + for (auto axis = 0; axis < input_rank; axis++) { output_dims[axis] *= repeats[axis]; } @@ -60,32 +113,36 @@ Status Tile::Compute(OpKernelContext* ctx) const { return Status::OK(); } - auto* output = output_tensor.template MutableData(); - auto* input = input_tensor.template Data(); - - TensorPitches output_pitches(output_tensor); + const auto& dtype = input_tensor.DataType(); TensorAxisCounters input_counters(input_tensor); + TensorPitches output_pitches(output_tensor); - while (input_counters) { - // Copy the input data over - size_t input_pitch = input_tensor.Shape().GetDims().back(); - for (size_t i = 0; i < input_pitch; i++) - *output++ = *input++; + static_assert(sizeof(float) == sizeof(int32_t), "Float and Int32 are of different sizes"); + static_assert(sizeof(double) == sizeof(int64_t), "Double and Int64 are of different sizes"); - // Tile it for the innermost axis - const auto* copy = output - input_tensor.Shape()[dimension_count - 1]; - for (int64_t repeat = (repeats[dimension_count - 1] - 1) * input_pitch; repeat-- > 0;) - *output++ = *copy++; + if (dtype == DataTypeImpl::GetType() || + dtype == DataTypeImpl::GetType() || + dtype == DataTypeImpl::GetType()) + return TileCoreForFixedSizeTypes(input_tensor, output_tensor, repeats, input_counters, output_pitches, sizeof(float)); - // Tile it in the other axes - while (input_counters.Increment()) { - ptrdiff_t pitch = output_pitches[input_counters.Axis()] * input_tensor.Shape()[input_counters.Axis()]; - copy = output - pitch; - for (int64_t repeat = (repeats[input_counters.Axis()] - 1) * pitch; repeat-- > 0;) { - *output++ = *copy++; - } - } - } - return Status::OK(); + else if (dtype == DataTypeImpl::GetType() || + dtype == DataTypeImpl::GetType() || + dtype == DataTypeImpl::GetType()) + return TileCoreForFixedSizeTypes(input_tensor, output_tensor, repeats, input_counters, output_pitches, sizeof(double)); + + else if (dtype == DataTypeImpl::GetType() || + dtype == DataTypeImpl::GetType()) + return TileCoreForFixedSizeTypes(input_tensor, output_tensor, repeats, input_counters, output_pitches, sizeof(int8_t)); + + else if (dtype == DataTypeImpl::GetType() || + dtype == DataTypeImpl::GetType()) + return TileCoreForFixedSizeTypes(input_tensor, output_tensor, repeats, input_counters, output_pitches, sizeof(int16_t)); + + else if (dtype == DataTypeImpl::GetType()) + return TileCoreForFixedSizeTypes(input_tensor, output_tensor, repeats, input_counters, output_pitches, sizeof(bool)); + + // TODO: Support 'string' and 'float16' types for completeness + else + ORT_THROW("Tile doesn't have an implementation yet for the type: ", dtype); } } // namespace onnxruntime diff --git a/onnxruntime/core/providers/cpu/tensor/tile.h b/onnxruntime/core/providers/cpu/tensor/tile.h index d7cf00097b..e97afb0714 100644 --- a/onnxruntime/core/providers/cpu/tensor/tile.h +++ b/onnxruntime/core/providers/cpu/tensor/tile.h @@ -7,7 +7,6 @@ namespace onnxruntime { -template struct Tile final : OpKernel { Tile(const OpKernelInfo& info) : OpKernel(info) { } diff --git a/onnxruntime/test/providers/cpu/tensor/tile_op_test.cc b/onnxruntime/test/providers/cpu/tensor/tile_op_test.cc index 09b0a0d6bf..a29dab9dd0 100644 --- a/onnxruntime/test/providers/cpu/tensor/tile_op_test.cc +++ b/onnxruntime/test/providers/cpu/tensor/tile_op_test.cc @@ -7,82 +7,104 @@ namespace onnxruntime { namespace test { -TEST(TensorOpTest, Tile1DWithZeroRepeats) { +template +void RunTest(std::initializer_list input, + std::initializer_list input_dims, + std::initializer_list repeat, + std::initializer_list repeat_dims, + std::initializer_list output, + std::initializer_list output_dims) { OpTester test("Tile"); - - test.AddInput("input", {3}, {1.0f, 2.0f, 3.0f}); - test.AddInput("repeats", {1}, {0}); - test.AddOutput("output", {0}, {}); + test.AddInput("input", input_dims, input); + test.AddInput("repeats", repeat_dims, repeat); + test.AddOutput("output", output_dims, output); test.Run(); } -TEST(TensorOpTest, Tile2DWithZeroRepeats) { - OpTester test("Tile"); +template +void RunTestWrapper() { + // Tile1DWithZeroRepeats + RunTest({1, 2, 3}, {3}, {0}, {1}, {}, {0}); - test.AddInput("input", {2, 2}, - {11.0f, 12.0f, - 21.0f, 22.0f}); - test.AddInput("repeats", {2}, {2, 0}); - test.AddOutput("output", {4, 0}, {}); - test.Run(); + // Tile2DWithZeroRepeats + RunTest({11, 12, 21, 22}, {2, 2}, {2, 0}, {2}, {}, {4, 0}); + + // Tile1D + RunTest({1, 2, 3}, {3}, {3}, {1}, {1, 2, 3, 1, 2, 3, 1, 2, 3}, {9}); + + // Tile2D_1Axis + RunTest({11, 12, 21, 22}, {2, 2}, {2, 1}, {2}, {11, 12, 21, 22, 11, 12, 21, 22}, {4, 2}); + + // Tile2D_2Axes + RunTest({11, 12, 21, 22}, {2, 2}, {2, 2}, {2}, {11, 12, 11, 12, 21, 22, 21, 22, 11, 12, 11, 12, 21, 22, 21, 22}, {4, 4}); + + // Tile3D + RunTest({111, 112, 113, 122, 123, 124}, {2, 1, 3}, {1, 2, 1}, {3}, {111, 112, 113, 111, 112, 113, 122, 123, 124, 122, 123, 124}, {2, 2, 3}); } -TEST(TensorOpTest, Tile1D) { - OpTester test("Tile"); +template <> +void RunTestWrapper() { + // Tile1DWithZeroRepeats + RunTest({true, false, true}, {3}, {0}, {1}, {}, {0}); - test.AddInput("input", {3}, {1.0f, 2.0f, 3.0f}); - test.AddInput("repeats", {1}, {3}); - test.AddOutput("output", {9}, {1.0f, 2.0f, 3.0f, 1.0f, 2.0f, 3.0f, 1.0f, 2.0f, 3.0f}); - test.Run(); + // Tile2DWithZeroRepeats + RunTest({true, false, true, false}, {2, 2}, {2, 0}, {2}, {}, {4, 0}); + + // Tile1D + RunTest({true, false, true}, {3}, {3}, {1}, {true, false, true, true, false, true, true, false, true}, {9}); + + // Tile2D_1Axis + RunTest({true, false, true, false}, {2, 2}, {2, 1}, {2}, {true, false, true, false, true, false, true, false}, {4, 2}); + + // Tile2D_2Axes + RunTest({true, false, true, false}, {2, 2}, {2, 2}, {2}, {true, false, true, false, true, false, true, false, true, false, true, false, true, false, true, false}, {4, 4}); + + // Tile3D + RunTest({true, false, true, false, true, false}, {2, 1, 3}, {1, 2, 1}, {3}, {true, false, true, true, false, true, false, true, false, false, true, false}, {2, 2, 3}); } -TEST(TensorOpTest, Tile2D_1Axis) { - OpTester test("Tile"); - - test.AddInput("input", {2, 2}, - {11.0f, 12.0f, - 21.0f, 22.0f}); - test.AddInput("repeats", {2}, {2, 1}); - test.AddOutput("output", {4, 2}, - {11.0f, 12.0f, - 21.0f, 22.0f, - 11.0f, 12.0f, - 21.0f, 22.0f}); - - test.Run(); +TEST(TensorOpTest, TileFloatType) { + RunTestWrapper(); } -TEST(TensorOpTest, Tile2D_2Axes) { - OpTester test("Tile"); - - test.AddInput("input", {2, 2}, - {11.0f, 12.0f, - 21.0f, 22.0f}); - test.AddInput("repeats", {2}, {2, 2}); - test.AddOutput("output", {4, 4}, - {11.0f, 12.0f, 11.0f, 12.0f, - 21.0f, 22.0f, 21.0f, 22.0f, - 11.0f, 12.0f, 11.0f, 12.0f, - 21.0f, 22.0f, 21.0f, 22.0f}); - - test.Run(); +TEST(TensorOpTest, TileDoubleType) { + RunTestWrapper(); } -TEST(TensorOpTest, Tile3D) { - OpTester test("Tile"); - - test.AddInput("input", {2, 1, 3}, - {111.0f, 112.0f, 113.0f, - 211.0f, 212.0f, 213.0f}); - test.AddInput("repeats", {3}, {1, 2, 1}); - test.AddOutput("output", {2, 2, 3}, - {111.0f, 112.0f, 113.0f, - 111.0f, 112.0f, 113.0f, - - 211.0f, 212.0f, 213.0f, - 211.0f, 212.0f, 213.0f}); - test.Run(); +TEST(TensorOpTest, TileInt8Type) { + RunTestWrapper(); } +TEST(TensorOpTest, TileUint8Type) { + RunTestWrapper(); +} + +TEST(TensorOpTest, TileInt16Type) { + RunTestWrapper(); +} + +TEST(TensorOpTest, TileUint16Type) { + RunTestWrapper(); +} + +TEST(TensorOpTest, TileInt32Type) { + RunTestWrapper(); +} + +TEST(TensorOpTest, TileUint32Type) { + RunTestWrapper(); +} + +TEST(TensorOpTest, TileInt64Type) { + RunTestWrapper(); +} + +TEST(TensorOpTest, TileUint64Type) { + RunTestWrapper(); +} + +TEST(TensorOpTest, TileBoolType) { + RunTestWrapper(); +} } // namespace test } // namespace onnxruntime