[CUDA] Fix SparseAttention Kernel (#20716)

### Description

Currently, there is one bool flag to indicate whether kernel is loaded.
However, there are v1 and v2 kernels, so the flag will allow only one
version of kernel loaded. We use v1 kernel for prompt and v2 kernel for
token generation, and the flag will cause issue when we want both prompt
and token generation.

This bug is found in integration test. The unit test only test one
kernel at a time so the issue was not found before.

Another possible walkaround without this fix is to set an environment
variable `ORT_DISABLE_SPARSE_ATTENTION_V1=1`
### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->
This commit is contained in:
Tianlei Wu 2024-05-17 22:42:19 -07:00 committed by GitHub
parent d7f7c3b343
commit 2e7de54565
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 21 additions and 27 deletions

View file

@ -55,8 +55,6 @@ SparseAttention<T>::SparseAttention(const OpKernelInfo& info)
scale_ = info.GetAttrOrDefault<float>("scale", 0.0f);
kernel_loaded_ = false;
disable_v1_kernel_ = ParseEnvironmentVariableWithDefault<bool>(sparse_attention::kDisableSparseAttentionV1, false);
}
@ -150,24 +148,21 @@ Status SparseAttention<T>::ComputeInternal(OpKernelContext* context) const {
CUDA_RETURN_IF_ERROR(cudaEventRecord(isCopyDone, cuda_stream));
}
if (!kernel_loaded_) {
if constexpr (std::is_same<T, MLFloat16>::value) {
// std::call_once is used in load_sparse_attention_fp16 so no need to use mutex here.
// After kernel is loaded, it will stay in memory until the process exits. We do not unload explicitly.
// TODO(tianleiwu): use TSharedCubinKernelFactory to manage kernel loading/unloading.
if (use_v2_kernel) {
sparse_attention_v2::load_sparse_attention_fp16(sm);
} else {
sparse_attention_v1::load_sparse_attention_fp16(sm);
}
if constexpr (std::is_same<T, MLFloat16>::value) {
// std::call_once is used in load_sparse_attention_fp16 so no need to use mutex here.
// After kernel is loaded, it will stay in memory until the process exits. We do not unload explicitly.
// TODO(tianleiwu): use TSharedCubinKernelFactory to manage kernel loading/unloading.
if (use_v2_kernel) {
sparse_attention_v2::load_sparse_attention_fp16(sm);
} else {
if (use_v2_kernel) {
sparse_attention_v2::load_sparse_attention_bf16(sm);
} else {
sparse_attention_v1::load_sparse_attention_bf16(sm);
}
sparse_attention_v1::load_sparse_attention_fp16(sm);
}
} else {
if (use_v2_kernel) {
sparse_attention_v2::load_sparse_attention_bf16(sm);
} else {
sparse_attention_v1::load_sparse_attention_bf16(sm);
}
kernel_loaded_ = true;
}
// Compute output shape and get output tensors.

View file

@ -18,15 +18,14 @@ class SparseAttention final : public CudaKernel {
Status ComputeInternal(OpKernelContext* context) const override;
protected:
int num_heads_; // number of attention heads for q
int kv_num_heads_; // number of attention heads for k and v
float scale_; // Scaling factor applied prior to softmax.
bool is_causal_; // unidirectional attention or not
int sparse_block_size_; // block size for sparsity
bool do_rotary_; // Has rotary positional embedding
bool rotary_interleaved_; // Interleaved rotary positional embedding
bool disable_v1_kernel_; // Whether disable v1 kernel and use v2 kernel for prompt.
mutable bool kernel_loaded_; // Kernel has been loaded
int num_heads_; // number of attention heads for q
int kv_num_heads_; // number of attention heads for k and v
float scale_; // Scaling factor applied prior to softmax.
bool is_causal_; // unidirectional attention or not
int sparse_block_size_; // block size for sparsity
bool do_rotary_; // Has rotary positional embedding
bool rotary_interleaved_; // Interleaved rotary positional embedding
bool disable_v1_kernel_; // Whether disable v1 kernel and use v2 kernel for prompt.
};
} // namespace cuda