Parallel Transpose_BSNH_to_BNSH (#19406)

Achieved a speedup of 1.098 in MultiHeadAttention and an end-to-end
speedup of 1.021 in the OCR model through parallelization of the
Transpose_BSNH_to_BNSH operation.
This commit is contained in:
Yi-Hong Lyu 2024-02-29 10:31:57 -08:00 committed by GitHub
parent 937cdd651e
commit ec0e4d3b65
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -58,11 +58,12 @@ Status Reshape_BSD_to_BSNH(Tensor* qkv,
// Transpose Q/K/V from BxSxNxH to BxNxSxH
Status Transpose_BSNH_to_BNSH(const Tensor* qkv,
OrtValue& qkv_transposed) {
OrtValue& qkv_transposed,
concurrency::ThreadPool* tp = nullptr) {
std::vector<size_t> permutations({0, 2, 1, 3});
gsl::span<const size_t> permutations_span{permutations};
size_t from = 2, to = 1;
SingleAxisTranspose(permutations_span, *qkv, *qkv_transposed.GetMutable<Tensor>(), from, to);
SingleAxisTranspose(permutations_span, *qkv, *qkv_transposed.GetMutable<Tensor>(), from, to, nullptr, tp);
return Status::OK();
}
@ -143,7 +144,8 @@ Status AddBiasTranspose(const Tensor* qkv, // Input: Q/K/V dat
ORT_RETURN_IF_ERROR(Reshape_BSD_to_BSNH(qkv_with_bias.GetMutable<Tensor>(), batch_size, sequence_length, num_heads, head_size));
// Transpose Q from BxSxNxH to BxNxSxH
ORT_RETURN_IF_ERROR(Transpose_BSNH_to_BNSH(qkv_with_bias.GetMutable<Tensor>(), qkv_with_bias_transposed));
auto tp = context->GetOperatorThreadPool();
ORT_RETURN_IF_ERROR(Transpose_BSNH_to_BNSH(qkv_with_bias.GetMutable<Tensor>(), qkv_with_bias_transposed, tp));
return Status::OK();
}