From e48dc3b281c60c905a021f70b609c389b129bbff Mon Sep 17 00:00:00 2001 From: Yi-Hong Lyu Date: Mon, 7 Aug 2023 14:25:53 -0700 Subject: [PATCH] Parallelize Transpose (#16854) It gives up to 5.6% improvement for prompt and 2.3% improvement for token generation in LLaMA 7B case. --- .../core/framework/transpose_helper.cc | 37 +++++++++---------- onnxruntime/core/framework/transpose_helper.h | 4 +- .../core/providers/cpu/tensor/transpose.cc | 2 +- 3 files changed, 22 insertions(+), 21 deletions(-) diff --git a/onnxruntime/core/framework/transpose_helper.cc b/onnxruntime/core/framework/transpose_helper.cc index a5535d919b..38f68215a0 100644 --- a/onnxruntime/core/framework/transpose_helper.cc +++ b/onnxruntime/core/framework/transpose_helper.cc @@ -1,8 +1,11 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +#include "core/framework/copy.h" +#include "core/framework/element_type_lists.h" #include "core/framework/transpose_helper.h" #include "core/mlas/inc/mlas.h" +#include "core/providers/cpu/tensor/utils.h" namespace onnxruntime { @@ -56,7 +59,8 @@ typename std::enable_if::value, void>::type SimpleTranspos // `input_shape_override` overrides the shape of `input` for compute purposes. void TransposeSingleAxisOutwards(gsl::span permutations, const Tensor& input, Tensor& output, - size_t from, size_t to, const TensorShape* input_shape_override = nullptr) { + size_t from, size_t to, const TensorShape* input_shape_override = nullptr, + concurrency::ThreadPool* tp = nullptr) { ORT_UNUSED_PARAMETER(permutations); const auto& input_shape = input_shape_override ? *input_shape_override : input.Shape(); @@ -100,25 +104,20 @@ void TransposeSingleAxisOutwards(gsl::span permutations, const Ten break; } default: { - // we need to use memcpy for each block - for (int64_t l = 0; l < num_loops; ++l) { - uint8_t* output_for_first_writer = output_data; + TensorPitches src_strides(input_dims); - for (auto wwpl = 0; wwpl < writes_per_writer_per_loop; ++wwpl) { - uint8_t* output_for_current_writer = output_for_first_writer; + TensorPitches contig_dst_strides(output); - for (int64_t w = 0; w < num_writers; ++w) { - memcpy(output_for_current_writer, input_data, bytes_per_write); - // skip to output position for next writer - output_for_current_writer += (writes_per_writer_per_loop * bytes_per_write); - input_data += bytes_per_write; - } - - output_for_first_writer += bytes_per_write; - } - - output_data += writes_per_loop * bytes_per_write; + const auto dims = input_dims.size(); + TensorShapeVector dst_strides(dims); + for (size_t dim = 0; dim < dims; ++dim) { + dst_strides[permutations[dim]] = contig_dst_strides[dim]; } + + ORT_THROW_IF_ERROR(DispatchStridedCopy(tp, + output, 0, dst_strides, + input_shape, + input, 0, src_strides)); } } } @@ -233,9 +232,9 @@ void TransposeSingleAxisInwards(gsl::span permutations, const Tens // `input_shape_override` overrides the shape of `input` for compute purposes. void SingleAxisTranspose(gsl::span permutations, const Tensor& input, Tensor& output, size_t from, - size_t to, const TensorShape* input_shape_override) { + size_t to, const TensorShape* input_shape_override, concurrency::ThreadPool* tp) { if (from > to) { - TransposeSingleAxisOutwards(permutations, input, output, from, to, input_shape_override); + TransposeSingleAxisOutwards(permutations, input, output, from, to, input_shape_override, tp); } else { TransposeSingleAxisInwards(permutations, input, output, from, to, input_shape_override); } diff --git a/onnxruntime/core/framework/transpose_helper.h b/onnxruntime/core/framework/transpose_helper.h index bb7a04d097..c34d5ef3f2 100644 --- a/onnxruntime/core/framework/transpose_helper.h +++ b/onnxruntime/core/framework/transpose_helper.h @@ -35,11 +35,13 @@ We fall back to the default implementation in all other cases, and if the input #include "core/common/inlined_containers.h" #include "core/framework/tensor_shape.h" #include "core/framework/tensor.h" +#include "core/platform/threadpool.h" #include "core/common/gsl.h" namespace onnxruntime { bool IsTransposeMovingSingleAxis(gsl::span permutations, size_t& from, size_t& to); void SingleAxisTranspose(gsl::span permutations, const Tensor& input, Tensor& output, size_t from, - size_t to, const TensorShape* input_shape_override = nullptr); + size_t to, const TensorShape* input_shape_override = nullptr, + concurrency::ThreadPool* tp = nullptr); } // namespace onnxruntime diff --git a/onnxruntime/core/providers/cpu/tensor/transpose.cc b/onnxruntime/core/providers/cpu/tensor/transpose.cc index 6292df05cb..277dccac35 100644 --- a/onnxruntime/core/providers/cpu/tensor/transpose.cc +++ b/onnxruntime/core/providers/cpu/tensor/transpose.cc @@ -402,7 +402,7 @@ Status Transpose::Compute(OpKernelContext* ctx) const { bool moving_single_axis = IsTransposeMovingSingleAxis(*p_perm, from, to); if (moving_single_axis && !X.IsDataTypeString()) { - SingleAxisTranspose(*p_perm, X, Y, from, to); + SingleAxisTranspose(*p_perm, X, Y, from, to, nullptr, ctx->GetOperatorThreadPool()); } else { // fall back to default implementation status = DoUntypedTranspose(*p_perm, X, Y);