mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-04 23:59:56 +00:00
Parallel Transpose_BSNH_to_BNSH (#19406)
Achieved a speedup of 1.098 in MultiHeadAttention and an end-to-end speedup of 1.021 in the OCR model through parallelization of the Transpose_BSNH_to_BNSH operation.
This commit is contained in:
parent
937cdd651e
commit
ec0e4d3b65
1 changed files with 5 additions and 3 deletions
|
|
@ -58,11 +58,12 @@ Status Reshape_BSD_to_BSNH(Tensor* qkv,
|
|||
|
||||
// Transpose Q/K/V from BxSxNxH to BxNxSxH
|
||||
Status Transpose_BSNH_to_BNSH(const Tensor* qkv,
|
||||
OrtValue& qkv_transposed) {
|
||||
OrtValue& qkv_transposed,
|
||||
concurrency::ThreadPool* tp = nullptr) {
|
||||
std::vector<size_t> permutations({0, 2, 1, 3});
|
||||
gsl::span<const size_t> permutations_span{permutations};
|
||||
size_t from = 2, to = 1;
|
||||
SingleAxisTranspose(permutations_span, *qkv, *qkv_transposed.GetMutable<Tensor>(), from, to);
|
||||
SingleAxisTranspose(permutations_span, *qkv, *qkv_transposed.GetMutable<Tensor>(), from, to, nullptr, tp);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
|
@ -143,7 +144,8 @@ Status AddBiasTranspose(const Tensor* qkv, // Input: Q/K/V dat
|
|||
ORT_RETURN_IF_ERROR(Reshape_BSD_to_BSNH(qkv_with_bias.GetMutable<Tensor>(), batch_size, sequence_length, num_heads, head_size));
|
||||
|
||||
// Transpose Q from BxSxNxH to BxNxSxH
|
||||
ORT_RETURN_IF_ERROR(Transpose_BSNH_to_BNSH(qkv_with_bias.GetMutable<Tensor>(), qkv_with_bias_transposed));
|
||||
auto tp = context->GetOperatorThreadPool();
|
||||
ORT_RETURN_IF_ERROR(Transpose_BSNH_to_BNSH(qkv_with_bias.GetMutable<Tensor>(), qkv_with_bias_transposed, tp));
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in a new issue