From 2e7de54565c24e8d8d2f039bac9ee838fe97b05d Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Fri, 17 May 2024 22:42:19 -0700 Subject: [PATCH] [CUDA] Fix SparseAttention Kernel (#20716) ### Description Currently, there is one bool flag to indicate whether kernel is loaded. However, there are v1 and v2 kernels, so the flag will allow only one version of kernel loaded. We use v1 kernel for prompt and v2 kernel for token generation, and the flag will cause issue when we want both prompt and token generation. This bug is found in integration test. The unit test only test one kernel at a time so the issue was not found before. Another possible walkaround without this fix is to set an environment variable `ORT_DISABLE_SPARSE_ATTENTION_V1=1` ### Motivation and Context --- .../cuda/sparse/sparse_attention.cc | 31 ++++++++----------- .../cuda/sparse/sparse_attention.h | 17 +++++----- 2 files changed, 21 insertions(+), 27 deletions(-) diff --git a/onnxruntime/contrib_ops/cuda/sparse/sparse_attention.cc b/onnxruntime/contrib_ops/cuda/sparse/sparse_attention.cc index 506a6683de..7d3f6eb929 100644 --- a/onnxruntime/contrib_ops/cuda/sparse/sparse_attention.cc +++ b/onnxruntime/contrib_ops/cuda/sparse/sparse_attention.cc @@ -55,8 +55,6 @@ SparseAttention::SparseAttention(const OpKernelInfo& info) scale_ = info.GetAttrOrDefault("scale", 0.0f); - kernel_loaded_ = false; - disable_v1_kernel_ = ParseEnvironmentVariableWithDefault(sparse_attention::kDisableSparseAttentionV1, false); } @@ -150,24 +148,21 @@ Status SparseAttention::ComputeInternal(OpKernelContext* context) const { CUDA_RETURN_IF_ERROR(cudaEventRecord(isCopyDone, cuda_stream)); } - if (!kernel_loaded_) { - if constexpr (std::is_same::value) { - // std::call_once is used in load_sparse_attention_fp16 so no need to use mutex here. - // After kernel is loaded, it will stay in memory until the process exits. We do not unload explicitly. - // TODO(tianleiwu): use TSharedCubinKernelFactory to manage kernel loading/unloading. - if (use_v2_kernel) { - sparse_attention_v2::load_sparse_attention_fp16(sm); - } else { - sparse_attention_v1::load_sparse_attention_fp16(sm); - } + if constexpr (std::is_same::value) { + // std::call_once is used in load_sparse_attention_fp16 so no need to use mutex here. + // After kernel is loaded, it will stay in memory until the process exits. We do not unload explicitly. + // TODO(tianleiwu): use TSharedCubinKernelFactory to manage kernel loading/unloading. + if (use_v2_kernel) { + sparse_attention_v2::load_sparse_attention_fp16(sm); } else { - if (use_v2_kernel) { - sparse_attention_v2::load_sparse_attention_bf16(sm); - } else { - sparse_attention_v1::load_sparse_attention_bf16(sm); - } + sparse_attention_v1::load_sparse_attention_fp16(sm); + } + } else { + if (use_v2_kernel) { + sparse_attention_v2::load_sparse_attention_bf16(sm); + } else { + sparse_attention_v1::load_sparse_attention_bf16(sm); } - kernel_loaded_ = true; } // Compute output shape and get output tensors. diff --git a/onnxruntime/contrib_ops/cuda/sparse/sparse_attention.h b/onnxruntime/contrib_ops/cuda/sparse/sparse_attention.h index 159ff6728f..1df3affe17 100644 --- a/onnxruntime/contrib_ops/cuda/sparse/sparse_attention.h +++ b/onnxruntime/contrib_ops/cuda/sparse/sparse_attention.h @@ -18,15 +18,14 @@ class SparseAttention final : public CudaKernel { Status ComputeInternal(OpKernelContext* context) const override; protected: - int num_heads_; // number of attention heads for q - int kv_num_heads_; // number of attention heads for k and v - float scale_; // Scaling factor applied prior to softmax. - bool is_causal_; // unidirectional attention or not - int sparse_block_size_; // block size for sparsity - bool do_rotary_; // Has rotary positional embedding - bool rotary_interleaved_; // Interleaved rotary positional embedding - bool disable_v1_kernel_; // Whether disable v1 kernel and use v2 kernel for prompt. - mutable bool kernel_loaded_; // Kernel has been loaded + int num_heads_; // number of attention heads for q + int kv_num_heads_; // number of attention heads for k and v + float scale_; // Scaling factor applied prior to softmax. + bool is_causal_; // unidirectional attention or not + int sparse_block_size_; // block size for sparsity + bool do_rotary_; // Has rotary positional embedding + bool rotary_interleaved_; // Interleaved rotary positional embedding + bool disable_v1_kernel_; // Whether disable v1 kernel and use v2 kernel for prompt. }; } // namespace cuda