From ad382120fe4b9b17d7dbb871feb397e8dfd183af Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Mon, 26 Aug 2024 13:34:55 -0700 Subject: [PATCH] [CUDA] enable causal in MultiHeadAttention (#21852) ### Description Enable causal in MultiHeadAttention cuda operator. All formats (Q_K_V_BSNH_BSNH_BSNH, Q_K_V_BSNH_BNSH_BNSH, Q_KV_BSNH_BSN2H and QKV_BSN3H) supports causal for now. Internally, casual will be dispatch to flash attention, efficient attention or unfused attention kernel. ### Motivation and Context Currently, MultiHeadAttention has causal enabled in CPU ep, but not in CUDA ep. It could cause issues in onnx conversion, like some model can run in CPU but not in CUDA. Enable causal in CUDA will reduce the difference of support matrix of CPU/CUDA. --- .../cuda/bert/multihead_attention.cc | 12 +++-- .../tools/transformers/io_binding_helper.py | 3 +- .../test/python/transformers/benchmark_mha.py | 5 +-- .../test/python/transformers/test_mha.py | 45 ++++++++++++------- 4 files changed, 37 insertions(+), 28 deletions(-) diff --git a/onnxruntime/contrib_ops/cuda/bert/multihead_attention.cc b/onnxruntime/contrib_ops/cuda/bert/multihead_attention.cc index 0960a9efe7..52bfe61608 100644 --- a/onnxruntime/contrib_ops/cuda/bert/multihead_attention.cc +++ b/onnxruntime/contrib_ops/cuda/bert/multihead_attention.cc @@ -46,8 +46,6 @@ MultiHeadAttention::MultiHeadAttention(const OpKernelInfo& info) scale_ = info.GetAttrOrDefault("scale", 0.0f); is_unidirectional_ = info.GetAttrOrDefault("unidirectional", 0) == 1; - ORT_ENFORCE(!is_unidirectional_, - "MHA support CUDA kernel does not Unidirectional. Consider using Attention or GQA instead."); kernel_options_ = this->GetAttentionKernelOptions(); @@ -208,13 +206,13 @@ Status MultiHeadAttention::ComputeInternal(OpKernelContext* context) const { bool use_fused_cross_attention = kernel_type == AttentionKernelType::AttentionKernel_Default && !disable_fused_cross_attention_ && + !is_unidirectional_ && nullptr == key_padding_mask && nullptr == attention_bias && nullptr == past_key && nullptr == present_key && (parameters.qkv_format == Q_K_V_BSNH || (parameters.qkv_format == Q_KV_BSNH_BSN2H && bias == nullptr)) && parameters.hidden_size == parameters.v_hidden_size && - has_fused_cross_attention_kernel(sm, parameters.head_size, - parameters.kv_sequence_length); + has_fused_cross_attention_kernel(sm, parameters.head_size, parameters.kv_sequence_length); if (use_fused_cross_attention) { if (fused_fp16_cross_attention_kernel_ == nullptr) { std::call_once(fused_cross_init_once_flag_, [&]() { @@ -233,6 +231,7 @@ Status MultiHeadAttention::ComputeInternal(OpKernelContext* context) const { bool use_fused_runner = kernel_type == AttentionKernelType::AttentionKernel_Default && !disable_fused_self_attention_ && + !is_unidirectional_ && nullptr == attention_bias && (parameters.qkv_format == Q_K_V_BSNH || parameters.qkv_format == QKV_BSN3H) && nullptr == past_key && nullptr == present_key && @@ -240,13 +239,12 @@ Status MultiHeadAttention::ComputeInternal(OpKernelContext* context) const { parameters.hidden_size == parameters.v_hidden_size && parameters.sequence_length == parameters.kv_sequence_length && // self attention only for fused runner FusedMHARunnerFP16v2::IsSupported(sm, parameters.head_size, sequence_length, - enable_trt_flash_attention_, false); + enable_trt_flash_attention_, is_unidirectional_); if (use_fused_runner) { // Here we assume that num_heads and head_size does not change for a MultiHeadAttention node. if (nullptr == fused_fp16_runner_.get()) { - constexpr bool is_unidirectional = false; std::call_once(fused_fp16_runner_created_, [&]() { - fused_fp16_runner_ = FusedMHARunnerFP16v2::Create(num_heads_, parameters.head_size, sm, is_unidirectional, + fused_fp16_runner_ = FusedMHARunnerFP16v2::Create(num_heads_, parameters.head_size, sm, is_unidirectional_, enable_trt_flash_attention_, parameters.scale); }); } diff --git a/onnxruntime/python/tools/transformers/io_binding_helper.py b/onnxruntime/python/tools/transformers/io_binding_helper.py index 4f46242a4f..2375104ac9 100644 --- a/onnxruntime/python/tools/transformers/io_binding_helper.py +++ b/onnxruntime/python/tools/transformers/io_binding_helper.py @@ -304,7 +304,7 @@ class CudaSession: tensor.data_ptr(), ) - def infer(self, feed_dict: Dict[str, torch.Tensor], run_options: RunOptions = None, synchronize: bool = False): + def infer(self, feed_dict: Dict[str, torch.Tensor], run_options: RunOptions = None, synchronize: bool = True): """Bind input tensors and run inference""" for name, tensor in feed_dict.items(): assert isinstance(tensor, torch.Tensor) and tensor.is_contiguous() @@ -317,7 +317,6 @@ class CudaSession: else: self.bind_input_and_buffer_sharing(name, tensor) - # Synchronization are not needed in most cases unless different streams are used or inputs/outputs are in CPU. if synchronize: self.io_binding.synchronize_inputs() self.ort_session.run_with_iobinding(self.io_binding, run_options) diff --git a/onnxruntime/test/python/transformers/benchmark_mha.py b/onnxruntime/test/python/transformers/benchmark_mha.py index 2a3541db4c..d8acb66158 100644 --- a/onnxruntime/test/python/transformers/benchmark_mha.py +++ b/onnxruntime/test/python/transformers/benchmark_mha.py @@ -587,8 +587,8 @@ class OrtMultiHeadAttention: self.ort_session = create_session(config, session_options, use_tf32=use_tf32) self.feed_dict = config.random_inputs() - def infer(self): - return self.ort_session.infer(self.feed_dict) + def infer(self, run_options=None, synchronize=True): + return self.ort_session.infer(self.feed_dict, run_options=run_options, synchronize=synchronize) def measure_latency(cuda_session: CudaSession, input_dict): @@ -1356,7 +1356,6 @@ if __name__ == "__main__": args.repeats = 10000 if args.use_gpu else 100 if args.use_gpu: - assert args.torch or not args.causal, "no causal cuda kernel in MHA op" assert torch.cuda.is_available() if not args.torch: assert "CUDAExecutionProvider" in get_available_providers() diff --git a/onnxruntime/test/python/transformers/test_mha.py b/onnxruntime/test/python/transformers/test_mha.py index 92653ffb05..69f0035ef8 100644 --- a/onnxruntime/test/python/transformers/test_mha.py +++ b/onnxruntime/test/python/transformers/test_mha.py @@ -68,6 +68,22 @@ def get_bias_support(format: InputFormats): raise RuntimeError(f"Unknown format: {format}") +def get_causal_support(format: InputFormats): + if format == InputFormats.Q_K_V_BSNH_BSNH_BSNH: + return [True, False] + + if format == InputFormats.Q_K_V_BSNH_BNSH_BNSH: + return [True, False] + + if format == InputFormats.Q_KV_BSNH_BSN2H: + return [True, False] + + if format == InputFormats.QKV_BSN3H: + return [True, False] + + raise RuntimeError(f"Unknown format: {format}") + + def get_atten_bias_support(): atten_bias_options = [ # (has_attn_bias, broadcast_attn_bias_dim_0, broadcast_attn_bias_dim_1) @@ -215,7 +231,7 @@ def no_kv_cache_test_cases(provider: str, comprehensive: bool): for num_heads in heads: for head_size in head_sizes: for format in formats: - for causal in [True, False]: + for causal in get_causal_support(format): for mask_format in mask_formats: for has_bias in get_bias_support(format): for ( @@ -256,8 +272,8 @@ def no_kv_cache_test_cases(provider: str, comprehensive: bool): has_attn_bias, broadcast_attn_bias_dim_0, broadcast_attn_bias_dim_1 = atten_bias_options[ i % len(atten_bias_options) ] - for causal in [True, False]: - for format in formats: + for format in formats: + for causal in get_causal_support(format): for has_bias in get_bias_support(format): config = MultiHeadAttentionConfig( batch_size=batch_size, @@ -308,7 +324,7 @@ def kv_cache_test_cases(provider: str, comprehensive: bool): for num_heads in heads: for head_size in head_sizes: for format in formats: - for causal in [True, False]: + for causal in get_causal_support(format): for has_past_input in [True, False]: for mask_format in mask_formats: for has_bias in get_bias_support(format): @@ -353,8 +369,8 @@ def kv_cache_test_cases(provider: str, comprehensive: bool): has_attn_bias, broadcast_attn_bias_dim_0, broadcast_attn_bias_dim_1 = atten_bias_options[ i % len(atten_bias_options) ] - for causal in [True, False]: - for format in formats: + for format in formats: + for causal in get_causal_support(format): for has_past_input in [True, False]: for has_bias in get_bias_support(format): sequence_length = 1 if has_past_input else past_sequence_length @@ -397,7 +413,7 @@ def no_kv_cache_multi_thread_test_cases(provider: str, comprehensive: bool): device, dtype, formats = get_provider_support_info(provider, False) for format in formats: - for causal in [True, False]: + for causal in get_causal_support(format): for num_heads in heads: for head_size in head_sizes: configs = [] # list of configurations to run in parallel @@ -437,7 +453,7 @@ def kv_cache_multi_thread_test_cases(provider: str, comprehensive: bool): device, dtype, formats = get_provider_support_info(provider, True) for format in formats: - for causal in [True, False]: + for causal in get_causal_support(format): for num_heads in heads: for head_size in head_sizes: configs = [] @@ -494,12 +510,8 @@ def parity_check_mha( rtol=1e-3, atol=1e-3, ): - # CUDA kernel does not support causal so skip such test cases. - if config.causal and config.provider == "CUDAExecutionProvider": - return - ort_mha = OrtMultiHeadAttention(config, use_tf32=False) - ort_outputs = ort_mha.infer() + ort_outputs = ort_mha.infer(synchronize=True) out = ort_outputs["output"] out = torch.reshape(out, (config.batch_size, config.sequence_length, config.num_heads, config.head_size)) @@ -602,9 +614,6 @@ def parity_check_mha_multi_threading( ): # Use the first config to create a session, which is shared by all configs to run in parallel. config = test_inputs[0]["config"] - # For now, MHA CUDA kernel does not support causal so skip such test cases. - if config.causal and config.provider == "CUDAExecutionProvider": - return None # Some kernel does not support certain input format. if attention_kernel not in [ @@ -784,6 +793,10 @@ class TestMultiHeadAttention(unittest.TestCase): def run_mha_cuda_multi_threading(self, attention_kernel): for configs in multi_thread_test_cases("CUDAExecutionProvider", comprehensive_mode): + if configs and configs[0].causal and (SdpaKernel.TRT_CAUSAL_ATTENTION & attention_kernel != 0): + # TRT fused causal is disabled by default so skip the test of causal for multi-threading. + continue + test_inputs = [] for config in configs: ort_inputs = config.random_inputs()