mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-16 21:00:14 +00:00
[CUDA] Fix SparseAttention Kernel (#20716)
### Description Currently, there is one bool flag to indicate whether kernel is loaded. However, there are v1 and v2 kernels, so the flag will allow only one version of kernel loaded. We use v1 kernel for prompt and v2 kernel for token generation, and the flag will cause issue when we want both prompt and token generation. This bug is found in integration test. The unit test only test one kernel at a time so the issue was not found before. Another possible walkaround without this fix is to set an environment variable `ORT_DISABLE_SPARSE_ATTENTION_V1=1` ### Motivation and Context <!-- - Why is this change required? What problem does it solve? - If it fixes an open issue, please link to the issue here. -->
This commit is contained in:
parent
d7f7c3b343
commit
2e7de54565
2 changed files with 21 additions and 27 deletions
|
|
@ -55,8 +55,6 @@ SparseAttention<T>::SparseAttention(const OpKernelInfo& info)
|
|||
|
||||
scale_ = info.GetAttrOrDefault<float>("scale", 0.0f);
|
||||
|
||||
kernel_loaded_ = false;
|
||||
|
||||
disable_v1_kernel_ = ParseEnvironmentVariableWithDefault<bool>(sparse_attention::kDisableSparseAttentionV1, false);
|
||||
}
|
||||
|
||||
|
|
@ -150,24 +148,21 @@ Status SparseAttention<T>::ComputeInternal(OpKernelContext* context) const {
|
|||
CUDA_RETURN_IF_ERROR(cudaEventRecord(isCopyDone, cuda_stream));
|
||||
}
|
||||
|
||||
if (!kernel_loaded_) {
|
||||
if constexpr (std::is_same<T, MLFloat16>::value) {
|
||||
// std::call_once is used in load_sparse_attention_fp16 so no need to use mutex here.
|
||||
// After kernel is loaded, it will stay in memory until the process exits. We do not unload explicitly.
|
||||
// TODO(tianleiwu): use TSharedCubinKernelFactory to manage kernel loading/unloading.
|
||||
if (use_v2_kernel) {
|
||||
sparse_attention_v2::load_sparse_attention_fp16(sm);
|
||||
} else {
|
||||
sparse_attention_v1::load_sparse_attention_fp16(sm);
|
||||
}
|
||||
if constexpr (std::is_same<T, MLFloat16>::value) {
|
||||
// std::call_once is used in load_sparse_attention_fp16 so no need to use mutex here.
|
||||
// After kernel is loaded, it will stay in memory until the process exits. We do not unload explicitly.
|
||||
// TODO(tianleiwu): use TSharedCubinKernelFactory to manage kernel loading/unloading.
|
||||
if (use_v2_kernel) {
|
||||
sparse_attention_v2::load_sparse_attention_fp16(sm);
|
||||
} else {
|
||||
if (use_v2_kernel) {
|
||||
sparse_attention_v2::load_sparse_attention_bf16(sm);
|
||||
} else {
|
||||
sparse_attention_v1::load_sparse_attention_bf16(sm);
|
||||
}
|
||||
sparse_attention_v1::load_sparse_attention_fp16(sm);
|
||||
}
|
||||
} else {
|
||||
if (use_v2_kernel) {
|
||||
sparse_attention_v2::load_sparse_attention_bf16(sm);
|
||||
} else {
|
||||
sparse_attention_v1::load_sparse_attention_bf16(sm);
|
||||
}
|
||||
kernel_loaded_ = true;
|
||||
}
|
||||
|
||||
// Compute output shape and get output tensors.
|
||||
|
|
|
|||
|
|
@ -18,15 +18,14 @@ class SparseAttention final : public CudaKernel {
|
|||
Status ComputeInternal(OpKernelContext* context) const override;
|
||||
|
||||
protected:
|
||||
int num_heads_; // number of attention heads for q
|
||||
int kv_num_heads_; // number of attention heads for k and v
|
||||
float scale_; // Scaling factor applied prior to softmax.
|
||||
bool is_causal_; // unidirectional attention or not
|
||||
int sparse_block_size_; // block size for sparsity
|
||||
bool do_rotary_; // Has rotary positional embedding
|
||||
bool rotary_interleaved_; // Interleaved rotary positional embedding
|
||||
bool disable_v1_kernel_; // Whether disable v1 kernel and use v2 kernel for prompt.
|
||||
mutable bool kernel_loaded_; // Kernel has been loaded
|
||||
int num_heads_; // number of attention heads for q
|
||||
int kv_num_heads_; // number of attention heads for k and v
|
||||
float scale_; // Scaling factor applied prior to softmax.
|
||||
bool is_causal_; // unidirectional attention or not
|
||||
int sparse_block_size_; // block size for sparsity
|
||||
bool do_rotary_; // Has rotary positional embedding
|
||||
bool rotary_interleaved_; // Interleaved rotary positional embedding
|
||||
bool disable_v1_kernel_; // Whether disable v1 kernel and use v2 kernel for prompt.
|
||||
};
|
||||
|
||||
} // namespace cuda
|
||||
|
|
|
|||
Loading…
Reference in a new issue