diff --git a/docs/ContribOperators.md b/docs/ContribOperators.md index 7f327c80cf..40f70e0b6b 100644 --- a/docs/ContribOperators.md +++ b/docs/ContribOperators.md @@ -131,6 +131,8 @@ This version of the operator has been available since version 1 of the 'com.micr
Corresponding past and present are same tensor, its size is (2, batch_size, num_heads, max_sequence_length, head_size)
qkv_hidden_sizes : list of ints
Hidden dimension of Q, K, V: hidden_size, hidden_size and v_hidden_size
+
scale : float
+
Custom scale will be used if specified. Default value is 1/sqrt(head_size)
unidirectional : int
Whether every token can only attend to previous tokens. Default value is 0.
@@ -970,7 +972,7 @@ This version of the operator has been available since version 1 of the 'com.micr
mask_filter_value : float
-
The value to be filled in the attention mask. Default value is negative infinity
+
The value to be filled in the attention mask. Default value is -10000.0f
num_heads : int (required)
Number of attention heads
@@ -2125,7 +2127,7 @@ This version of the operator has been available since version 1 of the 'com.micr
mask_filter_value : float
-
The value to be filled in the attention mask. Default value is negative infinity
+
The value to be filled in the attention mask. Default value is -10000.0f
num_heads : int (required)
Number of attention heads
@@ -2410,6 +2412,8 @@ This version of the operator has been available since version 1 of the 'com.micr
Number of attention heads
past_present_share_buffer : int
Corresponding past and present are same tensor, its shape is (2, batch_size, num_heads, max_sequence_length, head_size)
+
scale : float
+
Custom scale will be used if specified. Default value is 1/sqrt(head_size)
unidirectional : int
Whether every token can only attend to previous tokens. Default value is 0.
diff --git a/onnxruntime/contrib_ops/cpu/bert/attention_base.cc b/onnxruntime/contrib_ops/cpu/bert/attention_base.cc index 4b4f9d4029..affe7cab1d 100644 --- a/onnxruntime/contrib_ops/cpu/bert/attention_base.cc +++ b/onnxruntime/contrib_ops/cpu/bert/attention_base.cc @@ -249,6 +249,7 @@ Status AttentionBase::CheckInputs(const TensorShape& input_shape, output_parameters->is_unidirectional = is_unidirectional_; output_parameters->past_present_share_buffer = (past_present_share_buffer_ != 0 && past != nullptr); output_parameters->mask_filter_value = mask_filter_value_; + output_parameters->scale = scale_; output_parameters->mask_type = mask_type; } diff --git a/onnxruntime/contrib_ops/cpu/bert/attention_base.h b/onnxruntime/contrib_ops/cpu/bert/attention_base.h index 689db55965..2c49f196d5 100644 --- a/onnxruntime/contrib_ops/cpu/bert/attention_base.h +++ b/onnxruntime/contrib_ops/cpu/bert/attention_base.h @@ -38,6 +38,7 @@ class AttentionBase { is_unidirectional_ = info.GetAttrOrDefault("unidirectional", 0) == 1; mask_filter_value_ = info.GetAttrOrDefault("mask_filter_value", -10000.0f); + scale_ = info.GetAttrOrDefault("scale", 0.0f); if (!info.GetAttrs("qkv_hidden_sizes", qkv_hidden_sizes_).IsOK()) { qkv_hidden_sizes_.clear(); @@ -70,6 +71,7 @@ class AttentionBase { bool require_same_hidden_size_; // whether the implementation supports different hidden sizes of Q/K/V. bool past_present_share_buffer_; // whether or not the past (if used) and present tensor share the same buffer float mask_filter_value_; // the value to be used for filtered out positions + float scale_; // the scale to be used for softmax }; } // namespace contrib diff --git a/onnxruntime/contrib_ops/cpu/bert/attention_common.h b/onnxruntime/contrib_ops/cpu/bert/attention_common.h index d69644ae0c..849937d8cf 100644 --- a/onnxruntime/contrib_ops/cpu/bert/attention_common.h +++ b/onnxruntime/contrib_ops/cpu/bert/attention_common.h @@ -37,19 +37,20 @@ enum AttentionKernelType{ struct AttentionParameters { int batch_size; int sequence_length; - int kv_sequence_length; // input sequence length of K or V - int past_sequence_length; // sequence length in past state of K or V - int total_sequence_length; // total sequence length of K or V - int max_sequence_length; // max sequence length from 4D mask - int input_hidden_size; // first dimension of weights for input projection - int hidden_size; // hidden size of Q or K - int head_size; // hidden size per head of Q or K - int v_hidden_size; // hidden size of V - int v_head_size; // hidden size per head of V + int kv_sequence_length; // input sequence length of K or V + int past_sequence_length; // sequence length in past state of K or V + int total_sequence_length; // total sequence length of K or V + int max_sequence_length; // max sequence length from 4D mask + int input_hidden_size; // first dimension of weights for input projection + int hidden_size; // hidden size of Q or K + int head_size; // hidden size per head of Q or K + int v_hidden_size; // hidden size of V + int v_head_size; // hidden size per head of V int num_heads; bool is_unidirectional; bool past_present_share_buffer; float mask_filter_value; + float scale; AttentionMaskType mask_type; }; diff --git a/onnxruntime/contrib_ops/cpu/bert/attention_cpu_base.h b/onnxruntime/contrib_ops/cpu/bert/attention_cpu_base.h index 7dc3ca9873..0185fa9ea0 100644 --- a/onnxruntime/contrib_ops/cpu/bert/attention_cpu_base.h +++ b/onnxruntime/contrib_ops/cpu/bert/attention_cpu_base.h @@ -130,7 +130,7 @@ class AttentionCPUBase : public AttentionBase { } const int loop_len = batch_size * num_heads_; - const float alpha = 1.0f / sqrt(static_cast(head_size)); + const float alpha = scale_ == 0.0f ? 1.0f / sqrt(static_cast(head_size)) : scale_; // The cost of Gemm const double cost = static_cast(head_size) * sequence_length * total_sequence_length; diff --git a/onnxruntime/contrib_ops/cpu/bert/multihead_attention_helper.h b/onnxruntime/contrib_ops/cpu/bert/multihead_attention_helper.h index ade1c527f4..ce109a8372 100644 --- a/onnxruntime/contrib_ops/cpu/bert/multihead_attention_helper.h +++ b/onnxruntime/contrib_ops/cpu/bert/multihead_attention_helper.h @@ -107,6 +107,7 @@ Status CheckInputs(const T* query, output_parameters->past_present_share_buffer = false; output_parameters->mask_filter_value = mask_filter_value; output_parameters->mask_type = mask_type; + output_parameters->scale = 0.0f; } if (max_threads_per_block > 0 && num_heads > max_threads_per_block) { diff --git a/onnxruntime/contrib_ops/cuda/bert/attention.cc b/onnxruntime/contrib_ops/cuda/bert/attention.cc index d9445ff09e..3201ad1bcf 100644 --- a/onnxruntime/contrib_ops/cuda/bert/attention.cc +++ b/onnxruntime/contrib_ops/cuda/bert/attention.cc @@ -107,7 +107,7 @@ Status Attention::ComputeInternal(OpKernelContext* context) const { // Here we assume that num_heads, head_size and is_unidirectional does not change for an Attention node. if (nullptr == fused_fp16_runner_.get()) { fused_fp16_runner_.reset(new FusedMHARunnerFP16v2(num_heads_, parameters.head_size, sm, is_unidirectional_, - enable_flash_attention_)); + enable_flash_attention_, parameters.scale)); } // Here we assume all causal kernels can be loaded into shared memory. TODO: add a function to check. @@ -128,7 +128,7 @@ Status Attention::ComputeInternal(OpKernelContext* context) const { // Here we assume that num_heads, head_size and is_unidirectional does not change for an Attention node. if (nullptr == fused_fp16_runner_.get()) { fused_fp16_runner_.reset(new FusedMHARunnerFP16v2(num_heads_, parameters.head_size, sm, is_unidirectional_, - enable_flash_attention_)); + enable_flash_attention_, parameters.scale)); } // In case some kernel not loaded due to shared memory limit, we need to double check here. diff --git a/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu b/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu index 6ede1c3600..091e8de13a 100644 --- a/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu +++ b/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu @@ -540,8 +540,9 @@ Status QkvToContext( float zero = 0.f; // For raw attention mask, the scalar 1/sqrt(H) is moved to combine with softmax computation. - const float rsqrt_head_size = 1.f / sqrt(static_cast(qk_head_size)); - float alpha = use_raw_attention_mask ? one : rsqrt_head_size; + const float scale = parameters.scale == 0.0f ? 1.f / sqrt(static_cast(qk_head_size)) + : parameters.scale; + float alpha = use_raw_attention_mask ? one : scale; cublasSetStream(cublas, stream); @@ -570,7 +571,7 @@ Status QkvToContext( ORT_RETURN_IF_ERROR( ComputeSoftmaxWithRawMask(stream, total_sequence_length, sequence_length, batch_size, num_heads, mask_index, nullptr, data.extra_add_qk, scratch1, scratch2, - parameters.is_unidirectional, rsqrt_head_size, mask_dimension, + parameters.is_unidirectional, scale, mask_dimension, parameters.max_sequence_length, use_persistent_softmax, persistent_softmax_workspace, mask_filter_value)); } else if (nullptr != mask_index) { // 1d mask index diff --git a/onnxruntime/contrib_ops/cuda/bert/multihead_attention.cc b/onnxruntime/contrib_ops/cuda/bert/multihead_attention.cc index 18658a7eef..c9baf66d70 100644 --- a/onnxruntime/contrib_ops/cuda/bert/multihead_attention.cc +++ b/onnxruntime/contrib_ops/cuda/bert/multihead_attention.cc @@ -114,7 +114,7 @@ Status MultiHeadAttention::ComputeInternal(OpKernelContext* context) const { if (nullptr == fused_fp16_runner_.get()) { constexpr bool is_unidirectional = false; fused_fp16_runner_.reset(new FusedMHARunnerFP16v2( - num_heads_, parameters.head_size, sm, is_unidirectional, enable_flash_attention_)); + num_heads_, parameters.head_size, sm, is_unidirectional, enable_flash_attention_, parameters.scale)); } // In case some kernel not loaded due to shared memory limit, we need to double check here. diff --git a/onnxruntime/contrib_ops/cuda/bert/tensorrt_fused_multihead_attention/mha_runner.cu b/onnxruntime/contrib_ops/cuda/bert/tensorrt_fused_multihead_attention/mha_runner.cu index 4514fe384d..1b4114fe69 100644 --- a/onnxruntime/contrib_ops/cuda/bert/tensorrt_fused_multihead_attention/mha_runner.cu +++ b/onnxruntime/contrib_ops/cuda/bert/tensorrt_fused_multihead_attention/mha_runner.cu @@ -100,7 +100,7 @@ class FusedMHARunnerFP16v2::mhaImpl { // The number of xmmas in the M dimension. We use one uint32_t per XMMA in the M dimension. xmmas_m = (S + 16 * warps_m - 1) / (16 * warps_m); - const float scale_bmm1 = interface->mRsqrtHeadSize; + const float scale_bmm1 = interface->mScale; const float scale_softmax = 1.f; // Seems to be only required for int8 const float scale_bmm2 = 1.f; @@ -121,7 +121,7 @@ class FusedMHARunnerFP16v2::mhaImpl { } void setup_causal_masked_fmha(const int S, const int B) { - const float scale_bmm1 = interface->mRsqrtHeadSize; + const float scale_bmm1 = interface->mScale; const float scale_softmax = 1.f; // Seems to be only required for int8 const float scale_bmm2 = 1.f; @@ -219,8 +219,16 @@ class FusedMHARunnerFP16v2::mhaImpl { bool has_causal_mask = false; }; -FusedMHARunnerFP16v2::FusedMHARunnerFP16v2(const int numHeads, const int headSize, const int sm, bool causal_mask, bool enable_flash_attention) - : MHARunner(numHeads, headSize, 2, causal_mask), mSm(sm), mEnableFlashAttention(enable_flash_attention), pimpl(new mhaImpl(this)) { +FusedMHARunnerFP16v2::FusedMHARunnerFP16v2(const int numHeads, + const int headSize, + const int sm, + bool causal_mask, + bool enable_flash_attention, + const float scale) + : MHARunner(numHeads, headSize, 2, causal_mask, scale), + mSm(sm), + mEnableFlashAttention(enable_flash_attention), + pimpl(new mhaImpl(this)) { } void FusedMHARunnerFP16v2::setup(const int S, const int B) { diff --git a/onnxruntime/contrib_ops/cuda/bert/tensorrt_fused_multihead_attention/mha_runner.h b/onnxruntime/contrib_ops/cuda/bert/tensorrt_fused_multihead_attention/mha_runner.h index 37946239d0..f11ceccd97 100644 --- a/onnxruntime/contrib_ops/cuda/bert/tensorrt_fused_multihead_attention/mha_runner.h +++ b/onnxruntime/contrib_ops/cuda/bert/tensorrt_fused_multihead_attention/mha_runner.h @@ -28,7 +28,7 @@ constexpr int kMinSequenceLengthFlashAttention = 385; // Multi-Head Attention runner class MHARunner { public: - MHARunner(const int numHeads, const int headSize, const int wordSize, bool causal_mask) + MHARunner(const int numHeads, const int headSize, const int wordSize, bool causal_mask, const float scale) : mS(0), mB(0), mOmatSize(0), @@ -40,7 +40,8 @@ class MHARunner { mStrideQKV(0), mLdOut(0), mStrideOut(0), - mRsqrtHeadSize(1.f / sqrtf(static_cast(headSize))), + mScale(scale == 0.0f ? 1.f / sqrtf(static_cast(headSize)) + : scale), mHasCausalMask(causal_mask) { } @@ -83,7 +84,7 @@ class MHARunner { int mLdOut; int mStrideOut; - float mRsqrtHeadSize; + float mScale; bool mHasCausalMask; }; @@ -93,7 +94,8 @@ class FusedMHARunnerFP16v2 : public MHARunner { const int headSize, const int sm, bool causal_mask, - bool enable_flash_attention); + bool enable_flash_attention, + const float scale); ~FusedMHARunnerFP16v2() = default; // for pimpl virtual void setup(const int S, const int B) override; diff --git a/onnxruntime/core/graph/contrib_ops/bert_defs.cc b/onnxruntime/core/graph/contrib_ops/bert_defs.cc index b489e04494..b4ad4d64e7 100644 --- a/onnxruntime/core/graph/contrib_ops/bert_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/bert_defs.cc @@ -203,6 +203,10 @@ ONNX_MS_OPERATOR_SET_SCHEMA( "The value to be filled in the attention mask. Default value is -10000.0f", AttributeProto::FLOAT, OPTIONAL_VALUE) + .Attr("scale", + "Custom scale will be used if specified. Default value is 1/sqrt(head_size)", + AttributeProto::FLOAT, + OPTIONAL_VALUE) .Input(0, "input", "Input tensor with shape (batch_size, sequence_length, input_hidden_size)", @@ -275,7 +279,7 @@ ONNX_MS_OPERATOR_SET_SCHEMA( OpSchema() .SetDoc(MultiHeadAttention_ver1_doc) .Attr("num_heads", "Number of attention heads", AttributeProto::INT) - .Attr("mask_filter_value", "The value to be filled in the attention mask. Default value is negative infinity", + .Attr("mask_filter_value", "The value to be filled in the attention mask. Default value is -10000.0f", AttributeProto::FLOAT, OPTIONAL_VALUE) .Input(0, "query", @@ -349,7 +353,7 @@ ONNX_MS_OPERATOR_SET_SCHEMA( OpSchema() .SetDoc(Decoder_Attention_doc) .Attr("num_heads", "Number of attention heads", AttributeProto::INT) - .Attr("mask_filter_value", "The value to be filled in the attention mask. Default value is negative infinity", + .Attr("mask_filter_value", "The value to be filled in the attention mask. Default value is -10000.0f", AttributeProto::FLOAT, OPTIONAL_VALUE) .Input(0, "query", "3D input tensor with shape (sequence_length, batch_size, hidden_size), hidden_size = num_heads * head_size", "T") .Input(1, "key", "3D input tensor with shape (total_sequence_length, batch_size, hidden_size)", "T") diff --git a/onnxruntime/core/graph/contrib_ops/quantization_defs.cc b/onnxruntime/core/graph/contrib_ops/quantization_defs.cc index 59a9bc70bf..c45b5a79e5 100644 --- a/onnxruntime/core/graph/contrib_ops/quantization_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/quantization_defs.cc @@ -952,6 +952,10 @@ ONNX_MS_OPERATOR_SET_SCHEMA( .Attr("past_present_share_buffer", "Corresponding past and present are same tensor, its shape is " "(2, batch_size, num_heads, max_sequence_length, head_size)", AttributeProto::INT, OPTIONAL_VALUE) + .Attr("scale", + "Custom scale will be used if specified. Default value is 1/sqrt(head_size)", + AttributeProto::FLOAT, + OPTIONAL_VALUE) .Input(0, "input", "3D input tensor with shape (batch_size, sequence_length, input_hidden_size)", "T1") .Input(1, "weight", "2D input tensor with shape (input_hidden_size, 3 * hidden_size), hidden_size = num_heads * head_size", diff --git a/onnxruntime/test/contrib_ops/attention_op_test.cc b/onnxruntime/test/contrib_ops/attention_op_test.cc index e910f251cd..743c05050c 100644 --- a/onnxruntime/test/contrib_ops/attention_op_test.cc +++ b/onnxruntime/test/contrib_ops/attention_op_test.cc @@ -61,7 +61,8 @@ static void RunAttentionTest( std::vector qkv_sizes = {}, const std::vector& extra_add_data = {}, int kv_sequence_length = 0, - bool past_present_share_buffer = false) { + bool past_present_share_buffer = false, + bool use_scale = false) { input_hidden_size = (input_hidden_size == 0 ? hidden_size : input_hidden_size); // By default, no pruning. kv_sequence_length = (kv_sequence_length == 0 ? sequence_length : kv_sequence_length); past_present_share_buffer = past_present_share_buffer && use_past_state; @@ -78,6 +79,9 @@ static void RunAttentionTest( tester.AddAttribute("unidirectional", static_cast(is_unidirectional ? 1 : 0)); tester.AddAttribute("past_present_share_buffer", static_cast(past_present_share_buffer ? 1 : 0)); tester.AddAttribute("mask_filter_value", static_cast(-10000.0f)); + if (use_scale && !enable_rocm) { + tester.AddAttribute("scale", static_cast(1.f / sqrt(head_size))); + } int32_t qkv_hidden_size_sum; int32_t v_hidden_size; @@ -262,19 +266,20 @@ static void RunAttentionTest( const std::vector qkv_sizes = {}, const std::vector& extra_add_data = {}, int kv_sequence_length = 0, - bool past_present_share_buffer = false) { + bool past_present_share_buffer = false, + bool use_scale = false) { RunAttentionTest(input_data, weights_data, false, bias_data, mask_index_data, output_data, batch_size, sequence_length, hidden_size, number_of_heads, use_float16, is_unidirectional, use_past_state, past_sequence_length, past_data, present_data, mask_type, input_hidden_size, max_sequence_length, disable_cpu, disable_cuda, disable_rocm, qkv_sizes, extra_add_data, - kv_sequence_length, past_present_share_buffer); + kv_sequence_length, past_present_share_buffer, use_scale); RunAttentionTest(input_data, weights_data, true, bias_data, mask_index_data, output_data, batch_size, sequence_length, hidden_size, number_of_heads, use_float16, is_unidirectional, use_past_state, past_sequence_length, past_data, present_data, mask_type, input_hidden_size, max_sequence_length, disable_cpu, disable_cuda, disable_rocm, qkv_sizes, extra_add_data, - kv_sequence_length, past_present_share_buffer); + kv_sequence_length, past_present_share_buffer, use_scale); } TEST(AttentionTest, AttentionBatch1) { @@ -1663,6 +1668,51 @@ TEST(AttentionTest, AttentionUnidirectionalAttentionMask) { AttentionMaskType::MASK_2D_KEY_PADDING); } +TEST(AttentionTest, AttentionWithNormFactor) { + int batch_size = 2; + int sequence_length = 2; + int hidden_size = 4; + int number_of_heads = 2; + + std::vector input_data = { + 0.5f, 0.2f, 0.3f, -0.6f, + 0.8f, -0.5f, 0.0f, 1.f, + 0.8f, -0.5f, 0.0f, 1.f, + 0.5f, 0.2f, 0.3f, -0.6f}; + + std::vector weight_data = { + 0.1f, -0.2f, 0.3f, 1.0f, 1.1f, 0.3f, 0.5f, 0.2f, 0.3f, -0.6f, 1.5f, 2.0f, + 0.5f, 0.1f, 0.4f, 1.6f, 1.0f, 2.0f, 0.4f, 0.8f, 0.9f, 0.1f, -1.3f, 0.7f, + 0.3f, 0.2f, 4.0f, 2.2f, 1.6f, 1.1f, 0.7f, 0.2f, 0.4f, 1.0f, 1.2f, 0.5f, + 0.2f, 0.1f, 0.4f, 1.6f, 2.4f, 3.3f, 2.1f, 4.2f, 8.4f, 0.0f, 2.1f, 3.2f}; + + std::vector bias_data = { + -0.5f, 0.6f, 1.2f, 2.1f, 0.5f, 0.7f, 0.2f, 1.2f, 0.5f, 0.4f, 0.3f, 1.2f}; + + // Test mask start position > 0. + std::vector mask_index_data = {0, 1, 1, 1}; + + std::vector output_data = { + 3.0146f, 0.1142f, 3.9834f, 5.3394f, + 8.69f, -0.13f, 4.25f, 5.65f, + 8.69f, -0.13f, 4.25f, 5.65f, + 3.96967912f, 0.07314367f, 4.25f, 5.65f}; + + bool use_float16 = false; + bool is_unidirectional = true; + bool use_past_state = false; + int past_sequence_length = 0; + const std::vector* past_data = nullptr; + const std::vector* present_data = nullptr; + RunAttentionTest(input_data, weight_data, bias_data, mask_index_data, output_data, + batch_size, sequence_length, hidden_size, number_of_heads, + use_float16, is_unidirectional, use_past_state, past_sequence_length, past_data, present_data, + AttentionMaskType::MASK_2D_KEY_PADDING, 0 /*input_hidden_size*/, 0 /*max_sequence_length*/, + false /*disable_cpu*/, false /*disable_cuda*/, true /*disable_rocm*/, {} /*qkv_sizes*/, + {} /*extra_add_data*/, 0 /*kv_sequence_length*/, false /*past_present_share_buffer*/, + true /*use_scale*/); +} + TEST(AttentionTest, AttentionMask1DEndNoWord) { int batch_size = 2; int sequence_length = 2;