Remove memset for the case no any mask (#19823)

Improved OCR model speed by 1.034 end-to-end, by eliminating unnecessary
memset when no mask is present.
This commit is contained in:
Yi-Hong Lyu 2024-03-07 13:54:16 -08:00 committed by GitHub
parent 3dfce2f1cd
commit 33578cc76e
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -140,17 +140,6 @@ class AttentionCPUBase : public AttentionBase {
if (mask_data != nullptr) {
PrepareMask(mask_index, mask_index_dims, mask_data,
causal, batch_size, sequence_length, past_sequence_length, mask_filter_value_);
} else { // no any mask
const int memset_loop_len = batch_size * num_heads_;
const double memset_cost = static_cast<double>(sequence_length) * total_sequence_length;
ThreadPool::TryParallelFor(tp, memset_loop_len, memset_cost, [&](std::ptrdiff_t begin, std::ptrdiff_t end) {
for (std::ptrdiff_t i = begin; i != end; ++i) {
const int output_offset = static_cast<int>(i) * sequence_length * total_sequence_length;
T* output = attention_probs + output_offset;
memset(output, 0, static_cast<size_t>(sequence_length) * total_sequence_length * sizeof(T));
}
});
}
const int loop_len = batch_size * num_heads_;
@ -188,7 +177,7 @@ class AttentionCPUBase : public AttentionBase {
// B: K' (B x N x) T x H (B x N x) H x T H x T
// C: attention_probs (B x N x) S x T (B x N x) S x T S x T
math::Gemm<T, ThreadPool>(CblasNoTrans, CblasTrans, sequence_length, total_sequence_length, head_size, alpha,
Q + q_input_chunk_length * i, k, 1.0,
Q + q_input_chunk_length * i, k, mask_data != nullptr ? 1.0f : 0.0f,
output, nullptr);
if (relative_position_bias_data != nullptr) {