mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-29 23:06:41 +00:00
Implement FlashAttention for CPU (#20805)
### Description Implement [FlashAttention](https://arxiv.org/pdf/2205.14135) and [FlashAttention-2](https://arxiv.org/pdf/2307.08691) for MultiHeadAttention on CPU. ### Motivation and Context Accelerate the execution of MultiHeadAttention. Current performance: 10ms vs 16ms (com.microsoft.MultiHeadAttention) on my Linux machine and 10ms vs 38ms (com.microsoft.MultiHeadAttention) on my Windows machine. May need further optimizations. --------- Co-authored-by: Tianlei Wu <tlwu@microsoft.com> Co-authored-by: Qingnan Duan <qiduan@microsoft.com>
This commit is contained in:
parent
33e7c7f6ec
commit
80b56feb41
10 changed files with 363 additions and 5 deletions
|
|
@ -39,6 +39,7 @@ onnxruntime_add_static_library(onnxruntime_mlas
|
|||
${MLAS_SRC_DIR}/sqnbitgemm.h
|
||||
${MLAS_SRC_DIR}/sqnbitgemm.cpp
|
||||
${MLAS_SRC_DIR}/sqnbitgemm_q8_block.h
|
||||
${MLAS_SRC_DIR}/flashattn.cpp
|
||||
)
|
||||
|
||||
target_sources(onnxruntime_mlas PRIVATE
|
||||
|
|
|
|||
|
|
@ -10,8 +10,12 @@
|
|||
#include "core/framework/tensorprotoutils.h"
|
||||
#include "core/graph/onnx_protobuf.h"
|
||||
#include "core/common/safeint.h"
|
||||
#include "core/platform/env_var_utils.h"
|
||||
#include "core/platform/threadpool.h"
|
||||
#include "core/mlas/inc/mlas.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <type_traits>
|
||||
#include <unsupported/Eigen/SpecialFunctions>
|
||||
#include <vector>
|
||||
|
||||
|
|
@ -39,6 +43,11 @@ MultiHeadAttention<T>::MultiHeadAttention(const OpKernelInfo& info) : OpKernel(i
|
|||
|
||||
mask_filter_value_ = info.GetAttrOrDefault<float>("mask_filter_value", -10000.0f);
|
||||
is_unidirectional_ = info.GetAttrOrDefault<int64_t>("unidirectional", 0) == 1;
|
||||
|
||||
const auto& env = Env::Default();
|
||||
l2_cache_size_ = env.GetL2CacheSize();
|
||||
|
||||
disable_flash_ = ParseEnvironmentVariableWithDefault<bool>(attention::kDisableFlashAttention, false);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
|
|
@ -60,7 +69,6 @@ Status MultiHeadAttention<T>::Compute(OpKernelContext* context) const {
|
|||
}
|
||||
|
||||
AttentionParameters parameters = {};
|
||||
constexpr float scale = 1.0f;
|
||||
bool past_present_share_buffer = false;
|
||||
ORT_RETURN_IF_ERROR(multihead_attention_helper::CheckInputs<Tensor>(query,
|
||||
key,
|
||||
|
|
@ -74,7 +82,7 @@ Status MultiHeadAttention<T>::Compute(OpKernelContext* context) const {
|
|||
¶meters,
|
||||
num_heads_,
|
||||
mask_filter_value_,
|
||||
scale,
|
||||
scale_,
|
||||
is_unidirectional_,
|
||||
past_present_share_buffer,
|
||||
false));
|
||||
|
|
@ -99,8 +107,14 @@ Status MultiHeadAttention<T>::Compute(OpKernelContext* context) const {
|
|||
const int v_bias_offset = 2 * qk_hidden_size;
|
||||
|
||||
// If optional outputs aren't needed, present_k and present_v will be null
|
||||
std::vector<int64_t> present_k_shape({static_cast<int64_t>(batch_size), static_cast<int64_t>(num_heads_), static_cast<int64_t>(total_kv_sequence_length), static_cast<int64_t>(qk_head_size)});
|
||||
std::vector<int64_t> present_v_shape({static_cast<int64_t>(batch_size), static_cast<int64_t>(num_heads_), static_cast<int64_t>(total_kv_sequence_length), static_cast<int64_t>(v_head_size)});
|
||||
std::vector<int64_t> present_k_shape({static_cast<int64_t>(batch_size),
|
||||
static_cast<int64_t>(num_heads_),
|
||||
static_cast<int64_t>(total_kv_sequence_length),
|
||||
static_cast<int64_t>(qk_head_size)});
|
||||
std::vector<int64_t> present_v_shape({static_cast<int64_t>(batch_size),
|
||||
static_cast<int64_t>(num_heads_),
|
||||
static_cast<int64_t>(total_kv_sequence_length),
|
||||
static_cast<int64_t>(v_head_size)});
|
||||
Tensor* present_k = context->Output(1, present_k_shape);
|
||||
Tensor* present_v = context->Output(2, present_v_shape);
|
||||
|
||||
|
|
@ -138,6 +152,70 @@ Status MultiHeadAttention<T>::Compute(OpKernelContext* context) const {
|
|||
ORT_RETURN_IF_ERROR(MaybeTransposeToBNSHAndAddBias<T>(
|
||||
context, allocator, batch_size, num_heads_, kv_sequence_length, v_head_size, value, bias, v_bias_offset, V));
|
||||
|
||||
if (std::is_same_v<T, float> &&
|
||||
!disable_flash_ &&
|
||||
!is_unidirectional_ &&
|
||||
key_padding_mask == nullptr &&
|
||||
extra_add_qk == nullptr &&
|
||||
past_key == nullptr &&
|
||||
past_value == nullptr &&
|
||||
present_k == nullptr &&
|
||||
present_v == nullptr &&
|
||||
l2_cache_size_ > 0) {
|
||||
MlasFlashAttentionThreadedArgs args;
|
||||
args.batch_size = batch_size;
|
||||
args.num_heads = num_heads_;
|
||||
args.q_sequence_length = q_sequence_length;
|
||||
args.kv_sequence_length = kv_sequence_length;
|
||||
args.qk_head_size = qk_head_size;
|
||||
args.v_head_size = v_head_size;
|
||||
args.scale = (scale_ == 0.0f) ? 1.0f / sqrt(static_cast<float>(qk_head_size)) : scale_;
|
||||
/*
|
||||
q_block_size, kv_block_size correspond to Br, Bc in the FlashAttention paper.
|
||||
Let M = l2_cache_size / sizeof(float)
|
||||
In the FlashAttention kernel, there are 5 big matrices that we need to keep in L2 cache:
|
||||
slice of Q -- [Br, qk_head_size]
|
||||
slice of K -- [Bc, qk_head_size]
|
||||
slice of V -- [Bc, v_head_size]
|
||||
result of QK -- [Br, Bc]
|
||||
temporary output (same shape as QKV) -- [Br, v_head_size]
|
||||
The total size of these matrices is (Br + Bc) * (qk_head_size + v_head_size) + Br * Bc
|
||||
By taking Bc = M / (4 * (qk_head_size + v_head_size)), and Br = min(Bc, qk_head_size + v_head_size), we have
|
||||
(Br + Bc) * (qk_head_size + v_head_size) + Br * Bc
|
||||
<= 2 * Bc * (qk_head_size + v_head_size) + Br * Bc
|
||||
<= 2 * Bc * (qk_head_size + v_head_size) + M/4
|
||||
<= 2 * M/4 + M/4 = M * (3/4)
|
||||
|
||||
We leave 1/4 of the L2 cache for
|
||||
1. storing small tensors l and m
|
||||
2. instruction (code)
|
||||
*/
|
||||
args.kv_block_size = l2_cache_size_ / (static_cast<int>(sizeof(float)) * 4 * (qk_head_size + v_head_size));
|
||||
args.kv_block_size = std::max(args.kv_block_size, 1); // avoid kv_block_size = 0
|
||||
args.q_block_size = std::min(args.kv_block_size, qk_head_size + v_head_size);
|
||||
args.kv_block_size = std::min(args.kv_block_size, kv_sequence_length); // No point to have kv_block_size > kv_sequence_length
|
||||
args.q_block_size = std::min(args.q_block_size, q_sequence_length); // No point to have q_block_size > q_sequence_length
|
||||
|
||||
auto* tp = context->GetOperatorThreadPool();
|
||||
args.thread_count = concurrency::ThreadPool::DegreeOfParallelism(tp);
|
||||
args.buffer_size_per_thread = (static_cast<size_t>(args.q_block_size) * 2 +
|
||||
static_cast<size_t>(args.q_block_size) * static_cast<size_t>(args.kv_block_size) +
|
||||
static_cast<size_t>(args.q_block_size) * static_cast<size_t>(args.v_head_size)) *
|
||||
sizeof(float);
|
||||
size_t buffer_bytes = args.buffer_size_per_thread * args.thread_count;
|
||||
IAllocatorUniquePtr<void> buffer = IAllocator::MakeUniquePtr<void>(allocator, buffer_bytes);
|
||||
|
||||
args.buffer = reinterpret_cast<float*>(buffer.get());
|
||||
|
||||
args.query = Q.Get<Tensor>().Data<float>();
|
||||
args.key = K.Get<Tensor>().Data<float>();
|
||||
args.value = V.Get<Tensor>().Data<float>();
|
||||
args.output = output->MutableData<float>();
|
||||
|
||||
MlasFlashAttention(&args, tp);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Compute the attention score and apply the score to V
|
||||
return ApplyAttention(Q.GetMutable<Tensor>()->MutableData<T>(),
|
||||
K.GetMutable<Tensor>()->MutableData<T>(),
|
||||
|
|
|
|||
|
|
@ -19,6 +19,8 @@ class MultiHeadAttention final : public OpKernel, public AttentionCPUBase {
|
|||
int num_heads_; // number of attention heads
|
||||
float mask_filter_value_;
|
||||
bool is_unidirectional_;
|
||||
bool disable_flash_;
|
||||
int l2_cache_size_;
|
||||
};
|
||||
|
||||
} // namespace contrib
|
||||
|
|
|
|||
|
|
@ -1825,3 +1825,35 @@ MlasNhwcAvgPool(
|
|||
);
|
||||
|
||||
#endif
|
||||
|
||||
struct MlasFlashAttentionThreadedArgs {
|
||||
int batch_size;
|
||||
int num_heads;
|
||||
int q_sequence_length;
|
||||
int kv_sequence_length;
|
||||
int qk_head_size;
|
||||
int v_head_size;
|
||||
int q_block_size;
|
||||
int kv_block_size;
|
||||
float scale;
|
||||
int thread_count;
|
||||
float* buffer;
|
||||
size_t buffer_size_per_thread;
|
||||
const float* query;
|
||||
const float* key;
|
||||
const float* value;
|
||||
float* output;
|
||||
};
|
||||
|
||||
/**
|
||||
* @brief Per-thread worker function for fp32 Flash Attention
|
||||
* @param thread_id Thread index
|
||||
* @param args Arguments
|
||||
* @return
|
||||
*/
|
||||
void
|
||||
MLASCALL
|
||||
MlasFlashAttention(
|
||||
MlasFlashAttentionThreadedArgs* args,
|
||||
MLAS_THREADPOOL* ThreadPool
|
||||
);
|
||||
|
|
|
|||
167
onnxruntime/core/mlas/lib/flashattn.cpp
Normal file
167
onnxruntime/core/mlas/lib/flashattn.cpp
Normal file
|
|
@ -0,0 +1,167 @@
|
|||
#include <numeric>
|
||||
|
||||
#include "mlasi.h"
|
||||
|
||||
void
|
||||
MlasFlashAttentionThreaded(
|
||||
void* argptr,
|
||||
std::ptrdiff_t thread_id
|
||||
)
|
||||
{
|
||||
const MlasFlashAttentionThreadedArgs* args = reinterpret_cast<MlasFlashAttentionThreadedArgs*>(argptr);
|
||||
ptrdiff_t q_block_size = static_cast<ptrdiff_t>(args->q_block_size);
|
||||
ptrdiff_t kv_block_size = static_cast<ptrdiff_t>(args->kv_block_size);
|
||||
ptrdiff_t batch_size = static_cast<ptrdiff_t>(args->batch_size);
|
||||
ptrdiff_t num_heads = static_cast<ptrdiff_t>(args->num_heads);
|
||||
ptrdiff_t q_sequence_length = static_cast<ptrdiff_t>(args->q_sequence_length);
|
||||
ptrdiff_t kv_sequence_length = static_cast<ptrdiff_t>(args->kv_sequence_length);
|
||||
ptrdiff_t qk_head_size = static_cast<ptrdiff_t>(args->qk_head_size);
|
||||
ptrdiff_t v_head_size = static_cast<ptrdiff_t>(args->v_head_size);
|
||||
float* buffer = args->buffer;
|
||||
ptrdiff_t buffer_size_per_thread = static_cast<ptrdiff_t>(args->buffer_size_per_thread);
|
||||
ptrdiff_t thread_count = static_cast<ptrdiff_t>(args->thread_count);
|
||||
const float* query = args->query;
|
||||
const float* key = args->key;
|
||||
const float* value = args->value;
|
||||
float* output = args->output;
|
||||
|
||||
#if defined(MLAS_TARGET_AMD64) || defined(MLAS_TARGET_LARCH64)
|
||||
auto&& mlas_platform = GetMlasPlatform();
|
||||
#endif
|
||||
|
||||
ptrdiff_t q_chunk_count = (q_sequence_length + (q_block_size - 1)) / q_block_size;
|
||||
|
||||
ptrdiff_t task_start = 0;
|
||||
ptrdiff_t task_end = 0;
|
||||
ptrdiff_t total_task_count = batch_size * num_heads * q_chunk_count;
|
||||
ptrdiff_t quotient = total_task_count / thread_count;
|
||||
ptrdiff_t remainder = total_task_count % thread_count;
|
||||
if (thread_id < remainder) {
|
||||
task_start = (quotient + 1) * thread_id;
|
||||
task_end = task_start + quotient + 1;
|
||||
} else {
|
||||
task_start = quotient * thread_id + remainder;
|
||||
task_end = task_start + quotient;
|
||||
}
|
||||
|
||||
for (ptrdiff_t task_index = task_start; task_index < task_end; ++task_index) {
|
||||
ptrdiff_t batch_idx = task_index;
|
||||
ptrdiff_t q_idx = (batch_idx % q_chunk_count) * q_block_size;
|
||||
batch_idx /= q_chunk_count;
|
||||
ptrdiff_t head_idx = batch_idx % num_heads;
|
||||
batch_idx /= num_heads;
|
||||
|
||||
char* buffer_current_thread = reinterpret_cast<char*>(buffer) + thread_id * buffer_size_per_thread;
|
||||
float* l = reinterpret_cast<float*>(buffer_current_thread);
|
||||
float* m = l + q_block_size;
|
||||
for (ptrdiff_t t = 0; t < q_block_size; ++t) {
|
||||
m[t] = std::numeric_limits<float>::lowest();
|
||||
}
|
||||
float* intermediate = m + q_block_size;
|
||||
float* temp_output = intermediate + q_block_size * kv_block_size;
|
||||
float negmax = 0;
|
||||
|
||||
for (ptrdiff_t ir = 0; ir < kv_sequence_length; ir += kv_block_size) {
|
||||
/*
|
||||
S = Q[batch_idx, head_idx, q_idx:q_idx+q_block_size, :] * (K[batch_idx, head_idx, ir:ir+kv_block_size, :]).T
|
||||
old_m = m
|
||||
m = max(m, rowmax(S))
|
||||
diff = old_m - m
|
||||
S = exp(S - m)
|
||||
l = exp(diff) * l + rowsum(S)
|
||||
O = diag(exp(diff)) * O + S * V[batch_idx, head_idx, ir:ir+kv_block_size, :]
|
||||
*/
|
||||
ptrdiff_t h = batch_idx * num_heads + head_idx;
|
||||
const float* inputQ = query + (h * q_sequence_length + q_idx) * qk_head_size;
|
||||
const float* inputK = key + (h * kv_sequence_length + ir) * qk_head_size;
|
||||
const float* inputV = value + (h * kv_sequence_length + ir) * v_head_size;
|
||||
|
||||
size_t row_size_q_capped = static_cast<size_t>(std::min(q_block_size, q_sequence_length - q_idx));
|
||||
size_t row_size_kv_capped = static_cast<size_t>(std::min(kv_block_size, kv_sequence_length - ir));
|
||||
|
||||
MlasSgemmOperation(CBLAS_TRANSPOSE::CblasNoTrans,
|
||||
CBLAS_TRANSPOSE::CblasTrans,
|
||||
row_size_q_capped,
|
||||
row_size_kv_capped,
|
||||
static_cast<size_t>(qk_head_size),
|
||||
args->scale,
|
||||
inputQ,
|
||||
static_cast<size_t>(qk_head_size),
|
||||
inputK,
|
||||
static_cast<size_t>(qk_head_size),
|
||||
0.0f,
|
||||
intermediate,
|
||||
row_size_kv_capped);
|
||||
|
||||
for (ptrdiff_t irow = 0; irow < static_cast<ptrdiff_t>(row_size_q_capped); ++irow) {
|
||||
float* p = intermediate + irow * row_size_kv_capped;
|
||||
|
||||
#if defined(MLAS_TARGET_AMD64) || defined(MLAS_TARGET_LARCH64)
|
||||
float rowmax = mlas_platform.ReduceMaximumF32Kernel(p, row_size_kv_capped);
|
||||
#else
|
||||
float rowmax = MlasReduceMaximumF32Kernel(p, row_size_kv_capped);
|
||||
#endif
|
||||
float m_diff = m[irow];
|
||||
m[irow] = std::max(m[irow], rowmax); // new m
|
||||
negmax = -m[irow];
|
||||
m_diff -= m[irow]; // old - new (less than 0)
|
||||
|
||||
#if defined(MLAS_TARGET_AMD64)
|
||||
float rowsum = mlas_platform.ComputeSumExpF32Kernel(p, p, row_size_kv_capped, &negmax);
|
||||
#else
|
||||
float rowsum = MlasComputeSumExpF32Kernel(p, p, row_size_kv_capped, &negmax);
|
||||
#endif
|
||||
|
||||
// Note: for ir == 0, there is actually no need to calculate exp_diff
|
||||
if (ir != 0) {
|
||||
float exp_diff = std::exp(m_diff);
|
||||
l[irow] = exp_diff * l[irow] + rowsum;
|
||||
|
||||
for (ptrdiff_t icol = 0; icol < v_head_size; ++icol) {
|
||||
temp_output[irow * v_head_size + icol] = exp_diff * temp_output[irow * v_head_size + icol];
|
||||
}
|
||||
} else {
|
||||
l[irow] = rowsum;
|
||||
// When ir == 0, there is no need to scale the old result because it is zero.
|
||||
}
|
||||
}
|
||||
MlasSgemmOperation(CBLAS_TRANSPOSE::CblasNoTrans,
|
||||
CBLAS_TRANSPOSE::CblasNoTrans,
|
||||
row_size_q_capped,
|
||||
static_cast<size_t>(v_head_size),
|
||||
row_size_kv_capped,
|
||||
1.0f,
|
||||
intermediate,
|
||||
row_size_kv_capped,
|
||||
inputV,
|
||||
static_cast<size_t>(v_head_size),
|
||||
ir == 0 ? 0.0f : 1.0f,
|
||||
temp_output,
|
||||
static_cast<size_t>(v_head_size));
|
||||
}
|
||||
|
||||
float* output_row = output + ((batch_idx * q_sequence_length + q_idx) * num_heads + head_idx) * v_head_size;
|
||||
ptrdiff_t row_size_q_valid = std::min(q_block_size, q_sequence_length - q_idx);
|
||||
// TODO: leverage advanced instruction sets
|
||||
for (ptrdiff_t irow = 0; irow < row_size_q_valid; ++irow) {
|
||||
for (ptrdiff_t icol = 0; icol < v_head_size; ++icol) {
|
||||
output_row[icol] = temp_output[irow * v_head_size + icol] / l[irow];
|
||||
}
|
||||
output_row += num_heads * v_head_size;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void
|
||||
MLASCALL
|
||||
MlasFlashAttention(
|
||||
MlasFlashAttentionThreadedArgs* args,
|
||||
MLAS_THREADPOOL* ThreadPool
|
||||
)
|
||||
{
|
||||
MlasExecuteThreaded(
|
||||
MlasFlashAttentionThreaded,
|
||||
static_cast<void *>(args),
|
||||
static_cast<std::ptrdiff_t>(args->thread_count),
|
||||
ThreadPool);
|
||||
}
|
||||
|
|
@ -147,6 +147,8 @@ class Env {
|
|||
|
||||
virtual std::vector<LogicalProcessors> GetDefaultThreadAffinities() const = 0;
|
||||
|
||||
virtual int GetL2CacheSize() const = 0;
|
||||
|
||||
/// \brief Returns the number of micro-seconds since the Unix epoch.
|
||||
virtual uint64_t NowMicros() const {
|
||||
return env_time_->NowMicros();
|
||||
|
|
|
|||
|
|
@ -43,6 +43,10 @@ limitations under the License.
|
|||
#define ORT_USE_CPUINFO
|
||||
#endif
|
||||
|
||||
#if defined(__APPLE__) || defined(__FreeBSD__) || defined(__NetBSD__)
|
||||
#include <sys/sysctl.h>
|
||||
#endif
|
||||
|
||||
#include "core/common/common.h"
|
||||
#include <gsl/gsl>
|
||||
#include "core/common/logging/logging.h"
|
||||
|
|
@ -302,6 +306,22 @@ class PosixEnv : public Env {
|
|||
return ret;
|
||||
}
|
||||
|
||||
int GetL2CacheSize() const override {
|
||||
#ifdef _SC_LEVEL2_CACHE_SIZE
|
||||
return static_cast<int>(sysconf(_SC_LEVEL2_CACHE_SIZE));
|
||||
#else
|
||||
int value = 0; // unknown
|
||||
#if (defined(__APPLE__) || defined(__FreeBSD__) || defined(__NetBSD__)) && defined(HW_L2CACHESIZE)
|
||||
int mib[2] = {CTL_HW, HW_L2CACHESIZE};
|
||||
size_t len = sizeof(value);
|
||||
if (sysctl(mib, 2, &value, &len, NULL, 0) < 0) {
|
||||
return -1; // error
|
||||
}
|
||||
#endif
|
||||
return value;
|
||||
#endif
|
||||
}
|
||||
|
||||
void SleepForMicroseconds(int64_t micros) const override {
|
||||
while (micros > 0) {
|
||||
timespec sleep_time;
|
||||
|
|
|
|||
|
|
@ -303,6 +303,10 @@ std::vector<LogicalProcessors> WindowsEnv::GetDefaultThreadAffinities() const {
|
|||
return cores_.empty() ? std::vector<LogicalProcessors>(DefaultNumCores(), LogicalProcessors{}) : cores_;
|
||||
}
|
||||
|
||||
int WindowsEnv::GetL2CacheSize() const {
|
||||
return l2_cache_size_;
|
||||
}
|
||||
|
||||
WindowsEnv& WindowsEnv::Instance() {
|
||||
static WindowsEnv default_env;
|
||||
return default_env;
|
||||
|
|
@ -851,6 +855,7 @@ ProcessorInfo WindowsEnv::GetProcessorAffinityMask(int global_processor_id) cons
|
|||
}
|
||||
|
||||
WindowsEnv::WindowsEnv() {
|
||||
l2_cache_size_ = 0;
|
||||
InitializeCpuInfo();
|
||||
}
|
||||
|
||||
|
|
@ -924,9 +929,57 @@ void WindowsEnv::InitializeCpuInfo() {
|
|||
}
|
||||
iter += size;
|
||||
}
|
||||
|
||||
DWORD newLength = 0;
|
||||
GetLogicalProcessorInformationEx(RelationCache, nullptr, &newLength);
|
||||
last_error = GetLastError();
|
||||
if (last_error != ERROR_INSUFFICIENT_BUFFER) {
|
||||
const auto error_code = GetLastError();
|
||||
if (logging::LoggingManager::HasDefaultLogger()) {
|
||||
LOGS_DEFAULT(ERROR) << "Failed to calculate byte size for saving cpu info on windows"
|
||||
<< ", error code: " << error_code
|
||||
<< ", error msg: " << std::system_category().message(error_code);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
if (newLength > returnLength) {
|
||||
// Re-allocate
|
||||
allocation = std::make_unique<char[]>(newLength);
|
||||
processorInfos = reinterpret_cast<SYSTEM_LOGICAL_PROCESSOR_INFORMATION_EX*>(allocation.get());
|
||||
}
|
||||
|
||||
if (!GetLogicalProcessorInformationEx(RelationCache, processorInfos, &newLength)) {
|
||||
const auto error_code = GetLastError();
|
||||
if (logging::LoggingManager::HasDefaultLogger()) {
|
||||
LOGS_DEFAULT(ERROR) << "Failed to fetch cpu info on windows"
|
||||
<< ", error code: " << error_code
|
||||
<< ", error msg: " << std::system_category().message(error_code);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
iter = reinterpret_cast<const BYTE*>(processorInfos);
|
||||
end = iter + newLength;
|
||||
|
||||
while (iter < end) {
|
||||
auto processor_info = reinterpret_cast<const SYSTEM_LOGICAL_PROCESSOR_INFORMATION_EX*>(iter);
|
||||
auto size = processor_info->Size;
|
||||
|
||||
if (processor_info->Relationship == RelationCache &&
|
||||
processor_info->Cache.Level == 2) {
|
||||
// L2 cache
|
||||
l2_cache_size_ = static_cast<int>(processor_info->Cache.CacheSize);
|
||||
break;
|
||||
}
|
||||
|
||||
iter += size;
|
||||
}
|
||||
|
||||
if (logging::LoggingManager::HasDefaultLogger()) {
|
||||
LOGS_DEFAULT(VERBOSE) << "Found total " << cores_.size() << " core(s) from windows system:";
|
||||
LOGS_DEFAULT(VERBOSE) << log_stream.str();
|
||||
LOGS_DEFAULT(VERBOSE) << "\nDetected L2 cache size: " << l2_cache_size_ << " bytes";
|
||||
}
|
||||
}
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -55,6 +55,7 @@ class WindowsEnv : public Env {
|
|||
static int DefaultNumCores();
|
||||
int GetNumPhysicalCpuCores() const override;
|
||||
std::vector<LogicalProcessors> GetDefaultThreadAffinities() const override;
|
||||
int GetL2CacheSize() const override;
|
||||
static WindowsEnv& Instance();
|
||||
PIDType GetSelfPid() const override;
|
||||
Status GetFileLength(_In_z_ const ORTCHAR_T* file_path, size_t& length) const override;
|
||||
|
|
@ -113,6 +114,8 @@ class WindowsEnv : public Env {
|
|||
* }
|
||||
*/
|
||||
std::vector<LogicalProcessors> cores_;
|
||||
|
||||
int l2_cache_size_;
|
||||
/*
|
||||
* "global_processor_info_map_" is a map of:
|
||||
* global_processor_id <--> (group_id, local_processor_id)
|
||||
|
|
|
|||
|
|
@ -441,7 +441,7 @@ def run_tflops_test(use_gpu: bool = True, enable_cuda_graph: bool = False, repea
|
|||
sequence_length=sequence_length,
|
||||
num_heads=num_heads,
|
||||
head_size=head_size,
|
||||
causal=True,
|
||||
causal=causal,
|
||||
use_kv_cache=use_kv_cache,
|
||||
past_sequence_length=past_sequence_length,
|
||||
max_cache_sequence_length=None,
|
||||
|
|
|
|||
Loading…
Reference in a new issue