diff --git a/onnxruntime/core/framework/transpose_helper.cc b/onnxruntime/core/framework/transpose_helper.cc new file mode 100644 index 0000000000..a5535d919b --- /dev/null +++ b/onnxruntime/core/framework/transpose_helper.cc @@ -0,0 +1,313 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/framework/transpose_helper.h" +#include "core/mlas/inc/mlas.h" + +namespace onnxruntime { + +template +struct has_mlas_transpose : std::false_type {}; + +template <> +struct has_mlas_transpose : std::true_type {}; + +template <> +struct has_mlas_transpose : std::true_type {}; + +// moving a single axis outwards where the read/write size is a power of 2 and between 8 and 64 bits. +template +typename std::enable_if::value, void>::type SimpleTransposeSingleAxisOutwards( + const T* input_data, T* output_data, int64_t num_loops, int64_t num_writers, int64_t writes_per_loop, + int64_t writes_per_writer_per_loop) { + const T* end; + for (int64_t l = 0; l < num_loops; ++l) { + T* output_for_first_writer = output_data; + + for (auto wwpl = 0; wwpl < writes_per_writer_per_loop; ++wwpl) { + T* output_for_current_writer = output_for_first_writer; + + end = input_data + num_writers; + for (; input_data != end;) { + *output_for_current_writer = *input_data++; + + // skip to output position for next writer + output_for_current_writer += writes_per_writer_per_loop; + } + + ++output_for_first_writer; + } + + output_data += writes_per_loop; + } +} + +template +typename std::enable_if::value, void>::type SimpleTransposeSingleAxisOutwards( + const T* input_data, T* output_data, int64_t num_loops, int64_t num_writers, int64_t writes_per_loop, + int64_t writes_per_writer_per_loop) { + for (int64_t l = 0; l < num_loops; ++l) { + MlasTranspose(input_data, output_data, static_cast(writes_per_writer_per_loop), + static_cast(num_writers)); + input_data += writes_per_loop; + output_data += writes_per_loop; + } +} + +// `input_shape_override` overrides the shape of `input` for compute purposes. +void TransposeSingleAxisOutwards(gsl::span permutations, const Tensor& input, Tensor& output, + size_t from, size_t to, const TensorShape* input_shape_override = nullptr) { + ORT_UNUSED_PARAMETER(permutations); + + const auto& input_shape = input_shape_override ? *input_shape_override : input.Shape(); + const auto& input_dims = input_shape.GetDims(); + + const auto element_size = input.DataType()->Size(); + + const auto* input_data = reinterpret_cast(input.DataRaw()); + auto* output_data = reinterpret_cast(output.MutableDataRaw()); + + auto num_loops = input_shape.SizeToDimension(to); + auto num_writers = input_dims[from]; + auto block_size = input_shape.SizeFromDimension(from + 1); + auto writes_per_loop = int64_t(input_shape.Size() / num_loops / block_size); + auto writes_per_writer_per_loop = int64_t(writes_per_loop / num_writers); + // TODO: check integer overflow + const size_t bytes_per_write = static_cast(block_size) * element_size; + + switch (bytes_per_write) { + case (sizeof(uint8_t)): { + SimpleTransposeSingleAxisOutwards(input_data, output_data, num_loops, num_writers, writes_per_loop, + writes_per_writer_per_loop); + break; + } + case (sizeof(uint16_t)): { + SimpleTransposeSingleAxisOutwards(reinterpret_cast(input_data), + reinterpret_cast(output_data), num_loops, num_writers, + writes_per_loop, writes_per_writer_per_loop); + break; + } + case (sizeof(uint32_t)): { + SimpleTransposeSingleAxisOutwards(reinterpret_cast(input_data), + reinterpret_cast(output_data), num_loops, num_writers, + writes_per_loop, writes_per_writer_per_loop); + break; + } + case (sizeof(uint64_t)): { + SimpleTransposeSingleAxisOutwards(reinterpret_cast(input_data), + reinterpret_cast(output_data), num_loops, num_writers, + writes_per_loop, writes_per_writer_per_loop); + break; + } + default: { + // we need to use memcpy for each block + for (int64_t l = 0; l < num_loops; ++l) { + uint8_t* output_for_first_writer = output_data; + + for (auto wwpl = 0; wwpl < writes_per_writer_per_loop; ++wwpl) { + uint8_t* output_for_current_writer = output_for_first_writer; + + for (int64_t w = 0; w < num_writers; ++w) { + memcpy(output_for_current_writer, input_data, bytes_per_write); + // skip to output position for next writer + output_for_current_writer += (writes_per_writer_per_loop * bytes_per_write); + input_data += bytes_per_write; + } + + output_for_first_writer += bytes_per_write; + } + + output_data += writes_per_loop * bytes_per_write; + } + } + } +} + +template +typename std::enable_if::value, void>::type SimpleTransposeSingleAxisInwards( + const T* input_data, T* output_data, int64_t num_loops, int64_t num_readers, int64_t reads_per_loop, + int64_t reads_per_reader_per_loop) { + T* end; + for (int64_t l = 0; l < num_loops; ++l) { + const T* input_for_first_reader = input_data; + + for (auto rrpl = 0; rrpl < reads_per_reader_per_loop; ++rrpl) { + const T* input_for_current_reader = input_for_first_reader; + + end = output_data + num_readers; + for (; output_data != end;) { + *output_data++ = *input_for_current_reader; + // skip to input position for next reader + input_for_current_reader += reads_per_reader_per_loop; + } + + ++input_for_first_reader; + } + + input_data += reads_per_loop; + } +} + +template +typename std::enable_if::value, void>::type SimpleTransposeSingleAxisInwards( + const T* input_data, T* output_data, int64_t num_loops, int64_t num_readers, int64_t reads_per_loop, + int64_t reads_per_reader_per_loop) { + for (int64_t l = 0; l < num_loops; ++l) { + MlasTranspose(input_data, output_data, static_cast(num_readers), + static_cast(reads_per_reader_per_loop)); + input_data += reads_per_loop; + output_data += reads_per_loop; + } +} + +// moving a single axis inwards where the read/write size is a power of 2 and between 8 and 64 bits. +// `input_shape_override` overrides the shape of `input` for compute purposes. +void TransposeSingleAxisInwards(gsl::span permutations, const Tensor& input, Tensor& output, + size_t from, size_t to, const TensorShape* input_shape_override = nullptr) { + ORT_UNUSED_PARAMETER(permutations); + + const auto& input_shape = input_shape_override ? *input_shape_override : input.Shape(); + const auto& input_dims = input_shape.GetDims(); + + const auto element_size = input.DataType()->Size(); + + const auto* input_data = reinterpret_cast(input.DataRaw()); + auto* output_data = reinterpret_cast(output.MutableDataRaw()); + + auto num_loops = input_shape.SizeToDimension(from); + auto num_readers = input_dims[from]; + auto block_size = input_shape.SizeFromDimension(to + 1); + auto reads_per_loop = int64_t(input_shape.Size() / num_loops / block_size); + auto reads_per_reader_per_loop = int64_t(reads_per_loop / num_readers); + // TODO: check integer overflow + const size_t bytes_per_read = static_cast(block_size) * element_size; + + switch (bytes_per_read) { + case (sizeof(uint8_t)): { + SimpleTransposeSingleAxisInwards(input_data, output_data, num_loops, num_readers, reads_per_loop, + reads_per_reader_per_loop); + break; + } + case (sizeof(uint16_t)): { + SimpleTransposeSingleAxisInwards(reinterpret_cast(input_data), + reinterpret_cast(output_data), num_loops, num_readers, reads_per_loop, + reads_per_reader_per_loop); + break; + } + case (sizeof(uint32_t)): { + SimpleTransposeSingleAxisInwards(reinterpret_cast(input_data), + reinterpret_cast(output_data), num_loops, num_readers, reads_per_loop, + reads_per_reader_per_loop); + break; + } + case (sizeof(uint64_t)): { + SimpleTransposeSingleAxisInwards(reinterpret_cast(input_data), + reinterpret_cast(output_data), num_loops, num_readers, reads_per_loop, + reads_per_reader_per_loop); + break; + } + default: { + // we need to use memcpy for each block + for (int64_t l = 0; l < num_loops; ++l) { + const uint8_t* input_for_first_reader = input_data; + + for (auto rrpl = 0; rrpl < reads_per_reader_per_loop; ++rrpl) { + const uint8_t* input_for_current_reader = input_for_first_reader; + + for (int64_t r = 0; r < num_readers; ++r) { + memcpy(output_data, input_for_current_reader, bytes_per_read); + output_data += bytes_per_read; + + // skip to input position for next reader + input_for_current_reader += (reads_per_reader_per_loop * bytes_per_read); + } + + input_for_first_reader += bytes_per_read; + } + + input_data += reads_per_loop * bytes_per_read; + } + } + } +} + +// `input_shape_override` overrides the shape of `input` for compute purposes. +void SingleAxisTranspose(gsl::span permutations, const Tensor& input, Tensor& output, size_t from, + size_t to, const TensorShape* input_shape_override) { + if (from > to) { + TransposeSingleAxisOutwards(permutations, input, output, from, to, input_shape_override); + } else { + TransposeSingleAxisInwards(permutations, input, output, from, to, input_shape_override); + } +} + +bool IsTransposeMovingSingleAxis(gsl::span permutations, size_t& from, size_t& to) { + // if a single axis moved to an outer dimension, the values should be one lower than the index until the slot the + // axis was moved from, and equal to the index after that. + // e.g. axis 3 moves out to 1 would be: 0, 3, 1, 2, 4 + auto check_moved_outwards = [&permutations](size_t cur, size_t moved_from) { + // we start processing with the slot after the moved one, so the expected value is one less than the index + size_t expected = cur - 1; + for (size_t end = permutations.size(); cur < end; ++cur) { + if (permutations[cur] != expected) { + return false; + } + + // we are at the slot the axis moved from, so do an additional increment before checking the next value + if (cur == moved_from) { + ++expected; + } + + ++expected; + } + + return true; + }; + + // if a single axis moved to an inner dimension, the values should be one higher than the index until the slot the + // axis was moved to, and equal to the index after that. + // e.g. axis 1 moves inwards to 3 would be: 0, 2, 3, 1, 4 + auto check_moved_inwards = [&permutations](size_t cur, size_t& moved_to) { + size_t started_at = cur; + size_t expected = cur + 1; + moved_to = std::numeric_limits::max(); + + for (size_t end = permutations.size(); cur < end; ++cur) { + if (permutations[cur] != expected) { + // if a single axis moved it must have come from the location we started at + if (started_at != permutations[cur]) { + return false; + } + + moved_to = cur; + } else { + ++expected; + } + } + + return moved_to != std::numeric_limits::max(); + }; + + bool single_axis_moved = false; + // check axis moving outwards (earlier entry in permutations) + for (size_t i = 0, end = permutations.size(); i < end; ++i) { + size_t axis = permutations[i]; + + if (axis != i) { + if (check_moved_outwards(i + 1, axis)) { + single_axis_moved = true; + to = i; + from = axis; + } else if (check_moved_inwards(i, to)) { + single_axis_moved = true; + from = i; + } + + break; + } + } + + return single_axis_moved; +} + +} // namespace onnxruntime \ No newline at end of file diff --git a/onnxruntime/core/framework/transpose_helper.h b/onnxruntime/core/framework/transpose_helper.h new file mode 100644 index 0000000000..99e3dd9a5a --- /dev/null +++ b/onnxruntime/core/framework/transpose_helper.h @@ -0,0 +1,45 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +/* +This file contains optimizations for moving a single axis either inwards or outwards. + +If moving outwards we can use a single reader and multiple writers. The number of writers is equal to the value of +the axis being moved. + + e.g. if the input is NHWC with shape {N, 300, 300, 3}, we can transpose to NCHW by reading once and having + one writer for each of the 3 channels at a different offset in the output, updating the offset for each item + in the batch of N. + +Similarly if one axis is moving inwards we can use a single writer and multiple readers. The number of readers is equal +to the value of the axis being moved. + + e.g. if the input is NCHW with shape {N, 3, 300, 300}, we can transpose to NHWC by writing once using one reader for + each of the 3 channels at a different offset in the input, updating the read offset for each item in the batch + of N. + +This can be generalized for any input where only one axis is being moved, with the block size for each read/write +being dependent on which axis is moving, what direction it's moving in, and where it's moving to. + +We use simple pointer arithmetic if the size of each read/write is a power of 2 and between 8 and 64 bits. +We use memcpy if the block size is larger. + +We fall back to the default implementation in all other cases, and if the input is std::string. +*/ + +#include + +#include "core/common/common.h" +#include "core/common/inlined_containers.h" +#include "core/framework/tensor_shape.h" +#include "core/framework/tensor.h" + +#include "gsl/gsl" + +namespace onnxruntime { +bool IsTransposeMovingSingleAxis(gsl::span permutations, size_t& from, size_t& to); +void SingleAxisTranspose(gsl::span permutations, const Tensor& input, Tensor& output, size_t from, + size_t to, const TensorShape* input_shape_override = nullptr); +} // namespace onnxruntime \ No newline at end of file diff --git a/onnxruntime/core/providers/cpu/tensor/transpose.cc b/onnxruntime/core/providers/cpu/tensor/transpose.cc index f7c956daf5..849a783bd2 100644 --- a/onnxruntime/core/providers/cpu/tensor/transpose.cc +++ b/onnxruntime/core/providers/cpu/tensor/transpose.cc @@ -5,6 +5,7 @@ #include "core/framework/element_type_lists.h" #include "core/framework/utils.h" +#include "core/framework/transpose_helper.h" #include "core/framework/op_kernel_type_control_utils.h" #include "core/mlas/inc/mlas.h" #include "core/providers/op_kernel_type_control.h" @@ -323,345 +324,6 @@ static Status DoUntypedTranspose(const gsl::span& permutations, co return status; } -/* -Optimizations for moving a single axis either inwards or outwards. - -If moving outwards we can use a single reader and multiple writers. The number of writers is equal to the value of -the axis being moved. - - e.g. if the input is NHWC with shape {N, 300, 300, 3}, we can transpose to NCHW by reading once and having - one writer for each of the 3 channels at a different offset in the output, updating the offset for each item - in the batch of N. - -Similarly if one axis is moving inwards we can use a single writer and multiple readers. The number of readers is equal -to the value of the axis being moved. - - e.g. if the input is NCHW with shape {N, 3, 300, 300}, we can transpose to NHWC by writing once using one reader for - each of the 3 channels at a different offset in the input, updating the read offset for each item in the batch - of N. - -This can be generalized for any input where only one axis is being moved, with the block size for each read/write -being dependent on which axis is moving, what direction it's moving in, and where it's moving to. - -We use simple pointer arithmetic if the size of each read/write is a power of 2 and between 8 and 64 bits. -We use memcpy if the block size is larger. - -We fall back to the default implementation in all other cases, and if the input is std::string. -*/ - -namespace { - -template -struct has_mlas_transpose : std::false_type {}; - -template <> -struct has_mlas_transpose : std::true_type {}; - -template <> -struct has_mlas_transpose : std::true_type {}; - -// moving a single axis outwards where the read/write size is a power of 2 and between 8 and 64 bits. -template -typename std::enable_if::value, void>::type -SimpleTransposeSingleAxisOutwards(const T* input_data, T* output_data, - int64_t num_loops, int64_t num_writers, - int64_t writes_per_loop, int64_t writes_per_writer_per_loop) { - const T* end; - for (int64_t l = 0; l < num_loops; ++l) { - T* output_for_first_writer = output_data; - - for (auto wwpl = 0; wwpl < writes_per_writer_per_loop; ++wwpl) { - T* output_for_current_writer = output_for_first_writer; - - end = input_data + num_writers; - for (; input_data != end;) { - *output_for_current_writer = *input_data++; - - // skip to output position for next writer - output_for_current_writer += writes_per_writer_per_loop; - } - - ++output_for_first_writer; - } - - output_data += writes_per_loop; - } -} - -template -typename std::enable_if::value, void>::type -SimpleTransposeSingleAxisOutwards(const T* input_data, T* output_data, - int64_t num_loops, int64_t num_writers, - int64_t writes_per_loop, int64_t writes_per_writer_per_loop) { - for (int64_t l = 0; l < num_loops; ++l) { - MlasTranspose(input_data, - output_data, - static_cast(writes_per_writer_per_loop), - static_cast(num_writers)); - input_data += writes_per_loop; - output_data += writes_per_loop; - } -} - -// `input_shape_override` overrides the shape of `input` for compute purposes. -void TransposeSingleAxisOutwards(const gsl::span& permutations, const Tensor& input, Tensor& output, - int64_t from, int64_t to, const TensorShape* input_shape_override = nullptr) { - ORT_UNUSED_PARAMETER(permutations); - - const auto& input_shape = input_shape_override ? *input_shape_override : input.Shape(); - const auto& input_dims = input_shape.GetDims(); - - const auto element_size = input.DataType()->Size(); - - const auto* input_data = reinterpret_cast(input.DataRaw()); - auto* output_data = reinterpret_cast(output.MutableDataRaw()); - - auto num_loops = input_shape.SizeToDimension(to); - auto num_writers = input_dims[from]; - auto block_size = input_shape.SizeFromDimension(from + 1); - auto writes_per_loop = int64_t(input_shape.Size() / num_loops / block_size); - auto writes_per_writer_per_loop = int64_t(writes_per_loop / num_writers); - const int64_t bytes_per_write = block_size * element_size; - - switch (bytes_per_write) { - case (sizeof(uint8_t)): { - SimpleTransposeSingleAxisOutwards(input_data, output_data, - num_loops, num_writers, writes_per_loop, writes_per_writer_per_loop); - break; - } - case (sizeof(uint16_t)): { - SimpleTransposeSingleAxisOutwards(reinterpret_cast(input_data), - reinterpret_cast(output_data), - num_loops, num_writers, writes_per_loop, writes_per_writer_per_loop); - break; - } - case (sizeof(uint32_t)): { - SimpleTransposeSingleAxisOutwards(reinterpret_cast(input_data), - reinterpret_cast(output_data), - num_loops, num_writers, writes_per_loop, writes_per_writer_per_loop); - break; - } - case (sizeof(uint64_t)): { - SimpleTransposeSingleAxisOutwards(reinterpret_cast(input_data), - reinterpret_cast(output_data), - num_loops, num_writers, writes_per_loop, writes_per_writer_per_loop); - break; - } - default: { - // we need to use memcpy for each block - for (int64_t l = 0; l < num_loops; ++l) { - uint8_t* output_for_first_writer = output_data; - - for (auto wwpl = 0; wwpl < writes_per_writer_per_loop; ++wwpl) { - uint8_t* output_for_current_writer = output_for_first_writer; - - for (int64_t w = 0; w < num_writers; ++w) { - memcpy(output_for_current_writer, input_data, bytes_per_write); - // skip to output position for next writer - output_for_current_writer += (writes_per_writer_per_loop * bytes_per_write); - input_data += bytes_per_write; - } - - output_for_first_writer += bytes_per_write; - } - - output_data += writes_per_loop * bytes_per_write; - } - } - } -} - -template -typename std::enable_if::value, void>::type -SimpleTransposeSingleAxisInwards(const T* input_data, T* output_data, - int64_t num_loops, int64_t num_readers, - int64_t reads_per_loop, int64_t reads_per_reader_per_loop) { - T* end; - for (int64_t l = 0; l < num_loops; ++l) { - const T* input_for_first_reader = input_data; - - for (auto rrpl = 0; rrpl < reads_per_reader_per_loop; ++rrpl) { - const T* input_for_current_reader = input_for_first_reader; - - end = output_data + num_readers; - for (; output_data != end;) { - *output_data++ = *input_for_current_reader; - // skip to input position for next reader - input_for_current_reader += reads_per_reader_per_loop; - } - - ++input_for_first_reader; - } - - input_data += reads_per_loop; - } -} - -template -typename std::enable_if::value, void>::type -SimpleTransposeSingleAxisInwards(const T* input_data, T* output_data, - int64_t num_loops, int64_t num_readers, - int64_t reads_per_loop, int64_t reads_per_reader_per_loop) { - for (int64_t l = 0; l < num_loops; ++l) { - MlasTranspose(input_data, - output_data, - static_cast(num_readers), - static_cast(reads_per_reader_per_loop)); - input_data += reads_per_loop; - output_data += reads_per_loop; - } -} - -// moving a single axis inwards where the read/write size is a power of 2 and between 8 and 64 bits. -// `input_shape_override` overrides the shape of `input` for compute purposes. -void TransposeSingleAxisInwards(const gsl::span& permutations, const Tensor& input, Tensor& output, - int64_t from, int64_t to, const TensorShape* input_shape_override = nullptr) { - ORT_UNUSED_PARAMETER(permutations); - - const auto& input_shape = input_shape_override ? *input_shape_override : input.Shape(); - const auto& input_dims = input_shape.GetDims(); - - const auto element_size = input.DataType()->Size(); - - const auto* input_data = reinterpret_cast(input.DataRaw()); - auto* output_data = reinterpret_cast(output.MutableDataRaw()); - - auto num_loops = input_shape.SizeToDimension(from); - auto num_readers = input_dims[from]; - auto block_size = input_shape.SizeFromDimension(to + 1); - auto reads_per_loop = int64_t(input_shape.Size() / num_loops / block_size); - auto reads_per_reader_per_loop = int64_t(reads_per_loop / num_readers); - const int64_t bytes_per_read = block_size * element_size; - - switch (bytes_per_read) { - case (sizeof(uint8_t)): { - SimpleTransposeSingleAxisInwards(input_data, output_data, - num_loops, num_readers, reads_per_loop, reads_per_reader_per_loop); - break; - } - case (sizeof(uint16_t)): { - SimpleTransposeSingleAxisInwards(reinterpret_cast(input_data), - reinterpret_cast(output_data), - num_loops, num_readers, reads_per_loop, reads_per_reader_per_loop); - break; - } - case (sizeof(uint32_t)): { - SimpleTransposeSingleAxisInwards(reinterpret_cast(input_data), - reinterpret_cast(output_data), - num_loops, num_readers, reads_per_loop, reads_per_reader_per_loop); - break; - } - case (sizeof(uint64_t)): { - SimpleTransposeSingleAxisInwards(reinterpret_cast(input_data), - reinterpret_cast(output_data), - num_loops, num_readers, reads_per_loop, reads_per_reader_per_loop); - break; - } - default: { - // we need to use memcpy for each block - for (int64_t l = 0; l < num_loops; ++l) { - const uint8_t* input_for_first_reader = input_data; - - for (auto rrpl = 0; rrpl < reads_per_reader_per_loop; ++rrpl) { - const uint8_t* input_for_current_reader = input_for_first_reader; - - for (int64_t r = 0; r < num_readers; ++r) { - memcpy(output_data, input_for_current_reader, bytes_per_read); - output_data += bytes_per_read; - - // skip to input position for next reader - input_for_current_reader += (reads_per_reader_per_loop * bytes_per_read); - } - - input_for_first_reader += bytes_per_read; - } - - input_data += reads_per_loop * bytes_per_read; - } - } - } -} - -// `input_shape_override` overrides the shape of `input` for compute purposes. -void SingleAxisTranspose(const gsl::span& permutations, const Tensor& input, Tensor& output, - size_t from, size_t to, const TensorShape* input_shape_override = nullptr) { - if (from > to) { - TransposeSingleAxisOutwards(permutations, input, output, from, to, input_shape_override); - } else { - TransposeSingleAxisInwards(permutations, input, output, from, to, input_shape_override); - } -} - -bool IsMovingSingleAxis(const gsl::span& permutations, size_t& from, size_t& to) { - // if a single axis moved to an outer dimension, the values should be one lower than the index until the slot the - // axis was moved from, and equal to the index after that. - // e.g. axis 3 moves out to 1 would be: 0, 3, 1, 2, 4 - auto check_moved_outwards = [&permutations](size_t cur, size_t moved_from) { - // we start processing with the slot after the moved one, so the expected value is one less than the index - size_t expected = cur - 1; - for (size_t end = permutations.size(); cur < end; ++cur) { - if (permutations[cur] != expected) { - return false; - } - - // we are at the slot the axis moved from, so do an additional increment before checking the next value - if (cur == moved_from) { - ++expected; - } - - ++expected; - } - - return true; - }; - - // if a single axis moved to an inner dimension, the values should be one higher than the index until the slot the - // axis was moved to, and equal to the index after that. - // e.g. axis 1 moves inwards to 3 would be: 0, 2, 3, 1, 4 - auto check_moved_inwards = [&permutations](size_t cur, size_t& moved_to) { - size_t started_at = cur; - size_t expected = cur + 1; - moved_to = std::numeric_limits::max(); - - for (size_t end = permutations.size(); cur < end; ++cur) { - if (permutations[cur] != expected) { - // if a single axis moved it must have come from the location we started at - if (started_at != permutations[cur]) { - return false; - } - - moved_to = cur; - } else { - ++expected; - } - } - - return moved_to != std::numeric_limits::max(); - }; - - bool single_axis_moved = false; - // check axis moving outwards (earlier entry in permutations) - for (size_t i = 0, end = permutations.size(); i < end; ++i) { - size_t axis = permutations[i]; - - if (axis != i) { - if (check_moved_outwards(i + 1, axis)) { - single_axis_moved = true; - to = i; - from = axis; - } else if (check_moved_inwards(i, to)) { - single_axis_moved = true; - from = i; - } - - break; - } - } - - return single_axis_moved; -} - -} // namespace bool IsTransposeReshape(const gsl::span& perm, gsl::span input_dims) { // As long as the dims with values > 1 stay in the same order, it's a reshape. @@ -698,7 +360,7 @@ Status TransposeBase::DoTranspose(const gsl::span& permutations, c } size_t from = 0, to = 0; - bool moving_single_axis = IsMovingSingleAxis(permutations, from, to); + bool moving_single_axis = IsTransposeMovingSingleAxis(permutations, from, to); if (moving_single_axis && !input.IsDataTypeString()) { SingleAxisTranspose(permutations, input, output, from, to, input_shape_override); @@ -740,7 +402,7 @@ Status Transpose::Compute(OpKernelContext* ctx) const { } size_t from = 0, to = 0; - bool moving_single_axis = IsMovingSingleAxis(*p_perm, from, to); + bool moving_single_axis = IsTransposeMovingSingleAxis(*p_perm, from, to); if (moving_single_axis && !X.IsDataTypeString()) { SingleAxisTranspose(*p_perm, X, Y, from, to); diff --git a/onnxruntime/test/framework/transpose_test.cc b/onnxruntime/test/framework/transpose_test.cc new file mode 100644 index 0000000000..bb4da421fd --- /dev/null +++ b/onnxruntime/test/framework/transpose_test.cc @@ -0,0 +1,38 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "gtest/gtest.h" +#include "core/framework/transpose_helper.h" + +namespace onnxruntime { +namespace test { + +TEST(TransposeIsMovingSingleAxis, test1) { + std::array perm{1, 2, 3, 0}; + size_t from = 0, to = 0; + ASSERT_TRUE(IsTransposeMovingSingleAxis(perm, from, to)); + ASSERT_EQ(from, static_cast(0)); + ASSERT_EQ(to, static_cast(3)); +} + +TEST(TransposeIsMovingSingleAxis, test2) { + std::array perm{0, 2, 3, 1}; + size_t from = 0, to = 0; + ASSERT_TRUE(IsTransposeMovingSingleAxis(perm, from, to)); + ASSERT_EQ(from, static_cast(1)); + ASSERT_EQ(to, static_cast(3)); +} + +TEST(TransposeIsMovingSingleAxis, test3) { + std::array perm{3, 1, 0, 2}; + size_t from = 0, to = 0; + ASSERT_FALSE(IsTransposeMovingSingleAxis(perm, from, to)); +} + +TEST(TransposeIsMovingSingleAxis, test4) { + std::array perm{0, 1, 2, 3}; + size_t from = 0, to = 0; + ASSERT_FALSE(IsTransposeMovingSingleAxis(perm, from, to)); +} +} // namespace test +} // namespace onnxruntime