diff --git a/onnxruntime/contrib_ops/cuda/sparse/sparse_attention.cc b/onnxruntime/contrib_ops/cuda/sparse/sparse_attention.cc index 506a6683de..7d3f6eb929 100644 --- a/onnxruntime/contrib_ops/cuda/sparse/sparse_attention.cc +++ b/onnxruntime/contrib_ops/cuda/sparse/sparse_attention.cc @@ -55,8 +55,6 @@ SparseAttention::SparseAttention(const OpKernelInfo& info) scale_ = info.GetAttrOrDefault("scale", 0.0f); - kernel_loaded_ = false; - disable_v1_kernel_ = ParseEnvironmentVariableWithDefault(sparse_attention::kDisableSparseAttentionV1, false); } @@ -150,24 +148,21 @@ Status SparseAttention::ComputeInternal(OpKernelContext* context) const { CUDA_RETURN_IF_ERROR(cudaEventRecord(isCopyDone, cuda_stream)); } - if (!kernel_loaded_) { - if constexpr (std::is_same::value) { - // std::call_once is used in load_sparse_attention_fp16 so no need to use mutex here. - // After kernel is loaded, it will stay in memory until the process exits. We do not unload explicitly. - // TODO(tianleiwu): use TSharedCubinKernelFactory to manage kernel loading/unloading. - if (use_v2_kernel) { - sparse_attention_v2::load_sparse_attention_fp16(sm); - } else { - sparse_attention_v1::load_sparse_attention_fp16(sm); - } + if constexpr (std::is_same::value) { + // std::call_once is used in load_sparse_attention_fp16 so no need to use mutex here. + // After kernel is loaded, it will stay in memory until the process exits. We do not unload explicitly. + // TODO(tianleiwu): use TSharedCubinKernelFactory to manage kernel loading/unloading. + if (use_v2_kernel) { + sparse_attention_v2::load_sparse_attention_fp16(sm); } else { - if (use_v2_kernel) { - sparse_attention_v2::load_sparse_attention_bf16(sm); - } else { - sparse_attention_v1::load_sparse_attention_bf16(sm); - } + sparse_attention_v1::load_sparse_attention_fp16(sm); + } + } else { + if (use_v2_kernel) { + sparse_attention_v2::load_sparse_attention_bf16(sm); + } else { + sparse_attention_v1::load_sparse_attention_bf16(sm); } - kernel_loaded_ = true; } // Compute output shape and get output tensors. diff --git a/onnxruntime/contrib_ops/cuda/sparse/sparse_attention.h b/onnxruntime/contrib_ops/cuda/sparse/sparse_attention.h index 159ff6728f..1df3affe17 100644 --- a/onnxruntime/contrib_ops/cuda/sparse/sparse_attention.h +++ b/onnxruntime/contrib_ops/cuda/sparse/sparse_attention.h @@ -18,15 +18,14 @@ class SparseAttention final : public CudaKernel { Status ComputeInternal(OpKernelContext* context) const override; protected: - int num_heads_; // number of attention heads for q - int kv_num_heads_; // number of attention heads for k and v - float scale_; // Scaling factor applied prior to softmax. - bool is_causal_; // unidirectional attention or not - int sparse_block_size_; // block size for sparsity - bool do_rotary_; // Has rotary positional embedding - bool rotary_interleaved_; // Interleaved rotary positional embedding - bool disable_v1_kernel_; // Whether disable v1 kernel and use v2 kernel for prompt. - mutable bool kernel_loaded_; // Kernel has been loaded + int num_heads_; // number of attention heads for q + int kv_num_heads_; // number of attention heads for k and v + float scale_; // Scaling factor applied prior to softmax. + bool is_causal_; // unidirectional attention or not + int sparse_block_size_; // block size for sparsity + bool do_rotary_; // Has rotary positional embedding + bool rotary_interleaved_; // Interleaved rotary positional embedding + bool disable_v1_kernel_; // Whether disable v1 kernel and use v2 kernel for prompt. }; } // namespace cuda