diff --git a/onnxruntime/contrib_ops/cpu/bert/attention_common.h b/onnxruntime/contrib_ops/cpu/bert/attention_common.h index d0137a2049..bc5ce6e323 100644 --- a/onnxruntime/contrib_ops/cpu/bert/attention_common.h +++ b/onnxruntime/contrib_ops/cpu/bert/attention_common.h @@ -25,7 +25,7 @@ enum AttentionQkvFormat { Q_KV_BSNH_BSN2H, // for TRT fused cross attention, kv are packed }; -enum AttentionKernelType{ +enum AttentionKernelType { AttentionKernel_Unfused, AttentionKernel_TrtFusedAttention, AttentionKernel_TrtFlashAttention, @@ -38,15 +38,15 @@ enum AttentionKernelType{ struct AttentionParameters { int batch_size; int sequence_length; - int kv_sequence_length; // input sequence length of K or V - int past_sequence_length; // sequence length in past state of K or V - int total_sequence_length; // total sequence length of K or V - int max_sequence_length; // max sequence length from 4D mask - int input_hidden_size; // first dimension of weights for input projection - int hidden_size; // hidden size of Q or K - int head_size; // hidden size per head of Q or K - int v_hidden_size; // hidden size of V - int v_head_size; // hidden size per head of V + int kv_sequence_length; // input sequence length of K or V + int past_sequence_length; // sequence length in past state of K or V + int total_sequence_length; // total sequence length of K or V + int max_sequence_length; // max sequence length from 4D mask + int input_hidden_size; // first dimension of weights for input projection + int hidden_size; // hidden size of Q or K + int head_size; // hidden size per head of Q or K + int v_hidden_size; // hidden size of V + int v_head_size; // hidden size per head of V int num_heads; bool is_unidirectional; bool past_present_share_buffer; @@ -56,13 +56,17 @@ struct AttentionParameters { }; namespace attention { -// Environment variable to enable or disable fused self/causal attention kernel. Default is 0 (enabled). -constexpr const char* kDisableFusedAttention = "ORT_DISABLE_FUSED_ATTENTION"; +// Environment variable to enable or disable TRT fused self attention kernel. Default is 0 (enabled). +constexpr const char* kDisableFusedSelfAttention = "ORT_DISABLE_FUSED_ATTENTION"; // Environment variable to enable or disable fused cross attention kernel. Default is 0 (enabled). constexpr const char* kDisableFusedCrossAttention = "ORT_DISABLE_FUSED_CROSS_ATTENTION"; -// Environment variable to enable or disable TRT flash attention. Default is 0 (enabled). +// Environment variable to enable or disable TRT fused causal attention kernels. Default is 0 (disabled). +// Note that those causal attention kernels use fp16 accumulation. There is potential accuracy drop using those kernels. +constexpr const char* kEnableFusedCausalAttention = "ORT_ENABLE_FUSED_CAUSAL_ATTENTION"; + +// Environment variable to enable or disable TRT flash attention. This applies to both self and causal attention. Default is 0 (enabled). constexpr const char* kDisableTrtFlashAttention = "ORT_DISABLE_TRT_FLASH_ATTENTION"; // Environment variable to enable or disable cutlass memory efficient attention. Default is 0 (enabled). diff --git a/onnxruntime/contrib_ops/cuda/bert/attention.cc b/onnxruntime/contrib_ops/cuda/bert/attention.cc index 04cac1962f..f0669f6dc3 100644 --- a/onnxruntime/contrib_ops/cuda/bert/attention.cc +++ b/onnxruntime/contrib_ops/cuda/bert/attention.cc @@ -39,12 +39,15 @@ REGISTER_KERNEL_TYPED(MLFloat16) template Attention::Attention(const OpKernelInfo& info) : CudaKernel(info), AttentionBase(info, false) { - disable_fused_runner_ = sizeof(T) != 2 || - ParseEnvironmentVariableWithDefault(attention::kDisableFusedAttention, false); + disable_fused_self_attention_ = sizeof(T) != 2 || + ParseEnvironmentVariableWithDefault(attention::kDisableFusedSelfAttention, false); enable_trt_flash_attention_ = sizeof(T) == 2 && !ParseEnvironmentVariableWithDefault(attention::kDisableTrtFlashAttention, false); + enable_fused_causal_attention_ = sizeof(T) == 2 && + ParseEnvironmentVariableWithDefault(attention::kEnableFusedCausalAttention, false); + #if USE_FLASH_ATTENTION disable_memory_efficient_attention_ = ParseEnvironmentVariableWithDefault(attention::kDisableMemoryEfficientAttention, false); #else @@ -97,14 +100,13 @@ Status Attention::ComputeInternal(OpKernelContext* context) const { int sm = device_prop.major * 10 + device_prop.minor; bool is_mask_1d_seq_len = parameters.mask_type == AttentionMaskType::MASK_1D_KEY_SEQ_LEN; - if (is_unidirectional_) { // GPT + if (is_unidirectional_ && enable_fused_causal_attention_) { // GPT // GPT fused kernels requires left side padding. mask can be: // none (no padding), 1D sequence lengths or 2d mask. // Fused kernels don't support different sequence lengths of q and kv, so only apply to the first token // where past state is empty. bool is_mask_2d_key_padding = parameters.mask_type == AttentionMaskType::MASK_2D_KEY_PADDING; - bool use_causal_fused_runner = !disable_fused_runner_ && - (nullptr == mask_index || is_mask_1d_seq_len || is_mask_2d_key_padding) && + bool use_causal_fused_runner = (nullptr == mask_index || is_mask_1d_seq_len || is_mask_2d_key_padding) && nullptr == relative_position_bias && parameters.past_sequence_length == 0 && parameters.hidden_size == parameters.v_hidden_size && @@ -121,7 +123,7 @@ Status Attention::ComputeInternal(OpKernelContext* context) const { fused_runner = fused_fp16_runner_.get(); } } else { // BERT - bool use_fused_runner = !disable_fused_runner_ && + bool use_fused_runner = !disable_fused_self_attention_ && (nullptr == mask_index || is_mask_1d_seq_len) && nullptr == past && nullptr == present && diff --git a/onnxruntime/contrib_ops/cuda/bert/attention.h b/onnxruntime/contrib_ops/cuda/bert/attention.h index 13b2019b21..ba7c56c04f 100644 --- a/onnxruntime/contrib_ops/cuda/bert/attention.h +++ b/onnxruntime/contrib_ops/cuda/bert/attention.h @@ -21,8 +21,9 @@ class Attention final : public CudaKernel, public AttentionBase { Status ComputeInternal(OpKernelContext* context) const override; protected: - bool disable_fused_runner_; + bool disable_fused_self_attention_; bool enable_trt_flash_attention_; + bool enable_fused_causal_attention_; bool disable_memory_efficient_attention_; mutable std::unique_ptr fused_fp16_runner_; }; diff --git a/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu b/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu index 41f19f460e..c8ff075d24 100644 --- a/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu +++ b/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu @@ -620,12 +620,14 @@ Status QkvToContext( if (use_fused_kernel || use_fused_causal) { int* sequence_offset = reinterpret_cast(scratch1); if (parameters.mask_type == AttentionMaskType::MASK_2D_KEY_PADDING) { + DUMP_TENSOR_D("mask", reinterpret_cast(data.mask_index), batch_size, sequence_length); LaunchTrtSequenceOffset2d(sequence_offset, data.mask_index, batch_size, sequence_length, stream); } else { sequence_offset = GetCumulatedSequenceLength(data.cumulated_sequence_length_q_cache, data.mask_index, batch_size, sequence_length, stream, sequence_offset); } + DUMP_TENSOR_D("sequence_offset", sequence_offset, 1, (data.mask_index != nullptr ? 2 : 1) * batch_size + 1); CUDA_RETURN_IF_ERROR(cudaGetLastError()); FusedMHARunnerFP16v2* fused_fp16_runner = reinterpret_cast(fused_runner); diff --git a/onnxruntime/contrib_ops/cuda/bert/multihead_attention.cc b/onnxruntime/contrib_ops/cuda/bert/multihead_attention.cc index 321c2a1df0..d87f122045 100644 --- a/onnxruntime/contrib_ops/cuda/bert/multihead_attention.cc +++ b/onnxruntime/contrib_ops/cuda/bert/multihead_attention.cc @@ -42,8 +42,8 @@ MultiHeadAttention::MultiHeadAttention(const OpKernelInfo& info) mask_filter_value_ = info.GetAttrOrDefault("mask_filter_value", -10000.0f); - disable_fused_runner_ = sizeof(T) != 2 || - ParseEnvironmentVariableWithDefault(attention::kDisableFusedAttention, false); + disable_fused_self_attention_ = sizeof(T) != 2 || + ParseEnvironmentVariableWithDefault(attention::kDisableFusedSelfAttention, false); enable_trt_flash_attention_ = sizeof(T) == 2 && !ParseEnvironmentVariableWithDefault(attention::kDisableTrtFlashAttention, false); @@ -124,7 +124,7 @@ Status MultiHeadAttention::ComputeInternal(OpKernelContext* context) const { } } - bool use_fused_runner = !disable_fused_runner_ && + bool use_fused_runner = !disable_fused_self_attention_ && fused_cross_attention_kernel == nullptr && nullptr == relative_position_bias && (value != nullptr || key == nullptr) && diff --git a/onnxruntime/contrib_ops/cuda/bert/multihead_attention.h b/onnxruntime/contrib_ops/cuda/bert/multihead_attention.h index 928dbd1c4a..b9cf271db8 100644 --- a/onnxruntime/contrib_ops/cuda/bert/multihead_attention.h +++ b/onnxruntime/contrib_ops/cuda/bert/multihead_attention.h @@ -24,7 +24,7 @@ class MultiHeadAttention final : public CudaKernel { protected: int num_heads_; // number of attention heads float mask_filter_value_; - bool disable_fused_runner_; + bool disable_fused_self_attention_; bool enable_trt_flash_attention_; bool disable_fused_cross_attention_; bool disable_memory_efficient_attention_; diff --git a/onnxruntime/python/tools/transformers/benchmark_helper.py b/onnxruntime/python/tools/transformers/benchmark_helper.py index 2b5c53b867..bf8dd931c6 100644 --- a/onnxruntime/python/tools/transformers/benchmark_helper.py +++ b/onnxruntime/python/tools/transformers/benchmark_helper.py @@ -543,6 +543,7 @@ def get_ort_environment_variables(): # Environment variables might impact ORT performance on transformer models. Note that they are for testing only. env_names = [ "ORT_DISABLE_FUSED_ATTENTION", + "ORT_ENABLE_FUSED_CAUSAL_ATTENTION", "ORT_DISABLE_FUSED_CROSS_ATTENTION", "ORT_DISABLE_TRT_FLASH_ATTENTION", "ORT_DISABLE_MEMORY_EFFICIENT_ATTENTION", diff --git a/onnxruntime/python/tools/transformers/models/gpt2/gpt2_helper.py b/onnxruntime/python/tools/transformers/models/gpt2/gpt2_helper.py index f5de69e8f0..7eec8575f7 100644 --- a/onnxruntime/python/tools/transformers/models/gpt2/gpt2_helper.py +++ b/onnxruntime/python/tools/transformers/models/gpt2/gpt2_helper.py @@ -24,7 +24,6 @@ sys.path.append(os.path.join(os.path.dirname(__file__), "..", "..")) from benchmark_helper import Precision from float16 import float_to_float16_max_diff -from fusion_options import AttentionMaskFormat from io_binding_helper import IOBindingHelper from onnx_model import OnnxModel from torch_onnx_export_helper import torch_onnx_export @@ -188,6 +187,7 @@ class Gpt2Helper: input_ids_dtype: torch.dtype = torch.int32, position_ids_dtype: torch.dtype = torch.int32, attention_mask_dtype: torch.dtype = torch.int32, + left_side_padding: bool = True, ) -> Gpt2Inputs: """Create random inputs for GPT2 model. Returns torch tensors of input_ids, position_ids, attention_mask and a list of past state tensors. @@ -218,9 +218,14 @@ class Gpt2Helper: dtype=attention_mask_dtype, device=device, ) + if total_sequence_length >= 2: - padding_position = random.randint(0, total_sequence_length - 1) # test input with padding. - attention_mask[:, padding_position] = 0 + for i in range(batch_size): + padding_length = random.randint(0, total_sequence_length - 1) + if left_side_padding: + attention_mask[i, :padding_length] = 0 + else: # right side padding + attention_mask[i, total_sequence_length - padding_length :] = 0 # Deduce position_ids from attention mask position_ids = None @@ -517,11 +522,6 @@ class Gpt2Helper: optimization_options = FusionOptions("gpt2") - if is_float16 and stage == 1: - # For init_decoder, enable mask index to use fused causal cuda kernel. - # Potentially, we can add other optimization like unpad for effective transformer - optimization_options.attention_mask_format = AttentionMaskFormat.MaskIndexEnd - # TODO(hasesh): Investigate parity issue for GPT-2 fp16 when SkipLayerNormalization # is enabled if is_float16: @@ -841,6 +841,7 @@ class Gpt2Helper: input_ids_dtype=input_ids_dtype, position_ids_dtype=position_ids_dtype, attention_mask_dtype=attention_mask_dtype, + left_side_padding=True, ) outputs = Gpt2Helper.pytorch_inference(model, dummy_inputs) if use_io_binding: @@ -868,6 +869,7 @@ class Gpt2Helper: max_abs_diff_list.append(max_abs_diff) if is_all_close: passed_test_cases += 1 + if is_top1_matched: top1_matched_cases += 1 top1_matched_cases_per_run[run_id] += 1 diff --git a/onnxruntime/test/contrib_ops/attention_op_test.cc b/onnxruntime/test/contrib_ops/attention_op_test.cc index 0ea85dfdab..daeec7a64c 100644 --- a/onnxruntime/test/contrib_ops/attention_op_test.cc +++ b/onnxruntime/test/contrib_ops/attention_op_test.cc @@ -930,7 +930,8 @@ TEST(AttentionTest, Causal_EmptyPastState) { ScopedEnvironmentVariables scoped_env_vars{ EnvVarMap{ {onnxruntime::contrib::attention::kDisableTrtFlashAttention, "1"}, - {onnxruntime::contrib::attention::kDisableFusedAttention, "1"}}}; + {onnxruntime::contrib::attention::kEnableFusedCausalAttention, "0"}, + {onnxruntime::contrib::attention::kDisableFusedSelfAttention, "1"}}}; RunAttentionTest(input_data, weight_data, bias_data, mask_index_data, output_data, batch_size, sequence_length, hidden_size, number_of_heads, use_float16, is_unidirectional, use_past_state, past_sequence_length, &past_data, &present_data); @@ -941,7 +942,8 @@ TEST(AttentionTest, Causal_EmptyPastState) { ScopedEnvironmentVariables scoped_env_vars{ EnvVarMap{ {onnxruntime::contrib::attention::kDisableTrtFlashAttention, "1"}, - {onnxruntime::contrib::attention::kDisableFusedAttention, "0"}}}; + {onnxruntime::contrib::attention::kEnableFusedCausalAttention, "1"}, + {onnxruntime::contrib::attention::kDisableFusedSelfAttention, "0"}}}; RunAttentionTest(input_data, weight_data, bias_data, mask_index_data, output_data, batch_size, sequence_length, hidden_size, number_of_heads, use_float16, is_unidirectional, use_past_state, past_sequence_length, &past_data, &present_data); @@ -952,7 +954,8 @@ TEST(AttentionTest, Causal_EmptyPastState) { ScopedEnvironmentVariables scoped_env_vars{ EnvVarMap{ {onnxruntime::contrib::attention::kDisableTrtFlashAttention, "0"}, - {onnxruntime::contrib::attention::kDisableFusedAttention, "0"}}}; + {onnxruntime::contrib::attention::kEnableFusedCausalAttention, "1"}, + {onnxruntime::contrib::attention::kDisableFusedSelfAttention, "0"}}}; RunAttentionTest(input_data, weight_data, bias_data, mask_index_data, output_data, batch_size, sequence_length, hidden_size, number_of_heads, use_float16, is_unidirectional, use_past_state, past_sequence_length, &past_data, &present_data); diff --git a/onnxruntime/test/contrib_ops/multihead_attention_op_test.cc b/onnxruntime/test/contrib_ops/multihead_attention_op_test.cc index 83b9f62865..646f898ed0 100644 --- a/onnxruntime/test/contrib_ops/multihead_attention_op_test.cc +++ b/onnxruntime/test/contrib_ops/multihead_attention_op_test.cc @@ -181,7 +181,7 @@ static void RunMultiHeadAttentionKernel( ScopedEnvironmentVariables scoped_env_vars{ EnvVarMap{ {onnxruntime::contrib::attention::kDisableTrtFlashAttention, "0"}, - {onnxruntime::contrib::attention::kDisableFusedAttention, "0"}, + {onnxruntime::contrib::attention::kDisableFusedSelfAttention, "0"}, {onnxruntime::contrib::attention::kDisableFusedCrossAttention, "0"}, {onnxruntime::contrib::attention::kDisableMemoryEfficientAttention, "0"}}}; RunMultiHeadAttentionTest( @@ -195,7 +195,7 @@ static void RunMultiHeadAttentionKernel( ScopedEnvironmentVariables scoped_env_vars{ EnvVarMap{ {onnxruntime::contrib::attention::kDisableTrtFlashAttention, "1"}, - {onnxruntime::contrib::attention::kDisableFusedAttention, "1"}, + {onnxruntime::contrib::attention::kDisableFusedSelfAttention, "1"}, {onnxruntime::contrib::attention::kDisableFusedCrossAttention, "1"}, {onnxruntime::contrib::attention::kDisableMemoryEfficientAttention, "1"}}}; RunMultiHeadAttentionTest( @@ -209,7 +209,7 @@ static void RunMultiHeadAttentionKernel( ScopedEnvironmentVariables scoped_env_vars{ EnvVarMap{ {onnxruntime::contrib::attention::kDisableTrtFlashAttention, "1"}, - {onnxruntime::contrib::attention::kDisableFusedAttention, "1"}, + {onnxruntime::contrib::attention::kDisableFusedSelfAttention, "1"}, {onnxruntime::contrib::attention::kDisableFusedCrossAttention, "0"}, {onnxruntime::contrib::attention::kDisableMemoryEfficientAttention, "1"}}}; RunMultiHeadAttentionTest( @@ -224,7 +224,7 @@ static void RunMultiHeadAttentionKernel( ScopedEnvironmentVariables scoped_env_vars{ EnvVarMap{ {onnxruntime::contrib::attention::kDisableTrtFlashAttention, "1"}, - {onnxruntime::contrib::attention::kDisableFusedAttention, "1"}, + {onnxruntime::contrib::attention::kDisableFusedSelfAttention, "1"}, {onnxruntime::contrib::attention::kDisableFusedCrossAttention, "1"}, {onnxruntime::contrib::attention::kDisableMemoryEfficientAttention, "0"}}}; RunMultiHeadAttentionTest( @@ -239,7 +239,7 @@ static void RunMultiHeadAttentionKernel( ScopedEnvironmentVariables scoped_env_vars{ EnvVarMap{ {onnxruntime::contrib::attention::kDisableTrtFlashAttention, "0"}, - {onnxruntime::contrib::attention::kDisableFusedAttention, "0"}, + {onnxruntime::contrib::attention::kDisableFusedSelfAttention, "0"}, {onnxruntime::contrib::attention::kDisableFusedCrossAttention, "1"}, {onnxruntime::contrib::attention::kDisableMemoryEfficientAttention, "1"}}}; RunMultiHeadAttentionTest(