Parallelize Transpose (#16854)

It gives up to 5.6% improvement for prompt and 2.3% improvement for token generation in LLaMA 7B case.
This commit is contained in:
Yi-Hong Lyu 2023-08-07 14:25:53 -07:00 committed by GitHub
parent 3c10f027de
commit e48dc3b281
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 22 additions and 21 deletions

View file

@ -1,8 +1,11 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "core/framework/copy.h"
#include "core/framework/element_type_lists.h"
#include "core/framework/transpose_helper.h"
#include "core/mlas/inc/mlas.h"
#include "core/providers/cpu/tensor/utils.h"
namespace onnxruntime {
@ -56,7 +59,8 @@ typename std::enable_if<has_mlas_transpose<T>::value, void>::type SimpleTranspos
// `input_shape_override` overrides the shape of `input` for compute purposes.
void TransposeSingleAxisOutwards(gsl::span<const size_t> permutations, const Tensor& input, Tensor& output,
size_t from, size_t to, const TensorShape* input_shape_override = nullptr) {
size_t from, size_t to, const TensorShape* input_shape_override = nullptr,
concurrency::ThreadPool* tp = nullptr) {
ORT_UNUSED_PARAMETER(permutations);
const auto& input_shape = input_shape_override ? *input_shape_override : input.Shape();
@ -100,25 +104,20 @@ void TransposeSingleAxisOutwards(gsl::span<const size_t> permutations, const Ten
break;
}
default: {
// we need to use memcpy for each block
for (int64_t l = 0; l < num_loops; ++l) {
uint8_t* output_for_first_writer = output_data;
TensorPitches src_strides(input_dims);
for (auto wwpl = 0; wwpl < writes_per_writer_per_loop; ++wwpl) {
uint8_t* output_for_current_writer = output_for_first_writer;
TensorPitches contig_dst_strides(output);
for (int64_t w = 0; w < num_writers; ++w) {
memcpy(output_for_current_writer, input_data, bytes_per_write);
// skip to output position for next writer
output_for_current_writer += (writes_per_writer_per_loop * bytes_per_write);
input_data += bytes_per_write;
}
output_for_first_writer += bytes_per_write;
}
output_data += writes_per_loop * bytes_per_write;
const auto dims = input_dims.size();
TensorShapeVector dst_strides(dims);
for (size_t dim = 0; dim < dims; ++dim) {
dst_strides[permutations[dim]] = contig_dst_strides[dim];
}
ORT_THROW_IF_ERROR(DispatchStridedCopy<element_type_lists::All>(tp,
output, 0, dst_strides,
input_shape,
input, 0, src_strides));
}
}
}
@ -233,9 +232,9 @@ void TransposeSingleAxisInwards(gsl::span<const size_t> permutations, const Tens
// `input_shape_override` overrides the shape of `input` for compute purposes.
void SingleAxisTranspose(gsl::span<const size_t> permutations, const Tensor& input, Tensor& output, size_t from,
size_t to, const TensorShape* input_shape_override) {
size_t to, const TensorShape* input_shape_override, concurrency::ThreadPool* tp) {
if (from > to) {
TransposeSingleAxisOutwards(permutations, input, output, from, to, input_shape_override);
TransposeSingleAxisOutwards(permutations, input, output, from, to, input_shape_override, tp);
} else {
TransposeSingleAxisInwards(permutations, input, output, from, to, input_shape_override);
}

View file

@ -35,11 +35,13 @@ We fall back to the default implementation in all other cases, and if the input
#include "core/common/inlined_containers.h"
#include "core/framework/tensor_shape.h"
#include "core/framework/tensor.h"
#include "core/platform/threadpool.h"
#include "core/common/gsl.h"
namespace onnxruntime {
bool IsTransposeMovingSingleAxis(gsl::span<const size_t> permutations, size_t& from, size_t& to);
void SingleAxisTranspose(gsl::span<const size_t> permutations, const Tensor& input, Tensor& output, size_t from,
size_t to, const TensorShape* input_shape_override = nullptr);
size_t to, const TensorShape* input_shape_override = nullptr,
concurrency::ThreadPool* tp = nullptr);
} // namespace onnxruntime

View file

@ -402,7 +402,7 @@ Status Transpose::Compute(OpKernelContext* ctx) const {
bool moving_single_axis = IsTransposeMovingSingleAxis(*p_perm, from, to);
if (moving_single_axis && !X.IsDataTypeString()) {
SingleAxisTranspose(*p_perm, X, Y, from, to);
SingleAxisTranspose(*p_perm, X, Y, from, to, nullptr, ctx->GetOperatorThreadPool());
} else {
// fall back to default implementation
status = DoUntypedTranspose(*p_perm, X, Y);