mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-25 22:26:24 +00:00
Rework Transpose as a generic type agnostic implementation (#561)
Make Transpose op impl generic and add std::string support. Un-templatize implementation functions that make use of memcpy(). Support all types per spec. Add string tests.
This commit is contained in:
parent
4bd8463228
commit
4c2b1c3018
3 changed files with 228 additions and 41 deletions
|
|
@ -15,7 +15,7 @@ namespace onnxruntime {
|
|||
|
||||
// ComputeOffset: compute offset into a tensor. This is essentially the dot-product of
|
||||
// index and stride, restricted to the specified number of axes.
|
||||
size_t ComputeOffset(const std::vector<int64_t>& index, const std::vector<size_t>& stride, int64_t num_axes) {
|
||||
static inline size_t ComputeOffset(const std::vector<int64_t>& index, const std::vector<size_t>& stride, int64_t num_axes) {
|
||||
size_t offset = 0;
|
||||
for (int64_t j = 0; j < num_axes; ++j) {
|
||||
offset += index[j] * stride[j];
|
||||
|
|
@ -25,7 +25,7 @@ size_t ComputeOffset(const std::vector<int64_t>& index, const std::vector<size_t
|
|||
|
||||
// IncrementIndex: Increment an index into a tensor (in lexicographic ordering), wrapping
|
||||
// around the specified upper_bound.
|
||||
void IncrementIndex(std::vector<int64_t>& index, const std::vector<int64_t>& upper_bound, int64_t num_axes) {
|
||||
static inline void IncrementIndex(std::vector<int64_t>& index, const std::vector<int64_t>& upper_bound, int64_t num_axes) {
|
||||
for (int64_t k = num_axes - 1; k >= 0; --k) {
|
||||
index[k]++;
|
||||
if (index[k] < upper_bound[k]) break;
|
||||
|
|
@ -33,13 +33,26 @@ void IncrementIndex(std::vector<int64_t>& index, const std::vector<int64_t>& upp
|
|||
}
|
||||
}
|
||||
|
||||
// DoTransposeSingleBlock: specialization of DoTranspose for the num_blocks=1 case.
|
||||
// copies source tensor to target, transposing elements.
|
||||
static inline void DoTransposeSingleBlock(size_t num_elts_in_block, const void* source, void* target,
|
||||
size_t element_size) {
|
||||
size_t blocksize = num_elts_in_block * element_size;
|
||||
// copy
|
||||
memcpy(target, source, blocksize);
|
||||
}
|
||||
|
||||
static inline void DoTransposeSingleBlock(size_t num_elts_in_block, const std::string* source, std::string* target) {
|
||||
const std::string* end = source + num_elts_in_block;
|
||||
std::copy(source, end, target);
|
||||
}
|
||||
|
||||
// DoTranspose: copies source tensor to target, transposing elements.
|
||||
// The stride vector indicates the transposition.
|
||||
template <typename T>
|
||||
static void DoTransposeImpl(int64_t num_axes, const std::vector<int64_t>& target_dims,
|
||||
size_t num_blocks, size_t num_elts_in_block, const std::vector<size_t>& stride,
|
||||
const T* source, T* target) {
|
||||
size_t blocksize = num_elts_in_block * sizeof(float);
|
||||
const uint8_t* source, uint8_t* target, size_t element_size) {
|
||||
size_t blocksize = num_elts_in_block * element_size;
|
||||
// index used to iterate over target iteration-space
|
||||
std::vector<int64_t> target_index(num_axes, 0);
|
||||
for (size_t i = 0; i < num_blocks; ++i) {
|
||||
|
|
@ -47,7 +60,25 @@ static void DoTransposeImpl(int64_t num_axes, const std::vector<int64_t>& target
|
|||
size_t source_offset = ComputeOffset(target_index, stride, num_axes);
|
||||
|
||||
// copy
|
||||
memcpy(target, source + source_offset, blocksize);
|
||||
memcpy(target, source + source_offset * element_size, blocksize);
|
||||
|
||||
// increment target_index:
|
||||
IncrementIndex(target_index, target_dims, num_axes);
|
||||
target += blocksize;
|
||||
}
|
||||
}
|
||||
|
||||
static void DoTransposeImpl(int64_t num_axes, const std::vector<int64_t>& target_dims,
|
||||
size_t num_blocks, size_t num_elts_in_block, const std::vector<size_t>& stride,
|
||||
const std::string* source, std::string* target) {
|
||||
// index used to iterate over target iteration-space
|
||||
std::vector<int64_t> target_index(num_axes, 0);
|
||||
for (size_t i = 0; i < num_blocks; ++i) {
|
||||
// convert target_index into an offset in source data
|
||||
size_t source_offset = ComputeOffset(target_index, stride, num_axes);
|
||||
|
||||
// copy
|
||||
DoTransposeSingleBlock(num_elts_in_block, source + source_offset, target);
|
||||
|
||||
// increment target_index:
|
||||
IncrementIndex(target_index, target_dims, num_axes);
|
||||
|
|
@ -55,12 +86,80 @@ static void DoTransposeImpl(int64_t num_axes, const std::vector<int64_t>& target
|
|||
}
|
||||
}
|
||||
|
||||
template <class T>
|
||||
inline void CopyPrim(uint8_t* target, const uint8_t* source) {
|
||||
*reinterpret_cast<T*>(target) = *reinterpret_cast<const T*>(source);
|
||||
}
|
||||
|
||||
// DoTransposeEltWise: specialization of DoTranspose for the num_elts_in_block=1 case.
|
||||
// copies source tensor to target, transposing elements.
|
||||
// The stride vector indicates the transposition.
|
||||
template <typename T>
|
||||
static void DoTransposeEltWise(int64_t num_axes, const std::vector<int64_t>& target_dims, size_t num_blocks,
|
||||
const std::vector<size_t>& stride, const T* source, T* target) {
|
||||
const std::vector<size_t>& stride, const uint8_t* source, uint8_t* target,
|
||||
size_t element_size) {
|
||||
// index used to iterate over target iteration-space
|
||||
std::vector<int64_t> target_index(num_axes, 0);
|
||||
|
||||
switch (element_size) {
|
||||
case sizeof(uint64_t):
|
||||
for (size_t i = 0; i < num_blocks; ++i) {
|
||||
// convert target_index into an offset in source data
|
||||
size_t source_offset = ComputeOffset(target_index, stride, num_axes);
|
||||
|
||||
// copy
|
||||
CopyPrim<uint64_t>(target, source + (source_offset * element_size));
|
||||
|
||||
// increment target_index:
|
||||
IncrementIndex(target_index, target_dims, num_axes);
|
||||
target += element_size;
|
||||
}
|
||||
break;
|
||||
case sizeof(uint32_t):
|
||||
for (size_t i = 0; i < num_blocks; ++i) {
|
||||
// convert target_index into an offset in source data
|
||||
size_t source_offset = ComputeOffset(target_index, stride, num_axes);
|
||||
|
||||
// copy
|
||||
CopyPrim<uint32_t>(target, source + (source_offset * element_size));
|
||||
|
||||
// increment target_index:
|
||||
IncrementIndex(target_index, target_dims, num_axes);
|
||||
target += element_size;
|
||||
}
|
||||
break;
|
||||
case sizeof(uint16_t):
|
||||
for (size_t i = 0; i < num_blocks; ++i) {
|
||||
// convert target_index into an offset in source data
|
||||
size_t source_offset = ComputeOffset(target_index, stride, num_axes);
|
||||
|
||||
// copy
|
||||
CopyPrim<uint16_t>(target, source + (source_offset * element_size));
|
||||
|
||||
// increment target_index:
|
||||
IncrementIndex(target_index, target_dims, num_axes);
|
||||
target += element_size;
|
||||
}
|
||||
break;
|
||||
case sizeof(uint8_t):
|
||||
for (size_t i = 0; i < num_blocks; ++i) {
|
||||
// convert target_index into an offset in source data
|
||||
size_t source_offset = ComputeOffset(target_index, stride, num_axes);
|
||||
|
||||
// copy
|
||||
*target = *(source + (source_offset * element_size));
|
||||
|
||||
// increment target_index:
|
||||
IncrementIndex(target_index, target_dims, num_axes);
|
||||
target += element_size;
|
||||
}
|
||||
break;
|
||||
default:
|
||||
assert(false);
|
||||
}
|
||||
}
|
||||
|
||||
static void DoTransposeEltWise(int64_t num_axes, const std::vector<int64_t>& target_dims, size_t num_blocks,
|
||||
const std::vector<size_t>& stride, const std::string* source, std::string* target) {
|
||||
// index used to iterate over target iteration-space
|
||||
std::vector<int64_t> target_index(num_axes, 0);
|
||||
for (size_t i = 0; i < num_blocks; ++i) {
|
||||
|
|
@ -76,21 +175,14 @@ static void DoTransposeEltWise(int64_t num_axes, const std::vector<int64_t>& tar
|
|||
}
|
||||
}
|
||||
|
||||
// DoTransposeSingleBlock: specialization of DoTranspose for the num_blocks=1 case.
|
||||
// copies source tensor to target, transposing elements.
|
||||
template <typename T>
|
||||
static void DoTransposeSingleBlock(size_t num_elts_in_block, const T* source, T* target) {
|
||||
size_t blocksize = num_elts_in_block * sizeof(T);
|
||||
// copy
|
||||
memcpy(target, source, blocksize);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
static Status DoTypedTranspose(const std::vector<int64_t>& permutations, const Tensor& input, Tensor& output) {
|
||||
static Status DoUntypedTranspose(const std::vector<int64_t>& permutations, const Tensor& input, Tensor& output) {
|
||||
const auto& input_shape = input.Shape();
|
||||
const auto& input_dims = input_shape.GetDims();
|
||||
auto rank = input_shape.NumDimensions();
|
||||
|
||||
const auto element_size = input.DataType()->Size();
|
||||
const bool is_string_type = input.DataType() == DataTypeImpl::GetType<std::string>();
|
||||
|
||||
std::vector<size_t> stride(rank);
|
||||
for (int i = 0; i < rank; i++) {
|
||||
size_t inpdim = permutations[i];
|
||||
|
|
@ -118,17 +210,31 @@ static Status DoTypedTranspose(const std::vector<int64_t>& permutations, const T
|
|||
}
|
||||
}
|
||||
|
||||
const T* input_data = input.Data<T>();
|
||||
T* output_data = output.MutableData<T>();
|
||||
|
||||
if (1 == prefix_blocksize)
|
||||
DoTransposeSingleBlock<T>(suffix_blocksize, input_data, output_data);
|
||||
else if (1 == suffix_blocksize)
|
||||
DoTransposeEltWise<T>(num_axes_in_prefix, output.Shape().GetDims(), prefix_blocksize, stride,
|
||||
input_data, output_data);
|
||||
else
|
||||
DoTransposeImpl<T>(num_axes_in_prefix, output.Shape().GetDims(), prefix_blocksize, suffix_blocksize, stride,
|
||||
input_data, output_data);
|
||||
if (is_string_type) {
|
||||
const std::string* input_data = input.template Data<std::string>();
|
||||
std::string* output_data = output.template MutableData<std::string>();
|
||||
if (1 == prefix_blocksize) {
|
||||
DoTransposeSingleBlock(suffix_blocksize, input_data, output_data);
|
||||
} else if (1 == suffix_blocksize) {
|
||||
DoTransposeEltWise(num_axes_in_prefix, output.Shape().GetDims(), prefix_blocksize, stride,
|
||||
input_data, output_data);
|
||||
} else {
|
||||
DoTransposeImpl(num_axes_in_prefix, output.Shape().GetDims(), prefix_blocksize, suffix_blocksize, stride,
|
||||
input_data, output_data);
|
||||
}
|
||||
} else {
|
||||
const uint8_t* input_data = reinterpret_cast<const uint8_t*>(input.DataRaw());
|
||||
uint8_t* output_data = reinterpret_cast<uint8_t*>(output.MutableDataRaw());
|
||||
if (1 == prefix_blocksize) {
|
||||
DoTransposeSingleBlock(suffix_blocksize, input_data, output_data, element_size);
|
||||
} else if (1 == suffix_blocksize) {
|
||||
DoTransposeEltWise(num_axes_in_prefix, output.Shape().GetDims(), prefix_blocksize, stride,
|
||||
input_data, output_data, element_size);
|
||||
} else {
|
||||
DoTransposeImpl(num_axes_in_prefix, output.Shape().GetDims(), prefix_blocksize, suffix_blocksize, stride,
|
||||
input_data, output_data, element_size);
|
||||
}
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
|
@ -143,14 +249,13 @@ Status TransposeBase::DoTranspose(const std::vector<int64_t>& permutations, cons
|
|||
status = ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Mismatched data types between input and output Tensors. ",
|
||||
input_type, " != ", output_type);
|
||||
} else {
|
||||
DispatchOnTensorTypeWithReturn(input_type, status, DoTypedTranspose, permutations, input, output);
|
||||
status = DoUntypedTranspose(permutations, input, output);
|
||||
}
|
||||
|
||||
return status;
|
||||
}
|
||||
|
||||
template <>
|
||||
Status Transpose<float>::Compute(OpKernelContext* ctx) const {
|
||||
Status Transpose::Compute(OpKernelContext* ctx) const {
|
||||
// Get input and output:
|
||||
const Tensor* input_tensor_ptr = ctx->Input<Tensor>(0);
|
||||
ORT_ENFORCE(input_tensor_ptr != nullptr);
|
||||
|
|
@ -167,7 +272,7 @@ Status Transpose<float>::Compute(OpKernelContext* ctx) const {
|
|||
TensorShape output_shape{output_dims};
|
||||
Tensor& Y = *ctx->Output(0, output_shape);
|
||||
|
||||
DoTypedTranspose<float>(*p_perm, X, Y);
|
||||
DoUntypedTranspose(*p_perm, X, Y);
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
|
@ -175,7 +280,7 @@ Status Transpose<float>::Compute(OpKernelContext* ctx) const {
|
|||
ONNX_CPU_OPERATOR_KERNEL(
|
||||
Transpose,
|
||||
1,
|
||||
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<float>()),
|
||||
Transpose<float>);
|
||||
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::AllTensorTypes()),
|
||||
Transpose);
|
||||
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -65,7 +65,6 @@ class TransposeBase {
|
|||
std::vector<int64_t> perm_;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
class Transpose final : public OpKernel, public TransposeBase {
|
||||
public:
|
||||
Transpose(const OpKernelInfo& info) : OpKernel(info), TransposeBase(info) {}
|
||||
|
|
|
|||
|
|
@ -7,16 +7,17 @@
|
|||
namespace onnxruntime {
|
||||
namespace test {
|
||||
|
||||
template <class T>
|
||||
void TransposeTest(std::vector<int64_t>& input_shape,
|
||||
std::vector<float>& input_vals,
|
||||
std::vector<T>& input_vals,
|
||||
std::vector<int64_t>* p_perm,
|
||||
std::vector<int64_t> expected_shape,
|
||||
std::initializer_list<float>& expected_vals) {
|
||||
std::initializer_list<T>& expected_vals) {
|
||||
OpTester test("Transpose");
|
||||
if (nullptr != p_perm)
|
||||
test.AddAttribute("perm", *p_perm);
|
||||
test.AddInput<float>("X", input_shape, input_vals);
|
||||
test.AddOutput<float>("Y", expected_shape, expected_vals);
|
||||
test.AddInput<T>("X", input_shape, input_vals);
|
||||
test.AddOutput<T>("Y", expected_shape, expected_vals);
|
||||
test.Run();
|
||||
}
|
||||
|
||||
|
|
@ -36,6 +37,21 @@ TEST(TransposeOpTest, TwoDimNoAttr) {
|
|||
TransposeTest(input_shape, input_vals, nullptr, expected_shape, expected_vals);
|
||||
}
|
||||
|
||||
TEST(TransposeOpTest, TwoDimNoAttrStr) {
|
||||
std::vector<int64_t> input_shape({2, 3});
|
||||
std::vector<std::string> input_vals = {
|
||||
"1", "2", "3",
|
||||
"4", "5", "6"};
|
||||
|
||||
std::vector<int64_t> expected_shape({3, 2});
|
||||
std::initializer_list<std::string> expected_vals = {
|
||||
"1", "4",
|
||||
"2", "5",
|
||||
"3", "6"};
|
||||
|
||||
TransposeTest(input_shape, input_vals, nullptr, expected_shape, expected_vals);
|
||||
}
|
||||
|
||||
// Test 2 dimensional transpose, with permutation attribute specified
|
||||
TEST(TransposeOpTest, TwoDim) {
|
||||
std::vector<int64_t> input_shape({2, 3});
|
||||
|
|
@ -53,6 +69,22 @@ TEST(TransposeOpTest, TwoDim) {
|
|||
TransposeTest(input_shape, input_vals, &perm, expected_shape, expected_vals);
|
||||
}
|
||||
|
||||
TEST(TransposeOpTest, TwoDimStr) {
|
||||
std::vector<int64_t> input_shape({2, 3});
|
||||
std::vector<std::string> input_vals = {
|
||||
"1", "2", "3",
|
||||
"4", "5", "6"};
|
||||
|
||||
std::vector<int64_t> perm = {1, 0};
|
||||
std::vector<int64_t> expected_shape({3, 2});
|
||||
std::initializer_list<std::string> expected_vals = {
|
||||
"1", "4",
|
||||
"2", "5",
|
||||
"3", "6"};
|
||||
|
||||
TransposeTest(input_shape, input_vals, &perm, expected_shape, expected_vals);
|
||||
}
|
||||
|
||||
// Test 3 dimensional transpose, with permutation attribute specified
|
||||
TEST(TransposeOpTest, ThreeDim) {
|
||||
std::vector<int64_t> input_shape({4, 2, 3});
|
||||
|
|
@ -105,5 +137,56 @@ TEST(TransposeOpTest, ThreeDim) {
|
|||
TransposeTest(input_shape, input_vals, &perm, expected_shape, expected_vals);
|
||||
}
|
||||
|
||||
TEST(TransposeOpTest, ThreeDimStr) {
|
||||
std::vector<int64_t> input_shape({4, 2, 3});
|
||||
std::vector<std::string> input_vals = {
|
||||
"1", "2", "3",
|
||||
"4", "5", "6",
|
||||
|
||||
"1", "2", "3",
|
||||
"4", "5", "6",
|
||||
|
||||
"1", "2", "3",
|
||||
"4", "5", "6",
|
||||
|
||||
"1", "2", "3",
|
||||
"4", "5", "6"};
|
||||
|
||||
std::vector<int64_t> perm = {0, 2, 1};
|
||||
std::vector<int64_t> expected_shape({4, 3, 2});
|
||||
std::initializer_list<std::string> expected_vals = {
|
||||
"1",
|
||||
"4",
|
||||
"2",
|
||||
"5",
|
||||
"3",
|
||||
"6",
|
||||
|
||||
"1",
|
||||
"4",
|
||||
"2",
|
||||
"5",
|
||||
"3",
|
||||
"6",
|
||||
|
||||
"1",
|
||||
"4",
|
||||
"2",
|
||||
"5",
|
||||
"3",
|
||||
"6",
|
||||
|
||||
"1",
|
||||
"4",
|
||||
"2",
|
||||
"5",
|
||||
"3",
|
||||
"6"
|
||||
|
||||
};
|
||||
|
||||
TransposeTest(input_shape, input_vals, &perm, expected_shape, expected_vals);
|
||||
}
|
||||
|
||||
} // namespace test
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
Loading…
Reference in a new issue