mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-14 20:48:00 +00:00
Parallelize Transpose (#16854)
It gives up to 5.6% improvement for prompt and 2.3% improvement for token generation in LLaMA 7B case.
This commit is contained in:
parent
3c10f027de
commit
e48dc3b281
3 changed files with 22 additions and 21 deletions
|
|
@ -1,8 +1,11 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include "core/framework/copy.h"
|
||||
#include "core/framework/element_type_lists.h"
|
||||
#include "core/framework/transpose_helper.h"
|
||||
#include "core/mlas/inc/mlas.h"
|
||||
#include "core/providers/cpu/tensor/utils.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
|
||||
|
|
@ -56,7 +59,8 @@ typename std::enable_if<has_mlas_transpose<T>::value, void>::type SimpleTranspos
|
|||
|
||||
// `input_shape_override` overrides the shape of `input` for compute purposes.
|
||||
void TransposeSingleAxisOutwards(gsl::span<const size_t> permutations, const Tensor& input, Tensor& output,
|
||||
size_t from, size_t to, const TensorShape* input_shape_override = nullptr) {
|
||||
size_t from, size_t to, const TensorShape* input_shape_override = nullptr,
|
||||
concurrency::ThreadPool* tp = nullptr) {
|
||||
ORT_UNUSED_PARAMETER(permutations);
|
||||
|
||||
const auto& input_shape = input_shape_override ? *input_shape_override : input.Shape();
|
||||
|
|
@ -100,25 +104,20 @@ void TransposeSingleAxisOutwards(gsl::span<const size_t> permutations, const Ten
|
|||
break;
|
||||
}
|
||||
default: {
|
||||
// we need to use memcpy for each block
|
||||
for (int64_t l = 0; l < num_loops; ++l) {
|
||||
uint8_t* output_for_first_writer = output_data;
|
||||
TensorPitches src_strides(input_dims);
|
||||
|
||||
for (auto wwpl = 0; wwpl < writes_per_writer_per_loop; ++wwpl) {
|
||||
uint8_t* output_for_current_writer = output_for_first_writer;
|
||||
TensorPitches contig_dst_strides(output);
|
||||
|
||||
for (int64_t w = 0; w < num_writers; ++w) {
|
||||
memcpy(output_for_current_writer, input_data, bytes_per_write);
|
||||
// skip to output position for next writer
|
||||
output_for_current_writer += (writes_per_writer_per_loop * bytes_per_write);
|
||||
input_data += bytes_per_write;
|
||||
}
|
||||
|
||||
output_for_first_writer += bytes_per_write;
|
||||
}
|
||||
|
||||
output_data += writes_per_loop * bytes_per_write;
|
||||
const auto dims = input_dims.size();
|
||||
TensorShapeVector dst_strides(dims);
|
||||
for (size_t dim = 0; dim < dims; ++dim) {
|
||||
dst_strides[permutations[dim]] = contig_dst_strides[dim];
|
||||
}
|
||||
|
||||
ORT_THROW_IF_ERROR(DispatchStridedCopy<element_type_lists::All>(tp,
|
||||
output, 0, dst_strides,
|
||||
input_shape,
|
||||
input, 0, src_strides));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -233,9 +232,9 @@ void TransposeSingleAxisInwards(gsl::span<const size_t> permutations, const Tens
|
|||
|
||||
// `input_shape_override` overrides the shape of `input` for compute purposes.
|
||||
void SingleAxisTranspose(gsl::span<const size_t> permutations, const Tensor& input, Tensor& output, size_t from,
|
||||
size_t to, const TensorShape* input_shape_override) {
|
||||
size_t to, const TensorShape* input_shape_override, concurrency::ThreadPool* tp) {
|
||||
if (from > to) {
|
||||
TransposeSingleAxisOutwards(permutations, input, output, from, to, input_shape_override);
|
||||
TransposeSingleAxisOutwards(permutations, input, output, from, to, input_shape_override, tp);
|
||||
} else {
|
||||
TransposeSingleAxisInwards(permutations, input, output, from, to, input_shape_override);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -35,11 +35,13 @@ We fall back to the default implementation in all other cases, and if the input
|
|||
#include "core/common/inlined_containers.h"
|
||||
#include "core/framework/tensor_shape.h"
|
||||
#include "core/framework/tensor.h"
|
||||
#include "core/platform/threadpool.h"
|
||||
|
||||
#include "core/common/gsl.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
bool IsTransposeMovingSingleAxis(gsl::span<const size_t> permutations, size_t& from, size_t& to);
|
||||
void SingleAxisTranspose(gsl::span<const size_t> permutations, const Tensor& input, Tensor& output, size_t from,
|
||||
size_t to, const TensorShape* input_shape_override = nullptr);
|
||||
size_t to, const TensorShape* input_shape_override = nullptr,
|
||||
concurrency::ThreadPool* tp = nullptr);
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -402,7 +402,7 @@ Status Transpose::Compute(OpKernelContext* ctx) const {
|
|||
bool moving_single_axis = IsTransposeMovingSingleAxis(*p_perm, from, to);
|
||||
|
||||
if (moving_single_axis && !X.IsDataTypeString()) {
|
||||
SingleAxisTranspose(*p_perm, X, Y, from, to);
|
||||
SingleAxisTranspose(*p_perm, X, Y, from, to, nullptr, ctx->GetOperatorThreadPool());
|
||||
} else {
|
||||
// fall back to default implementation
|
||||
status = DoUntypedTranspose(*p_perm, X, Y);
|
||||
|
|
|
|||
Loading…
Reference in a new issue