diff --git a/cmake/onnxruntime_mlas.cmake b/cmake/onnxruntime_mlas.cmake index 54bddcbdcf..df6553e383 100644 --- a/cmake/onnxruntime_mlas.cmake +++ b/cmake/onnxruntime_mlas.cmake @@ -39,6 +39,7 @@ onnxruntime_add_static_library(onnxruntime_mlas ${MLAS_SRC_DIR}/sqnbitgemm.h ${MLAS_SRC_DIR}/sqnbitgemm.cpp ${MLAS_SRC_DIR}/sqnbitgemm_q8_block.h + ${MLAS_SRC_DIR}/flashattn.cpp ) target_sources(onnxruntime_mlas PRIVATE diff --git a/onnxruntime/contrib_ops/cpu/bert/multihead_attention.cc b/onnxruntime/contrib_ops/cpu/bert/multihead_attention.cc index b39167f449..9677c30f22 100644 --- a/onnxruntime/contrib_ops/cpu/bert/multihead_attention.cc +++ b/onnxruntime/contrib_ops/cpu/bert/multihead_attention.cc @@ -10,8 +10,12 @@ #include "core/framework/tensorprotoutils.h" #include "core/graph/onnx_protobuf.h" #include "core/common/safeint.h" +#include "core/platform/env_var_utils.h" #include "core/platform/threadpool.h" +#include "core/mlas/inc/mlas.h" +#include +#include #include #include @@ -39,6 +43,11 @@ MultiHeadAttention::MultiHeadAttention(const OpKernelInfo& info) : OpKernel(i mask_filter_value_ = info.GetAttrOrDefault("mask_filter_value", -10000.0f); is_unidirectional_ = info.GetAttrOrDefault("unidirectional", 0) == 1; + + const auto& env = Env::Default(); + l2_cache_size_ = env.GetL2CacheSize(); + + disable_flash_ = ParseEnvironmentVariableWithDefault(attention::kDisableFlashAttention, false); } template @@ -60,7 +69,6 @@ Status MultiHeadAttention::Compute(OpKernelContext* context) const { } AttentionParameters parameters = {}; - constexpr float scale = 1.0f; bool past_present_share_buffer = false; ORT_RETURN_IF_ERROR(multihead_attention_helper::CheckInputs(query, key, @@ -74,7 +82,7 @@ Status MultiHeadAttention::Compute(OpKernelContext* context) const { ¶meters, num_heads_, mask_filter_value_, - scale, + scale_, is_unidirectional_, past_present_share_buffer, false)); @@ -99,8 +107,14 @@ Status MultiHeadAttention::Compute(OpKernelContext* context) const { const int v_bias_offset = 2 * qk_hidden_size; // If optional outputs aren't needed, present_k and present_v will be null - std::vector present_k_shape({static_cast(batch_size), static_cast(num_heads_), static_cast(total_kv_sequence_length), static_cast(qk_head_size)}); - std::vector present_v_shape({static_cast(batch_size), static_cast(num_heads_), static_cast(total_kv_sequence_length), static_cast(v_head_size)}); + std::vector present_k_shape({static_cast(batch_size), + static_cast(num_heads_), + static_cast(total_kv_sequence_length), + static_cast(qk_head_size)}); + std::vector present_v_shape({static_cast(batch_size), + static_cast(num_heads_), + static_cast(total_kv_sequence_length), + static_cast(v_head_size)}); Tensor* present_k = context->Output(1, present_k_shape); Tensor* present_v = context->Output(2, present_v_shape); @@ -138,6 +152,70 @@ Status MultiHeadAttention::Compute(OpKernelContext* context) const { ORT_RETURN_IF_ERROR(MaybeTransposeToBNSHAndAddBias( context, allocator, batch_size, num_heads_, kv_sequence_length, v_head_size, value, bias, v_bias_offset, V)); + if (std::is_same_v && + !disable_flash_ && + !is_unidirectional_ && + key_padding_mask == nullptr && + extra_add_qk == nullptr && + past_key == nullptr && + past_value == nullptr && + present_k == nullptr && + present_v == nullptr && + l2_cache_size_ > 0) { + MlasFlashAttentionThreadedArgs args; + args.batch_size = batch_size; + args.num_heads = num_heads_; + args.q_sequence_length = q_sequence_length; + args.kv_sequence_length = kv_sequence_length; + args.qk_head_size = qk_head_size; + args.v_head_size = v_head_size; + args.scale = (scale_ == 0.0f) ? 1.0f / sqrt(static_cast(qk_head_size)) : scale_; + /* + q_block_size, kv_block_size correspond to Br, Bc in the FlashAttention paper. + Let M = l2_cache_size / sizeof(float) + In the FlashAttention kernel, there are 5 big matrices that we need to keep in L2 cache: + slice of Q -- [Br, qk_head_size] + slice of K -- [Bc, qk_head_size] + slice of V -- [Bc, v_head_size] + result of QK -- [Br, Bc] + temporary output (same shape as QKV) -- [Br, v_head_size] + The total size of these matrices is (Br + Bc) * (qk_head_size + v_head_size) + Br * Bc + By taking Bc = M / (4 * (qk_head_size + v_head_size)), and Br = min(Bc, qk_head_size + v_head_size), we have + (Br + Bc) * (qk_head_size + v_head_size) + Br * Bc + <= 2 * Bc * (qk_head_size + v_head_size) + Br * Bc + <= 2 * Bc * (qk_head_size + v_head_size) + M/4 + <= 2 * M/4 + M/4 = M * (3/4) + + We leave 1/4 of the L2 cache for + 1. storing small tensors l and m + 2. instruction (code) + */ + args.kv_block_size = l2_cache_size_ / (static_cast(sizeof(float)) * 4 * (qk_head_size + v_head_size)); + args.kv_block_size = std::max(args.kv_block_size, 1); // avoid kv_block_size = 0 + args.q_block_size = std::min(args.kv_block_size, qk_head_size + v_head_size); + args.kv_block_size = std::min(args.kv_block_size, kv_sequence_length); // No point to have kv_block_size > kv_sequence_length + args.q_block_size = std::min(args.q_block_size, q_sequence_length); // No point to have q_block_size > q_sequence_length + + auto* tp = context->GetOperatorThreadPool(); + args.thread_count = concurrency::ThreadPool::DegreeOfParallelism(tp); + args.buffer_size_per_thread = (static_cast(args.q_block_size) * 2 + + static_cast(args.q_block_size) * static_cast(args.kv_block_size) + + static_cast(args.q_block_size) * static_cast(args.v_head_size)) * + sizeof(float); + size_t buffer_bytes = args.buffer_size_per_thread * args.thread_count; + IAllocatorUniquePtr buffer = IAllocator::MakeUniquePtr(allocator, buffer_bytes); + + args.buffer = reinterpret_cast(buffer.get()); + + args.query = Q.Get().Data(); + args.key = K.Get().Data(); + args.value = V.Get().Data(); + args.output = output->MutableData(); + + MlasFlashAttention(&args, tp); + return Status::OK(); + } + // Compute the attention score and apply the score to V return ApplyAttention(Q.GetMutable()->MutableData(), K.GetMutable()->MutableData(), diff --git a/onnxruntime/contrib_ops/cpu/bert/multihead_attention.h b/onnxruntime/contrib_ops/cpu/bert/multihead_attention.h index fb7da78a5c..8a9bef1b2b 100644 --- a/onnxruntime/contrib_ops/cpu/bert/multihead_attention.h +++ b/onnxruntime/contrib_ops/cpu/bert/multihead_attention.h @@ -19,6 +19,8 @@ class MultiHeadAttention final : public OpKernel, public AttentionCPUBase { int num_heads_; // number of attention heads float mask_filter_value_; bool is_unidirectional_; + bool disable_flash_; + int l2_cache_size_; }; } // namespace contrib diff --git a/onnxruntime/core/mlas/inc/mlas.h b/onnxruntime/core/mlas/inc/mlas.h index cdfd283899..675f7c7a13 100644 --- a/onnxruntime/core/mlas/inc/mlas.h +++ b/onnxruntime/core/mlas/inc/mlas.h @@ -1825,3 +1825,35 @@ MlasNhwcAvgPool( ); #endif + +struct MlasFlashAttentionThreadedArgs { + int batch_size; + int num_heads; + int q_sequence_length; + int kv_sequence_length; + int qk_head_size; + int v_head_size; + int q_block_size; + int kv_block_size; + float scale; + int thread_count; + float* buffer; + size_t buffer_size_per_thread; + const float* query; + const float* key; + const float* value; + float* output; +}; + +/** + * @brief Per-thread worker function for fp32 Flash Attention + * @param thread_id Thread index + * @param args Arguments + * @return +*/ +void +MLASCALL +MlasFlashAttention( + MlasFlashAttentionThreadedArgs* args, + MLAS_THREADPOOL* ThreadPool +); diff --git a/onnxruntime/core/mlas/lib/flashattn.cpp b/onnxruntime/core/mlas/lib/flashattn.cpp new file mode 100644 index 0000000000..fe5402ed14 --- /dev/null +++ b/onnxruntime/core/mlas/lib/flashattn.cpp @@ -0,0 +1,167 @@ +#include + +#include "mlasi.h" + +void +MlasFlashAttentionThreaded( + void* argptr, + std::ptrdiff_t thread_id +) +{ + const MlasFlashAttentionThreadedArgs* args = reinterpret_cast(argptr); + ptrdiff_t q_block_size = static_cast(args->q_block_size); + ptrdiff_t kv_block_size = static_cast(args->kv_block_size); + ptrdiff_t batch_size = static_cast(args->batch_size); + ptrdiff_t num_heads = static_cast(args->num_heads); + ptrdiff_t q_sequence_length = static_cast(args->q_sequence_length); + ptrdiff_t kv_sequence_length = static_cast(args->kv_sequence_length); + ptrdiff_t qk_head_size = static_cast(args->qk_head_size); + ptrdiff_t v_head_size = static_cast(args->v_head_size); + float* buffer = args->buffer; + ptrdiff_t buffer_size_per_thread = static_cast(args->buffer_size_per_thread); + ptrdiff_t thread_count = static_cast(args->thread_count); + const float* query = args->query; + const float* key = args->key; + const float* value = args->value; + float* output = args->output; + +#if defined(MLAS_TARGET_AMD64) || defined(MLAS_TARGET_LARCH64) + auto&& mlas_platform = GetMlasPlatform(); +#endif + + ptrdiff_t q_chunk_count = (q_sequence_length + (q_block_size - 1)) / q_block_size; + + ptrdiff_t task_start = 0; + ptrdiff_t task_end = 0; + ptrdiff_t total_task_count = batch_size * num_heads * q_chunk_count; + ptrdiff_t quotient = total_task_count / thread_count; + ptrdiff_t remainder = total_task_count % thread_count; + if (thread_id < remainder) { + task_start = (quotient + 1) * thread_id; + task_end = task_start + quotient + 1; + } else { + task_start = quotient * thread_id + remainder; + task_end = task_start + quotient; + } + + for (ptrdiff_t task_index = task_start; task_index < task_end; ++task_index) { + ptrdiff_t batch_idx = task_index; + ptrdiff_t q_idx = (batch_idx % q_chunk_count) * q_block_size; + batch_idx /= q_chunk_count; + ptrdiff_t head_idx = batch_idx % num_heads; + batch_idx /= num_heads; + + char* buffer_current_thread = reinterpret_cast(buffer) + thread_id * buffer_size_per_thread; + float* l = reinterpret_cast(buffer_current_thread); + float* m = l + q_block_size; + for (ptrdiff_t t = 0; t < q_block_size; ++t) { + m[t] = std::numeric_limits::lowest(); + } + float* intermediate = m + q_block_size; + float* temp_output = intermediate + q_block_size * kv_block_size; + float negmax = 0; + + for (ptrdiff_t ir = 0; ir < kv_sequence_length; ir += kv_block_size) { + /* + S = Q[batch_idx, head_idx, q_idx:q_idx+q_block_size, :] * (K[batch_idx, head_idx, ir:ir+kv_block_size, :]).T + old_m = m + m = max(m, rowmax(S)) + diff = old_m - m + S = exp(S - m) + l = exp(diff) * l + rowsum(S) + O = diag(exp(diff)) * O + S * V[batch_idx, head_idx, ir:ir+kv_block_size, :] + */ + ptrdiff_t h = batch_idx * num_heads + head_idx; + const float* inputQ = query + (h * q_sequence_length + q_idx) * qk_head_size; + const float* inputK = key + (h * kv_sequence_length + ir) * qk_head_size; + const float* inputV = value + (h * kv_sequence_length + ir) * v_head_size; + + size_t row_size_q_capped = static_cast(std::min(q_block_size, q_sequence_length - q_idx)); + size_t row_size_kv_capped = static_cast(std::min(kv_block_size, kv_sequence_length - ir)); + + MlasSgemmOperation(CBLAS_TRANSPOSE::CblasNoTrans, + CBLAS_TRANSPOSE::CblasTrans, + row_size_q_capped, + row_size_kv_capped, + static_cast(qk_head_size), + args->scale, + inputQ, + static_cast(qk_head_size), + inputK, + static_cast(qk_head_size), + 0.0f, + intermediate, + row_size_kv_capped); + + for (ptrdiff_t irow = 0; irow < static_cast(row_size_q_capped); ++irow) { + float* p = intermediate + irow * row_size_kv_capped; + +#if defined(MLAS_TARGET_AMD64) || defined(MLAS_TARGET_LARCH64) + float rowmax = mlas_platform.ReduceMaximumF32Kernel(p, row_size_kv_capped); +#else + float rowmax = MlasReduceMaximumF32Kernel(p, row_size_kv_capped); +#endif + float m_diff = m[irow]; + m[irow] = std::max(m[irow], rowmax); // new m + negmax = -m[irow]; + m_diff -= m[irow]; // old - new (less than 0) + +#if defined(MLAS_TARGET_AMD64) + float rowsum = mlas_platform.ComputeSumExpF32Kernel(p, p, row_size_kv_capped, &negmax); +#else + float rowsum = MlasComputeSumExpF32Kernel(p, p, row_size_kv_capped, &negmax); +#endif + + // Note: for ir == 0, there is actually no need to calculate exp_diff + if (ir != 0) { + float exp_diff = std::exp(m_diff); + l[irow] = exp_diff * l[irow] + rowsum; + + for (ptrdiff_t icol = 0; icol < v_head_size; ++icol) { + temp_output[irow * v_head_size + icol] = exp_diff * temp_output[irow * v_head_size + icol]; + } + } else { + l[irow] = rowsum; + // When ir == 0, there is no need to scale the old result because it is zero. + } + } + MlasSgemmOperation(CBLAS_TRANSPOSE::CblasNoTrans, + CBLAS_TRANSPOSE::CblasNoTrans, + row_size_q_capped, + static_cast(v_head_size), + row_size_kv_capped, + 1.0f, + intermediate, + row_size_kv_capped, + inputV, + static_cast(v_head_size), + ir == 0 ? 0.0f : 1.0f, + temp_output, + static_cast(v_head_size)); + } + + float* output_row = output + ((batch_idx * q_sequence_length + q_idx) * num_heads + head_idx) * v_head_size; + ptrdiff_t row_size_q_valid = std::min(q_block_size, q_sequence_length - q_idx); + // TODO: leverage advanced instruction sets + for (ptrdiff_t irow = 0; irow < row_size_q_valid; ++irow) { + for (ptrdiff_t icol = 0; icol < v_head_size; ++icol) { + output_row[icol] = temp_output[irow * v_head_size + icol] / l[irow]; + } + output_row += num_heads * v_head_size; + } + } +} + +void +MLASCALL +MlasFlashAttention( + MlasFlashAttentionThreadedArgs* args, + MLAS_THREADPOOL* ThreadPool +) +{ + MlasExecuteThreaded( + MlasFlashAttentionThreaded, + static_cast(args), + static_cast(args->thread_count), + ThreadPool); +} diff --git a/onnxruntime/core/platform/env.h b/onnxruntime/core/platform/env.h index f4dff2c491..c42b31e64d 100644 --- a/onnxruntime/core/platform/env.h +++ b/onnxruntime/core/platform/env.h @@ -147,6 +147,8 @@ class Env { virtual std::vector GetDefaultThreadAffinities() const = 0; + virtual int GetL2CacheSize() const = 0; + /// \brief Returns the number of micro-seconds since the Unix epoch. virtual uint64_t NowMicros() const { return env_time_->NowMicros(); diff --git a/onnxruntime/core/platform/posix/env.cc b/onnxruntime/core/platform/posix/env.cc index 16d135c3ac..ec06320438 100644 --- a/onnxruntime/core/platform/posix/env.cc +++ b/onnxruntime/core/platform/posix/env.cc @@ -43,6 +43,10 @@ limitations under the License. #define ORT_USE_CPUINFO #endif +#if defined(__APPLE__) || defined(__FreeBSD__) || defined(__NetBSD__) +#include +#endif + #include "core/common/common.h" #include #include "core/common/logging/logging.h" @@ -302,6 +306,22 @@ class PosixEnv : public Env { return ret; } + int GetL2CacheSize() const override { +#ifdef _SC_LEVEL2_CACHE_SIZE + return static_cast(sysconf(_SC_LEVEL2_CACHE_SIZE)); +#else + int value = 0; // unknown +#if (defined(__APPLE__) || defined(__FreeBSD__) || defined(__NetBSD__)) && defined(HW_L2CACHESIZE) + int mib[2] = {CTL_HW, HW_L2CACHESIZE}; + size_t len = sizeof(value); + if (sysctl(mib, 2, &value, &len, NULL, 0) < 0) { + return -1; // error + } +#endif + return value; +#endif + } + void SleepForMicroseconds(int64_t micros) const override { while (micros > 0) { timespec sleep_time; diff --git a/onnxruntime/core/platform/windows/env.cc b/onnxruntime/core/platform/windows/env.cc index 712b69593a..73319cd9c9 100644 --- a/onnxruntime/core/platform/windows/env.cc +++ b/onnxruntime/core/platform/windows/env.cc @@ -303,6 +303,10 @@ std::vector WindowsEnv::GetDefaultThreadAffinities() const { return cores_.empty() ? std::vector(DefaultNumCores(), LogicalProcessors{}) : cores_; } +int WindowsEnv::GetL2CacheSize() const { + return l2_cache_size_; +} + WindowsEnv& WindowsEnv::Instance() { static WindowsEnv default_env; return default_env; @@ -851,6 +855,7 @@ ProcessorInfo WindowsEnv::GetProcessorAffinityMask(int global_processor_id) cons } WindowsEnv::WindowsEnv() { + l2_cache_size_ = 0; InitializeCpuInfo(); } @@ -924,9 +929,57 @@ void WindowsEnv::InitializeCpuInfo() { } iter += size; } + + DWORD newLength = 0; + GetLogicalProcessorInformationEx(RelationCache, nullptr, &newLength); + last_error = GetLastError(); + if (last_error != ERROR_INSUFFICIENT_BUFFER) { + const auto error_code = GetLastError(); + if (logging::LoggingManager::HasDefaultLogger()) { + LOGS_DEFAULT(ERROR) << "Failed to calculate byte size for saving cpu info on windows" + << ", error code: " << error_code + << ", error msg: " << std::system_category().message(error_code); + } + return; + } + + if (newLength > returnLength) { + // Re-allocate + allocation = std::make_unique(newLength); + processorInfos = reinterpret_cast(allocation.get()); + } + + if (!GetLogicalProcessorInformationEx(RelationCache, processorInfos, &newLength)) { + const auto error_code = GetLastError(); + if (logging::LoggingManager::HasDefaultLogger()) { + LOGS_DEFAULT(ERROR) << "Failed to fetch cpu info on windows" + << ", error code: " << error_code + << ", error msg: " << std::system_category().message(error_code); + } + return; + } + + iter = reinterpret_cast(processorInfos); + end = iter + newLength; + + while (iter < end) { + auto processor_info = reinterpret_cast(iter); + auto size = processor_info->Size; + + if (processor_info->Relationship == RelationCache && + processor_info->Cache.Level == 2) { + // L2 cache + l2_cache_size_ = static_cast(processor_info->Cache.CacheSize); + break; + } + + iter += size; + } + if (logging::LoggingManager::HasDefaultLogger()) { LOGS_DEFAULT(VERBOSE) << "Found total " << cores_.size() << " core(s) from windows system:"; LOGS_DEFAULT(VERBOSE) << log_stream.str(); + LOGS_DEFAULT(VERBOSE) << "\nDetected L2 cache size: " << l2_cache_size_ << " bytes"; } } } // namespace onnxruntime diff --git a/onnxruntime/core/platform/windows/env.h b/onnxruntime/core/platform/windows/env.h index 79739db9e5..395aface1d 100644 --- a/onnxruntime/core/platform/windows/env.h +++ b/onnxruntime/core/platform/windows/env.h @@ -55,6 +55,7 @@ class WindowsEnv : public Env { static int DefaultNumCores(); int GetNumPhysicalCpuCores() const override; std::vector GetDefaultThreadAffinities() const override; + int GetL2CacheSize() const override; static WindowsEnv& Instance(); PIDType GetSelfPid() const override; Status GetFileLength(_In_z_ const ORTCHAR_T* file_path, size_t& length) const override; @@ -113,6 +114,8 @@ class WindowsEnv : public Env { * } */ std::vector cores_; + + int l2_cache_size_; /* * "global_processor_info_map_" is a map of: * global_processor_id <--> (group_id, local_processor_id) diff --git a/onnxruntime/test/python/transformers/benchmark_mha.py b/onnxruntime/test/python/transformers/benchmark_mha.py index 2257817584..797461bae2 100644 --- a/onnxruntime/test/python/transformers/benchmark_mha.py +++ b/onnxruntime/test/python/transformers/benchmark_mha.py @@ -441,7 +441,7 @@ def run_tflops_test(use_gpu: bool = True, enable_cuda_graph: bool = False, repea sequence_length=sequence_length, num_heads=num_heads, head_size=head_size, - causal=True, + causal=causal, use_kv_cache=use_kv_cache, past_sequence_length=past_sequence_length, max_cache_sequence_length=None,