diff --git a/onnxruntime/contrib_ops/cuda/bert/tensorrt_fused_multihead_attention/mha_runner.cu b/onnxruntime/contrib_ops/cuda/bert/tensorrt_fused_multihead_attention/mha_runner.cu index 82c7485eee..4514fe384d 100644 --- a/onnxruntime/contrib_ops/cuda/bert/tensorrt_fused_multihead_attention/mha_runner.cu +++ b/onnxruntime/contrib_ops/cuda/bert/tensorrt_fused_multihead_attention/mha_runner.cu @@ -159,7 +159,7 @@ class FusedMHARunnerFP16v2::mhaImpl { params.o_ptr = output; params.cu_seqlens = static_cast(const_cast(cu_seqlens)); - if (use_flash_attention && flash_attention_kernel != nullptr) { + if (use_flash_attention && flash_attention_kernel != nullptr && !has_causal_mask) { flash_attention_kernel->run(params, stream); } else { xmmaKernel->run(params, stream, use_flash_attention, has_causal_mask);