mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-22 22:01:08 +00:00
Fix causual flash attention related kernel run (#14299)
This commit is contained in:
parent
8824f812e0
commit
bd39c8f35e
1 changed files with 1 additions and 1 deletions
|
|
@ -159,7 +159,7 @@ class FusedMHARunnerFP16v2::mhaImpl {
|
|||
params.o_ptr = output;
|
||||
params.cu_seqlens = static_cast<int*>(const_cast<void*>(cu_seqlens));
|
||||
|
||||
if (use_flash_attention && flash_attention_kernel != nullptr) {
|
||||
if (use_flash_attention && flash_attention_kernel != nullptr && !has_causal_mask) {
|
||||
flash_attention_kernel->run(params, stream);
|
||||
} else {
|
||||
xmmaKernel->run(params, stream, use_flash_attention, has_causal_mask);
|
||||
|
|
|
|||
Loading…
Reference in a new issue