Move some of the transpose kernel code to onnxruntime_framework.lib (#11380)

* Move some of the tranpose kernel code to onnxruntime_framework.lib

* Fix C4244 warnings in the tranpose code

* Rename IsMovingSingleAxis to IsTransposeMovingSingleAxis
This commit is contained in:
Changming Sun 2022-05-03 14:03:50 -07:00 committed by GitHub
parent 308b605047
commit 253c8b41ed
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 399 additions and 341 deletions

View file

@ -0,0 +1,313 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "core/framework/transpose_helper.h"
#include "core/mlas/inc/mlas.h"
namespace onnxruntime {
template <typename T>
struct has_mlas_transpose : std::false_type {};
template <>
struct has_mlas_transpose<uint8_t> : std::true_type {};
template <>
struct has_mlas_transpose<uint32_t> : std::true_type {};
// moving a single axis outwards where the read/write size is a power of 2 and between 8 and 64 bits.
template <typename T>
typename std::enable_if<!has_mlas_transpose<T>::value, void>::type SimpleTransposeSingleAxisOutwards(
const T* input_data, T* output_data, int64_t num_loops, int64_t num_writers, int64_t writes_per_loop,
int64_t writes_per_writer_per_loop) {
const T* end;
for (int64_t l = 0; l < num_loops; ++l) {
T* output_for_first_writer = output_data;
for (auto wwpl = 0; wwpl < writes_per_writer_per_loop; ++wwpl) {
T* output_for_current_writer = output_for_first_writer;
end = input_data + num_writers;
for (; input_data != end;) {
*output_for_current_writer = *input_data++;
// skip to output position for next writer
output_for_current_writer += writes_per_writer_per_loop;
}
++output_for_first_writer;
}
output_data += writes_per_loop;
}
}
template <typename T>
typename std::enable_if<has_mlas_transpose<T>::value, void>::type SimpleTransposeSingleAxisOutwards(
const T* input_data, T* output_data, int64_t num_loops, int64_t num_writers, int64_t writes_per_loop,
int64_t writes_per_writer_per_loop) {
for (int64_t l = 0; l < num_loops; ++l) {
MlasTranspose(input_data, output_data, static_cast<size_t>(writes_per_writer_per_loop),
static_cast<size_t>(num_writers));
input_data += writes_per_loop;
output_data += writes_per_loop;
}
}
// `input_shape_override` overrides the shape of `input` for compute purposes.
void TransposeSingleAxisOutwards(gsl::span<const size_t> permutations, const Tensor& input, Tensor& output,
size_t from, size_t to, const TensorShape* input_shape_override = nullptr) {
ORT_UNUSED_PARAMETER(permutations);
const auto& input_shape = input_shape_override ? *input_shape_override : input.Shape();
const auto& input_dims = input_shape.GetDims();
const auto element_size = input.DataType()->Size();
const auto* input_data = reinterpret_cast<const uint8_t*>(input.DataRaw());
auto* output_data = reinterpret_cast<uint8_t*>(output.MutableDataRaw());
auto num_loops = input_shape.SizeToDimension(to);
auto num_writers = input_dims[from];
auto block_size = input_shape.SizeFromDimension(from + 1);
auto writes_per_loop = int64_t(input_shape.Size() / num_loops / block_size);
auto writes_per_writer_per_loop = int64_t(writes_per_loop / num_writers);
// TODO: check integer overflow
const size_t bytes_per_write = static_cast<size_t>(block_size) * element_size;
switch (bytes_per_write) {
case (sizeof(uint8_t)): {
SimpleTransposeSingleAxisOutwards(input_data, output_data, num_loops, num_writers, writes_per_loop,
writes_per_writer_per_loop);
break;
}
case (sizeof(uint16_t)): {
SimpleTransposeSingleAxisOutwards(reinterpret_cast<const uint16_t*>(input_data),
reinterpret_cast<uint16_t*>(output_data), num_loops, num_writers,
writes_per_loop, writes_per_writer_per_loop);
break;
}
case (sizeof(uint32_t)): {
SimpleTransposeSingleAxisOutwards(reinterpret_cast<const uint32_t*>(input_data),
reinterpret_cast<uint32_t*>(output_data), num_loops, num_writers,
writes_per_loop, writes_per_writer_per_loop);
break;
}
case (sizeof(uint64_t)): {
SimpleTransposeSingleAxisOutwards(reinterpret_cast<const uint64_t*>(input_data),
reinterpret_cast<uint64_t*>(output_data), num_loops, num_writers,
writes_per_loop, writes_per_writer_per_loop);
break;
}
default: {
// we need to use memcpy for each block
for (int64_t l = 0; l < num_loops; ++l) {
uint8_t* output_for_first_writer = output_data;
for (auto wwpl = 0; wwpl < writes_per_writer_per_loop; ++wwpl) {
uint8_t* output_for_current_writer = output_for_first_writer;
for (int64_t w = 0; w < num_writers; ++w) {
memcpy(output_for_current_writer, input_data, bytes_per_write);
// skip to output position for next writer
output_for_current_writer += (writes_per_writer_per_loop * bytes_per_write);
input_data += bytes_per_write;
}
output_for_first_writer += bytes_per_write;
}
output_data += writes_per_loop * bytes_per_write;
}
}
}
}
template <typename T>
typename std::enable_if<!has_mlas_transpose<T>::value, void>::type SimpleTransposeSingleAxisInwards(
const T* input_data, T* output_data, int64_t num_loops, int64_t num_readers, int64_t reads_per_loop,
int64_t reads_per_reader_per_loop) {
T* end;
for (int64_t l = 0; l < num_loops; ++l) {
const T* input_for_first_reader = input_data;
for (auto rrpl = 0; rrpl < reads_per_reader_per_loop; ++rrpl) {
const T* input_for_current_reader = input_for_first_reader;
end = output_data + num_readers;
for (; output_data != end;) {
*output_data++ = *input_for_current_reader;
// skip to input position for next reader
input_for_current_reader += reads_per_reader_per_loop;
}
++input_for_first_reader;
}
input_data += reads_per_loop;
}
}
template <typename T>
typename std::enable_if<has_mlas_transpose<T>::value, void>::type SimpleTransposeSingleAxisInwards(
const T* input_data, T* output_data, int64_t num_loops, int64_t num_readers, int64_t reads_per_loop,
int64_t reads_per_reader_per_loop) {
for (int64_t l = 0; l < num_loops; ++l) {
MlasTranspose(input_data, output_data, static_cast<size_t>(num_readers),
static_cast<size_t>(reads_per_reader_per_loop));
input_data += reads_per_loop;
output_data += reads_per_loop;
}
}
// moving a single axis inwards where the read/write size is a power of 2 and between 8 and 64 bits.
// `input_shape_override` overrides the shape of `input` for compute purposes.
void TransposeSingleAxisInwards(gsl::span<const size_t> permutations, const Tensor& input, Tensor& output,
size_t from, size_t to, const TensorShape* input_shape_override = nullptr) {
ORT_UNUSED_PARAMETER(permutations);
const auto& input_shape = input_shape_override ? *input_shape_override : input.Shape();
const auto& input_dims = input_shape.GetDims();
const auto element_size = input.DataType()->Size();
const auto* input_data = reinterpret_cast<const uint8_t*>(input.DataRaw());
auto* output_data = reinterpret_cast<uint8_t*>(output.MutableDataRaw());
auto num_loops = input_shape.SizeToDimension(from);
auto num_readers = input_dims[from];
auto block_size = input_shape.SizeFromDimension(to + 1);
auto reads_per_loop = int64_t(input_shape.Size() / num_loops / block_size);
auto reads_per_reader_per_loop = int64_t(reads_per_loop / num_readers);
// TODO: check integer overflow
const size_t bytes_per_read = static_cast<size_t>(block_size) * element_size;
switch (bytes_per_read) {
case (sizeof(uint8_t)): {
SimpleTransposeSingleAxisInwards(input_data, output_data, num_loops, num_readers, reads_per_loop,
reads_per_reader_per_loop);
break;
}
case (sizeof(uint16_t)): {
SimpleTransposeSingleAxisInwards(reinterpret_cast<const uint16_t*>(input_data),
reinterpret_cast<uint16_t*>(output_data), num_loops, num_readers, reads_per_loop,
reads_per_reader_per_loop);
break;
}
case (sizeof(uint32_t)): {
SimpleTransposeSingleAxisInwards(reinterpret_cast<const uint32_t*>(input_data),
reinterpret_cast<uint32_t*>(output_data), num_loops, num_readers, reads_per_loop,
reads_per_reader_per_loop);
break;
}
case (sizeof(uint64_t)): {
SimpleTransposeSingleAxisInwards(reinterpret_cast<const uint64_t*>(input_data),
reinterpret_cast<uint64_t*>(output_data), num_loops, num_readers, reads_per_loop,
reads_per_reader_per_loop);
break;
}
default: {
// we need to use memcpy for each block
for (int64_t l = 0; l < num_loops; ++l) {
const uint8_t* input_for_first_reader = input_data;
for (auto rrpl = 0; rrpl < reads_per_reader_per_loop; ++rrpl) {
const uint8_t* input_for_current_reader = input_for_first_reader;
for (int64_t r = 0; r < num_readers; ++r) {
memcpy(output_data, input_for_current_reader, bytes_per_read);
output_data += bytes_per_read;
// skip to input position for next reader
input_for_current_reader += (reads_per_reader_per_loop * bytes_per_read);
}
input_for_first_reader += bytes_per_read;
}
input_data += reads_per_loop * bytes_per_read;
}
}
}
}
// `input_shape_override` overrides the shape of `input` for compute purposes.
void SingleAxisTranspose(gsl::span<const size_t> permutations, const Tensor& input, Tensor& output, size_t from,
size_t to, const TensorShape* input_shape_override) {
if (from > to) {
TransposeSingleAxisOutwards(permutations, input, output, from, to, input_shape_override);
} else {
TransposeSingleAxisInwards(permutations, input, output, from, to, input_shape_override);
}
}
bool IsTransposeMovingSingleAxis(gsl::span<const size_t> permutations, size_t& from, size_t& to) {
// if a single axis moved to an outer dimension, the values should be one lower than the index until the slot the
// axis was moved from, and equal to the index after that.
// e.g. axis 3 moves out to 1 would be: 0, 3, 1, 2, 4
auto check_moved_outwards = [&permutations](size_t cur, size_t moved_from) {
// we start processing with the slot after the moved one, so the expected value is one less than the index
size_t expected = cur - 1;
for (size_t end = permutations.size(); cur < end; ++cur) {
if (permutations[cur] != expected) {
return false;
}
// we are at the slot the axis moved from, so do an additional increment before checking the next value
if (cur == moved_from) {
++expected;
}
++expected;
}
return true;
};
// if a single axis moved to an inner dimension, the values should be one higher than the index until the slot the
// axis was moved to, and equal to the index after that.
// e.g. axis 1 moves inwards to 3 would be: 0, 2, 3, 1, 4
auto check_moved_inwards = [&permutations](size_t cur, size_t& moved_to) {
size_t started_at = cur;
size_t expected = cur + 1;
moved_to = std::numeric_limits<size_t>::max();
for (size_t end = permutations.size(); cur < end; ++cur) {
if (permutations[cur] != expected) {
// if a single axis moved it must have come from the location we started at
if (started_at != permutations[cur]) {
return false;
}
moved_to = cur;
} else {
++expected;
}
}
return moved_to != std::numeric_limits<size_t>::max();
};
bool single_axis_moved = false;
// check axis moving outwards (earlier entry in permutations)
for (size_t i = 0, end = permutations.size(); i < end; ++i) {
size_t axis = permutations[i];
if (axis != i) {
if (check_moved_outwards(i + 1, axis)) {
single_axis_moved = true;
to = i;
from = axis;
} else if (check_moved_inwards(i, to)) {
single_axis_moved = true;
from = i;
}
break;
}
}
return single_axis_moved;
}
} // namespace onnxruntime

View file

@ -0,0 +1,45 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
/*
This file contains optimizations for moving a single axis either inwards or outwards.
If moving outwards we can use a single reader and multiple writers. The number of writers is equal to the value of
the axis being moved.
e.g. if the input is NHWC with shape {N, 300, 300, 3}, we can transpose to NCHW by reading once and having
one writer for each of the 3 channels at a different offset in the output, updating the offset for each item
in the batch of N.
Similarly if one axis is moving inwards we can use a single writer and multiple readers. The number of readers is equal
to the value of the axis being moved.
e.g. if the input is NCHW with shape {N, 3, 300, 300}, we can transpose to NHWC by writing once using one reader for
each of the 3 channels at a different offset in the input, updating the read offset for each item in the batch
of N.
This can be generalized for any input where only one axis is being moved, with the block size for each read/write
being dependent on which axis is moving, what direction it's moving in, and where it's moving to.
We use simple pointer arithmetic if the size of each read/write is a power of 2 and between 8 and 64 bits.
We use memcpy if the block size is larger.
We fall back to the default implementation in all other cases, and if the input is std::string.
*/
#include <sstream>
#include "core/common/common.h"
#include "core/common/inlined_containers.h"
#include "core/framework/tensor_shape.h"
#include "core/framework/tensor.h"
#include "gsl/gsl"
namespace onnxruntime {
bool IsTransposeMovingSingleAxis(gsl::span<const size_t> permutations, size_t& from, size_t& to);
void SingleAxisTranspose(gsl::span<const size_t> permutations, const Tensor& input, Tensor& output, size_t from,
size_t to, const TensorShape* input_shape_override = nullptr);
} // namespace onnxruntime

View file

@ -5,6 +5,7 @@
#include "core/framework/element_type_lists.h"
#include "core/framework/utils.h"
#include "core/framework/transpose_helper.h"
#include "core/framework/op_kernel_type_control_utils.h"
#include "core/mlas/inc/mlas.h"
#include "core/providers/op_kernel_type_control.h"
@ -323,345 +324,6 @@ static Status DoUntypedTranspose(const gsl::span<const size_t>& permutations, co
return status;
}
/*
Optimizations for moving a single axis either inwards or outwards.
If moving outwards we can use a single reader and multiple writers. The number of writers is equal to the value of
the axis being moved.
e.g. if the input is NHWC with shape {N, 300, 300, 3}, we can transpose to NCHW by reading once and having
one writer for each of the 3 channels at a different offset in the output, updating the offset for each item
in the batch of N.
Similarly if one axis is moving inwards we can use a single writer and multiple readers. The number of readers is equal
to the value of the axis being moved.
e.g. if the input is NCHW with shape {N, 3, 300, 300}, we can transpose to NHWC by writing once using one reader for
each of the 3 channels at a different offset in the input, updating the read offset for each item in the batch
of N.
This can be generalized for any input where only one axis is being moved, with the block size for each read/write
being dependent on which axis is moving, what direction it's moving in, and where it's moving to.
We use simple pointer arithmetic if the size of each read/write is a power of 2 and between 8 and 64 bits.
We use memcpy if the block size is larger.
We fall back to the default implementation in all other cases, and if the input is std::string.
*/
namespace {
template <typename T>
struct has_mlas_transpose : std::false_type {};
template <>
struct has_mlas_transpose<uint8_t> : std::true_type {};
template <>
struct has_mlas_transpose<uint32_t> : std::true_type {};
// moving a single axis outwards where the read/write size is a power of 2 and between 8 and 64 bits.
template <typename T>
typename std::enable_if<!has_mlas_transpose<T>::value, void>::type
SimpleTransposeSingleAxisOutwards(const T* input_data, T* output_data,
int64_t num_loops, int64_t num_writers,
int64_t writes_per_loop, int64_t writes_per_writer_per_loop) {
const T* end;
for (int64_t l = 0; l < num_loops; ++l) {
T* output_for_first_writer = output_data;
for (auto wwpl = 0; wwpl < writes_per_writer_per_loop; ++wwpl) {
T* output_for_current_writer = output_for_first_writer;
end = input_data + num_writers;
for (; input_data != end;) {
*output_for_current_writer = *input_data++;
// skip to output position for next writer
output_for_current_writer += writes_per_writer_per_loop;
}
++output_for_first_writer;
}
output_data += writes_per_loop;
}
}
template <typename T>
typename std::enable_if<has_mlas_transpose<T>::value, void>::type
SimpleTransposeSingleAxisOutwards(const T* input_data, T* output_data,
int64_t num_loops, int64_t num_writers,
int64_t writes_per_loop, int64_t writes_per_writer_per_loop) {
for (int64_t l = 0; l < num_loops; ++l) {
MlasTranspose(input_data,
output_data,
static_cast<size_t>(writes_per_writer_per_loop),
static_cast<size_t>(num_writers));
input_data += writes_per_loop;
output_data += writes_per_loop;
}
}
// `input_shape_override` overrides the shape of `input` for compute purposes.
void TransposeSingleAxisOutwards(const gsl::span<const size_t>& permutations, const Tensor& input, Tensor& output,
int64_t from, int64_t to, const TensorShape* input_shape_override = nullptr) {
ORT_UNUSED_PARAMETER(permutations);
const auto& input_shape = input_shape_override ? *input_shape_override : input.Shape();
const auto& input_dims = input_shape.GetDims();
const auto element_size = input.DataType()->Size();
const auto* input_data = reinterpret_cast<const uint8_t*>(input.DataRaw());
auto* output_data = reinterpret_cast<uint8_t*>(output.MutableDataRaw());
auto num_loops = input_shape.SizeToDimension(to);
auto num_writers = input_dims[from];
auto block_size = input_shape.SizeFromDimension(from + 1);
auto writes_per_loop = int64_t(input_shape.Size() / num_loops / block_size);
auto writes_per_writer_per_loop = int64_t(writes_per_loop / num_writers);
const int64_t bytes_per_write = block_size * element_size;
switch (bytes_per_write) {
case (sizeof(uint8_t)): {
SimpleTransposeSingleAxisOutwards(input_data, output_data,
num_loops, num_writers, writes_per_loop, writes_per_writer_per_loop);
break;
}
case (sizeof(uint16_t)): {
SimpleTransposeSingleAxisOutwards(reinterpret_cast<const uint16_t*>(input_data),
reinterpret_cast<uint16_t*>(output_data),
num_loops, num_writers, writes_per_loop, writes_per_writer_per_loop);
break;
}
case (sizeof(uint32_t)): {
SimpleTransposeSingleAxisOutwards(reinterpret_cast<const uint32_t*>(input_data),
reinterpret_cast<uint32_t*>(output_data),
num_loops, num_writers, writes_per_loop, writes_per_writer_per_loop);
break;
}
case (sizeof(uint64_t)): {
SimpleTransposeSingleAxisOutwards(reinterpret_cast<const uint64_t*>(input_data),
reinterpret_cast<uint64_t*>(output_data),
num_loops, num_writers, writes_per_loop, writes_per_writer_per_loop);
break;
}
default: {
// we need to use memcpy for each block
for (int64_t l = 0; l < num_loops; ++l) {
uint8_t* output_for_first_writer = output_data;
for (auto wwpl = 0; wwpl < writes_per_writer_per_loop; ++wwpl) {
uint8_t* output_for_current_writer = output_for_first_writer;
for (int64_t w = 0; w < num_writers; ++w) {
memcpy(output_for_current_writer, input_data, bytes_per_write);
// skip to output position for next writer
output_for_current_writer += (writes_per_writer_per_loop * bytes_per_write);
input_data += bytes_per_write;
}
output_for_first_writer += bytes_per_write;
}
output_data += writes_per_loop * bytes_per_write;
}
}
}
}
template <typename T>
typename std::enable_if<!has_mlas_transpose<T>::value, void>::type
SimpleTransposeSingleAxisInwards(const T* input_data, T* output_data,
int64_t num_loops, int64_t num_readers,
int64_t reads_per_loop, int64_t reads_per_reader_per_loop) {
T* end;
for (int64_t l = 0; l < num_loops; ++l) {
const T* input_for_first_reader = input_data;
for (auto rrpl = 0; rrpl < reads_per_reader_per_loop; ++rrpl) {
const T* input_for_current_reader = input_for_first_reader;
end = output_data + num_readers;
for (; output_data != end;) {
*output_data++ = *input_for_current_reader;
// skip to input position for next reader
input_for_current_reader += reads_per_reader_per_loop;
}
++input_for_first_reader;
}
input_data += reads_per_loop;
}
}
template <typename T>
typename std::enable_if<has_mlas_transpose<T>::value, void>::type
SimpleTransposeSingleAxisInwards(const T* input_data, T* output_data,
int64_t num_loops, int64_t num_readers,
int64_t reads_per_loop, int64_t reads_per_reader_per_loop) {
for (int64_t l = 0; l < num_loops; ++l) {
MlasTranspose(input_data,
output_data,
static_cast<size_t>(num_readers),
static_cast<size_t>(reads_per_reader_per_loop));
input_data += reads_per_loop;
output_data += reads_per_loop;
}
}
// moving a single axis inwards where the read/write size is a power of 2 and between 8 and 64 bits.
// `input_shape_override` overrides the shape of `input` for compute purposes.
void TransposeSingleAxisInwards(const gsl::span<const size_t>& permutations, const Tensor& input, Tensor& output,
int64_t from, int64_t to, const TensorShape* input_shape_override = nullptr) {
ORT_UNUSED_PARAMETER(permutations);
const auto& input_shape = input_shape_override ? *input_shape_override : input.Shape();
const auto& input_dims = input_shape.GetDims();
const auto element_size = input.DataType()->Size();
const auto* input_data = reinterpret_cast<const uint8_t*>(input.DataRaw());
auto* output_data = reinterpret_cast<uint8_t*>(output.MutableDataRaw());
auto num_loops = input_shape.SizeToDimension(from);
auto num_readers = input_dims[from];
auto block_size = input_shape.SizeFromDimension(to + 1);
auto reads_per_loop = int64_t(input_shape.Size() / num_loops / block_size);
auto reads_per_reader_per_loop = int64_t(reads_per_loop / num_readers);
const int64_t bytes_per_read = block_size * element_size;
switch (bytes_per_read) {
case (sizeof(uint8_t)): {
SimpleTransposeSingleAxisInwards(input_data, output_data,
num_loops, num_readers, reads_per_loop, reads_per_reader_per_loop);
break;
}
case (sizeof(uint16_t)): {
SimpleTransposeSingleAxisInwards(reinterpret_cast<const uint16_t*>(input_data),
reinterpret_cast<uint16_t*>(output_data),
num_loops, num_readers, reads_per_loop, reads_per_reader_per_loop);
break;
}
case (sizeof(uint32_t)): {
SimpleTransposeSingleAxisInwards(reinterpret_cast<const uint32_t*>(input_data),
reinterpret_cast<uint32_t*>(output_data),
num_loops, num_readers, reads_per_loop, reads_per_reader_per_loop);
break;
}
case (sizeof(uint64_t)): {
SimpleTransposeSingleAxisInwards(reinterpret_cast<const uint64_t*>(input_data),
reinterpret_cast<uint64_t*>(output_data),
num_loops, num_readers, reads_per_loop, reads_per_reader_per_loop);
break;
}
default: {
// we need to use memcpy for each block
for (int64_t l = 0; l < num_loops; ++l) {
const uint8_t* input_for_first_reader = input_data;
for (auto rrpl = 0; rrpl < reads_per_reader_per_loop; ++rrpl) {
const uint8_t* input_for_current_reader = input_for_first_reader;
for (int64_t r = 0; r < num_readers; ++r) {
memcpy(output_data, input_for_current_reader, bytes_per_read);
output_data += bytes_per_read;
// skip to input position for next reader
input_for_current_reader += (reads_per_reader_per_loop * bytes_per_read);
}
input_for_first_reader += bytes_per_read;
}
input_data += reads_per_loop * bytes_per_read;
}
}
}
}
// `input_shape_override` overrides the shape of `input` for compute purposes.
void SingleAxisTranspose(const gsl::span<const size_t>& permutations, const Tensor& input, Tensor& output,
size_t from, size_t to, const TensorShape* input_shape_override = nullptr) {
if (from > to) {
TransposeSingleAxisOutwards(permutations, input, output, from, to, input_shape_override);
} else {
TransposeSingleAxisInwards(permutations, input, output, from, to, input_shape_override);
}
}
bool IsMovingSingleAxis(const gsl::span<const size_t>& permutations, size_t& from, size_t& to) {
// if a single axis moved to an outer dimension, the values should be one lower than the index until the slot the
// axis was moved from, and equal to the index after that.
// e.g. axis 3 moves out to 1 would be: 0, 3, 1, 2, 4
auto check_moved_outwards = [&permutations](size_t cur, size_t moved_from) {
// we start processing with the slot after the moved one, so the expected value is one less than the index
size_t expected = cur - 1;
for (size_t end = permutations.size(); cur < end; ++cur) {
if (permutations[cur] != expected) {
return false;
}
// we are at the slot the axis moved from, so do an additional increment before checking the next value
if (cur == moved_from) {
++expected;
}
++expected;
}
return true;
};
// if a single axis moved to an inner dimension, the values should be one higher than the index until the slot the
// axis was moved to, and equal to the index after that.
// e.g. axis 1 moves inwards to 3 would be: 0, 2, 3, 1, 4
auto check_moved_inwards = [&permutations](size_t cur, size_t& moved_to) {
size_t started_at = cur;
size_t expected = cur + 1;
moved_to = std::numeric_limits<size_t>::max();
for (size_t end = permutations.size(); cur < end; ++cur) {
if (permutations[cur] != expected) {
// if a single axis moved it must have come from the location we started at
if (started_at != permutations[cur]) {
return false;
}
moved_to = cur;
} else {
++expected;
}
}
return moved_to != std::numeric_limits<size_t>::max();
};
bool single_axis_moved = false;
// check axis moving outwards (earlier entry in permutations)
for (size_t i = 0, end = permutations.size(); i < end; ++i) {
size_t axis = permutations[i];
if (axis != i) {
if (check_moved_outwards(i + 1, axis)) {
single_axis_moved = true;
to = i;
from = axis;
} else if (check_moved_inwards(i, to)) {
single_axis_moved = true;
from = i;
}
break;
}
}
return single_axis_moved;
}
} // namespace
bool IsTransposeReshape(const gsl::span<const size_t>& perm, gsl::span<const int64_t> input_dims) {
// As long as the dims with values > 1 stay in the same order, it's a reshape.
@ -698,7 +360,7 @@ Status TransposeBase::DoTranspose(const gsl::span<const size_t>& permutations, c
}
size_t from = 0, to = 0;
bool moving_single_axis = IsMovingSingleAxis(permutations, from, to);
bool moving_single_axis = IsTransposeMovingSingleAxis(permutations, from, to);
if (moving_single_axis && !input.IsDataTypeString()) {
SingleAxisTranspose(permutations, input, output, from, to, input_shape_override);
@ -740,7 +402,7 @@ Status Transpose::Compute(OpKernelContext* ctx) const {
}
size_t from = 0, to = 0;
bool moving_single_axis = IsMovingSingleAxis(*p_perm, from, to);
bool moving_single_axis = IsTransposeMovingSingleAxis(*p_perm, from, to);
if (moving_single_axis && !X.IsDataTypeString()) {
SingleAxisTranspose(*p_perm, X, Y, from, to);

View file

@ -0,0 +1,38 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "gtest/gtest.h"
#include "core/framework/transpose_helper.h"
namespace onnxruntime {
namespace test {
TEST(TransposeIsMovingSingleAxis, test1) {
std::array<size_t, 4> perm{1, 2, 3, 0};
size_t from = 0, to = 0;
ASSERT_TRUE(IsTransposeMovingSingleAxis(perm, from, to));
ASSERT_EQ(from, static_cast<size_t>(0));
ASSERT_EQ(to, static_cast<size_t>(3));
}
TEST(TransposeIsMovingSingleAxis, test2) {
std::array<size_t, 4> perm{0, 2, 3, 1};
size_t from = 0, to = 0;
ASSERT_TRUE(IsTransposeMovingSingleAxis(perm, from, to));
ASSERT_EQ(from, static_cast<size_t>(1));
ASSERT_EQ(to, static_cast<size_t>(3));
}
TEST(TransposeIsMovingSingleAxis, test3) {
std::array<size_t, 4> perm{3, 1, 0, 2};
size_t from = 0, to = 0;
ASSERT_FALSE(IsTransposeMovingSingleAxis(perm, from, to));
}
TEST(TransposeIsMovingSingleAxis, test4) {
std::array<size_t, 4> perm{0, 1, 2, 3};
size_t from = 0, to = 0;
ASSERT_FALSE(IsTransposeMovingSingleAxis(perm, from, to));
}
} // namespace test
} // namespace onnxruntime