profile init code

Signed-off-by: Liqun Fu <liqun.fu@microsoft.com>
This commit is contained in:
Liqun Fu 2025-01-16 18:54:44 -08:00
parent 5735e1bce0
commit 4523b3d06c
3 changed files with 80 additions and 11 deletions

View file

@ -87,18 +87,36 @@ class GQAAttentionBase {
bool past_present_share_buffer = past_key_data == present_key_data && past_value_data == present_value_data;
const T* k = packed_qkv ? Q + num_heads_ * sequence_length * head_size : K;
ComputeAttentionProbs<T>(static_cast<float*>(attention_probs), Q, k, seqlens_k->Data<int32_t>(), batch_size,
sequence_length, seqlen_past_kv_cache, seqlen_present_kv_cache, head_size, past_key_data,
present_key_data, past_present_share_buffer, packed_qkv, is_prompt, tp, allocator);
{
std::chrono::high_resolution_clock::time_point time_point;
if (profiler_->IsEnabled()) {
time_point = profiler_->Start();
}
ComputeAttentionProbs<T>(static_cast<float*>(attention_probs), Q, k, seqlens_k->Data<int32_t>(), batch_size,
sequence_length, seqlen_past_kv_cache, seqlen_present_kv_cache, head_size, past_key_data,
present_key_data, past_present_share_buffer, packed_qkv, is_prompt, tp, allocator);
if (profiler_->IsEnabled()) {
std::string eventName = context->GetNodeName() + "_" + "ComputeAttentionProbs";
profiler_->EndTimeAndRecordEvent(onnxruntime::profiling::KERNEL_EVENT, eventName, time_point);
}
}
// Compute the attentionScore * Value: out(B, N, S, H_v) = attention_probs(B, N, S, T) x V(B, N, T, H_v)
const T* v = packed_qkv ? Q + (num_heads_ + kv_num_heads_) * sequence_length * head_size : V;
ComputeVxAttentionScore(output->MutableData<T>(), static_cast<float*>(attention_probs), v,
seqlens_k->Data<int32_t>(),
batch_size, sequence_length, seqlen_past_kv_cache, seqlen_present_kv_cache, head_size,
hidden_size, past_value_data, present_value_data, past_present_share_buffer, packed_qkv,
is_prompt, tp, allocator);
{
std::chrono::high_resolution_clock::time_point time_point;
if (profiler_->IsEnabled()) {
time_point = profiler_->Start();
}
ComputeVxAttentionScore(output->MutableData<T>(), static_cast<float*>(attention_probs), v,
seqlens_k->Data<int32_t>(),
batch_size, sequence_length, seqlen_past_kv_cache, seqlen_present_kv_cache, head_size,
hidden_size, past_value_data, present_value_data, past_present_share_buffer, packed_qkv,
is_prompt, tp, allocator);
if (profiler_->IsEnabled()) {
std::string eventName = context->GetNodeName() + "_" + "ComputeVxAttentionScore";
profiler_->EndTimeAndRecordEvent(onnxruntime::profiling::KERNEL_EVENT, eventName, time_point);
}
}
return Status::OK();
}
@ -123,6 +141,7 @@ class GQAAttentionBase {
const bool is_prompt, // whether it is prompt
ThreadPool* tp, // thread pool
AllocatorPtr allocator) const { // allocator for temporary buffer
const ptrdiff_t packed_batch_stride =
packed_qkv ? SafeInt<ptrdiff_t>(num_heads_ + 2 * kv_num_heads_) * sequence_length * head_size
: SafeInt<ptrdiff_t>(0);

View file

@ -16,6 +16,32 @@
#include <unsupported/Eigen/SpecialFunctions>
#include <vector>
// https://github.com/microsoft/onnxruntime/blob/b9493adbe88c4681fcae71774ec3685d1390bd46/onnxruntime/core/mlas/lib/sqnbitgemm.cpp
#include <chrono>
#include "core/common/profiler.h"
class ProfilerWrapper {
public:
ProfilerWrapper() {
profiler_ = std::make_unique<onnxruntime::profiling::Profiler>();
profiler_->StartProfiling<char>("profile.json");
}
~ProfilerWrapper() {
if (profiler_) {
profiler_->EndProfiling();
}
}
onnxruntime::profiling::Profiler* operator->() {
return profiler_.get();
}
private:
std::unique_ptr<onnxruntime::profiling::Profiler> profiler_;
};
static ProfilerWrapper profiler_;
using onnxruntime::concurrency::ThreadPool;
namespace onnxruntime {
@ -112,6 +138,11 @@ Status GroupQueryAttention<T>::Compute(OpKernelContext* context) const {
T* q_rotary = Q.GetMutable<Tensor>()->MutableData<T>();
T* k_rotary = packed_qkv ? nullptr : K.GetMutable<Tensor>()->MutableData<T>();
if (do_rotary_) {
std::chrono::high_resolution_clock::time_point time_point;
if (profiler_->IsEnabled()) {
time_point = profiler_->Start();
}
// Initialize rotary parameters
rotary_embedding_helper::RotaryParameters rotary_params = {};
rotary_params.batch_size = batch_size;
@ -189,13 +220,26 @@ Status GroupQueryAttention<T>::Compute(OpKernelContext* context) const {
v_input,
v_rotary));
}
if (profiler_->IsEnabled()) {
std::string eventName = this->Node().Name() + "_" + "rotary";
profiler_->EndTimeAndRecordEvent(onnxruntime::profiling::KERNEL_EVENT, eventName, time_point);
}
}
ORT_RETURN_IF_ERROR(context->GetTempSpaceAllocator(&allocator));
std::chrono::high_resolution_clock::time_point time_point;
if (profiler_->IsEnabled()) {
time_point = profiler_->Start();
}
// Compute the attention score and apply the score to V
return ApplyAttention(q_rotary, packed_qkv ? nullptr : k_rotary, packed_qkv ? nullptr : V.Get<Tensor>().Data<T>(),
auto ret = ApplyAttention(q_rotary, packed_qkv ? nullptr : k_rotary, packed_qkv ? nullptr : V.Get<Tensor>().Data<T>(),
past_key, past_value, output, present_k, present_v,
seqlens_k, parameters, allocator, context);
if (profiler_->IsEnabled()) {
std::string eventName = this->Node().Name() + "_" + "ApplyAttention";
profiler_->EndTimeAndRecordEvent(onnxruntime::profiling::KERNEL_EVENT, eventName, time_point);
}
return ret;
}
} // namespace contrib
} // namespace onnxruntime

View file

@ -714,6 +714,7 @@ def gqa_prompt_func(
"total_sequence_length": torch.tensor([config.q_sequence_length], dtype=torch.int32).detach().cpu().numpy(),
}
sess_options = SessionOptions()
sess_options.enable_profiling = True
ort_session = InferenceSession(onnx_model_str, sess_options, providers=["CPUExecutionProvider"])
io_binding = ort_session.io_binding()
if new_k is not None:
@ -747,6 +748,11 @@ def gqa_prompt_func(
ort_output, present_k, present_v = io_binding.copy_outputs_to_cpu()
ort_output = numpy.array(ort_output)
output = torch.tensor(ort_output)
if sess_options.enable_profiling:
profile_file = ort_session.end_profiling()
print(f"Profiling data saved to: {profile_file}")
return output, present_k, present_v
else:
ort_inputs = {