From bd39c8f35e5ef061ef55c615f41dcf7d11c39cdd Mon Sep 17 00:00:00 2001 From: Zhang Lei Date: Fri, 13 Jan 2023 21:40:22 -0800 Subject: [PATCH] Fix causual flash attention related kernel run (#14299) --- .../cuda/bert/tensorrt_fused_multihead_attention/mha_runner.cu | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxruntime/contrib_ops/cuda/bert/tensorrt_fused_multihead_attention/mha_runner.cu b/onnxruntime/contrib_ops/cuda/bert/tensorrt_fused_multihead_attention/mha_runner.cu index 82c7485eee..4514fe384d 100644 --- a/onnxruntime/contrib_ops/cuda/bert/tensorrt_fused_multihead_attention/mha_runner.cu +++ b/onnxruntime/contrib_ops/cuda/bert/tensorrt_fused_multihead_attention/mha_runner.cu @@ -159,7 +159,7 @@ class FusedMHARunnerFP16v2::mhaImpl { params.o_ptr = output; params.cu_seqlens = static_cast(const_cast(cu_seqlens)); - if (use_flash_attention && flash_attention_kernel != nullptr) { + if (use_flash_attention && flash_attention_kernel != nullptr && !has_causal_mask) { flash_attention_kernel->run(params, stream); } else { xmmaKernel->run(params, stream, use_flash_attention, has_causal_mask);