From 4523b3d06c19018e72354f7c7837658bbae01966 Mon Sep 17 00:00:00 2001 From: Liqun Fu Date: Thu, 16 Jan 2025 18:54:44 -0800 Subject: [PATCH] profile init code Signed-off-by: Liqun Fu --- .../contrib_ops/cpu/bert/gqa_attention_base.h | 39 ++++++++++++---- .../cpu/bert/group_query_attention.cc | 46 ++++++++++++++++++- .../test/python/transformers/test_gqa_cpu.py | 6 +++ 3 files changed, 80 insertions(+), 11 deletions(-) diff --git a/onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h b/onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h index ccaeb6654e..a73a1a4d38 100644 --- a/onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h +++ b/onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h @@ -87,18 +87,36 @@ class GQAAttentionBase { bool past_present_share_buffer = past_key_data == present_key_data && past_value_data == present_value_data; const T* k = packed_qkv ? Q + num_heads_ * sequence_length * head_size : K; - ComputeAttentionProbs(static_cast(attention_probs), Q, k, seqlens_k->Data(), batch_size, - sequence_length, seqlen_past_kv_cache, seqlen_present_kv_cache, head_size, past_key_data, - present_key_data, past_present_share_buffer, packed_qkv, is_prompt, tp, allocator); - + { + std::chrono::high_resolution_clock::time_point time_point; + if (profiler_->IsEnabled()) { + time_point = profiler_->Start(); + } + ComputeAttentionProbs(static_cast(attention_probs), Q, k, seqlens_k->Data(), batch_size, + sequence_length, seqlen_past_kv_cache, seqlen_present_kv_cache, head_size, past_key_data, + present_key_data, past_present_share_buffer, packed_qkv, is_prompt, tp, allocator); + if (profiler_->IsEnabled()) { + std::string eventName = context->GetNodeName() + "_" + "ComputeAttentionProbs"; + profiler_->EndTimeAndRecordEvent(onnxruntime::profiling::KERNEL_EVENT, eventName, time_point); + } + } // Compute the attentionScore * Value: out(B, N, S, H_v) = attention_probs(B, N, S, T) x V(B, N, T, H_v) const T* v = packed_qkv ? Q + (num_heads_ + kv_num_heads_) * sequence_length * head_size : V; - ComputeVxAttentionScore(output->MutableData(), static_cast(attention_probs), v, - seqlens_k->Data(), - batch_size, sequence_length, seqlen_past_kv_cache, seqlen_present_kv_cache, head_size, - hidden_size, past_value_data, present_value_data, past_present_share_buffer, packed_qkv, - is_prompt, tp, allocator); - + { + std::chrono::high_resolution_clock::time_point time_point; + if (profiler_->IsEnabled()) { + time_point = profiler_->Start(); + } + ComputeVxAttentionScore(output->MutableData(), static_cast(attention_probs), v, + seqlens_k->Data(), + batch_size, sequence_length, seqlen_past_kv_cache, seqlen_present_kv_cache, head_size, + hidden_size, past_value_data, present_value_data, past_present_share_buffer, packed_qkv, + is_prompt, tp, allocator); + if (profiler_->IsEnabled()) { + std::string eventName = context->GetNodeName() + "_" + "ComputeVxAttentionScore"; + profiler_->EndTimeAndRecordEvent(onnxruntime::profiling::KERNEL_EVENT, eventName, time_point); + } + } return Status::OK(); } @@ -123,6 +141,7 @@ class GQAAttentionBase { const bool is_prompt, // whether it is prompt ThreadPool* tp, // thread pool AllocatorPtr allocator) const { // allocator for temporary buffer + const ptrdiff_t packed_batch_stride = packed_qkv ? SafeInt(num_heads_ + 2 * kv_num_heads_) * sequence_length * head_size : SafeInt(0); diff --git a/onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc b/onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc index 8f662cd388..003df4680c 100644 --- a/onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc +++ b/onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc @@ -16,6 +16,32 @@ #include #include +// https://github.com/microsoft/onnxruntime/blob/b9493adbe88c4681fcae71774ec3685d1390bd46/onnxruntime/core/mlas/lib/sqnbitgemm.cpp +#include +#include "core/common/profiler.h" +class ProfilerWrapper { + public: + ProfilerWrapper() { + profiler_ = std::make_unique(); + profiler_->StartProfiling("profile.json"); + } + + ~ProfilerWrapper() { + if (profiler_) { + profiler_->EndProfiling(); + } + } + + onnxruntime::profiling::Profiler* operator->() { + return profiler_.get(); + } + + private: + std::unique_ptr profiler_; +}; + +static ProfilerWrapper profiler_; + using onnxruntime::concurrency::ThreadPool; namespace onnxruntime { @@ -112,6 +138,11 @@ Status GroupQueryAttention::Compute(OpKernelContext* context) const { T* q_rotary = Q.GetMutable()->MutableData(); T* k_rotary = packed_qkv ? nullptr : K.GetMutable()->MutableData(); if (do_rotary_) { + std::chrono::high_resolution_clock::time_point time_point; + if (profiler_->IsEnabled()) { + time_point = profiler_->Start(); + } + // Initialize rotary parameters rotary_embedding_helper::RotaryParameters rotary_params = {}; rotary_params.batch_size = batch_size; @@ -189,13 +220,26 @@ Status GroupQueryAttention::Compute(OpKernelContext* context) const { v_input, v_rotary)); } + if (profiler_->IsEnabled()) { + std::string eventName = this->Node().Name() + "_" + "rotary"; + profiler_->EndTimeAndRecordEvent(onnxruntime::profiling::KERNEL_EVENT, eventName, time_point); + } } ORT_RETURN_IF_ERROR(context->GetTempSpaceAllocator(&allocator)); + std::chrono::high_resolution_clock::time_point time_point; + if (profiler_->IsEnabled()) { + time_point = profiler_->Start(); + } // Compute the attention score and apply the score to V - return ApplyAttention(q_rotary, packed_qkv ? nullptr : k_rotary, packed_qkv ? nullptr : V.Get().Data(), + auto ret = ApplyAttention(q_rotary, packed_qkv ? nullptr : k_rotary, packed_qkv ? nullptr : V.Get().Data(), past_key, past_value, output, present_k, present_v, seqlens_k, parameters, allocator, context); + if (profiler_->IsEnabled()) { + std::string eventName = this->Node().Name() + "_" + "ApplyAttention"; + profiler_->EndTimeAndRecordEvent(onnxruntime::profiling::KERNEL_EVENT, eventName, time_point); + } + return ret; } } // namespace contrib } // namespace onnxruntime diff --git a/onnxruntime/test/python/transformers/test_gqa_cpu.py b/onnxruntime/test/python/transformers/test_gqa_cpu.py index 77b4b326bf..b9a81f0c5d 100644 --- a/onnxruntime/test/python/transformers/test_gqa_cpu.py +++ b/onnxruntime/test/python/transformers/test_gqa_cpu.py @@ -714,6 +714,7 @@ def gqa_prompt_func( "total_sequence_length": torch.tensor([config.q_sequence_length], dtype=torch.int32).detach().cpu().numpy(), } sess_options = SessionOptions() + sess_options.enable_profiling = True ort_session = InferenceSession(onnx_model_str, sess_options, providers=["CPUExecutionProvider"]) io_binding = ort_session.io_binding() if new_k is not None: @@ -747,6 +748,11 @@ def gqa_prompt_func( ort_output, present_k, present_v = io_binding.copy_outputs_to_cpu() ort_output = numpy.array(ort_output) output = torch.tensor(ort_output) + + if sess_options.enable_profiling: + profile_file = ort_session.end_profiling() + print(f"Profiling data saved to: {profile_file}") + return output, present_k, present_v else: ort_inputs = {