Fix causual flash attention related kernel run (#14299)

This commit is contained in:
Zhang Lei 2023-01-13 21:40:22 -08:00 committed by GitHub
parent 8824f812e0
commit bd39c8f35e
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -159,7 +159,7 @@ class FusedMHARunnerFP16v2::mhaImpl {
params.o_ptr = output;
params.cu_seqlens = static_cast<int*>(const_cast<void*>(cu_seqlens));
if (use_flash_attention && flash_attention_kernel != nullptr) {
if (use_flash_attention && flash_attention_kernel != nullptr && !has_causal_mask) {
flash_attention_kernel->run(params, stream);
} else {
xmmaKernel->run(params, stream, use_flash_attention, has_causal_mask);