diff --git a/docs/ContribOperators.md b/docs/ContribOperators.md
index 7f327c80cf..40f70e0b6b 100644
--- a/docs/ContribOperators.md
+++ b/docs/ContribOperators.md
@@ -131,6 +131,8 @@ This version of the operator has been available since version 1 of the 'com.micr
Corresponding past and present are same tensor, its size is (2, batch_size, num_heads, max_sequence_length, head_size)
qkv_hidden_sizes : list of ints
Hidden dimension of Q, K, V: hidden_size, hidden_size and v_hidden_size
+scale : float
+Custom scale will be used if specified. Default value is 1/sqrt(head_size)
unidirectional : int
Whether every token can only attend to previous tokens. Default value is 0.
@@ -970,7 +972,7 @@ This version of the operator has been available since version 1 of the 'com.micr
- mask_filter_value : float
-- The value to be filled in the attention mask. Default value is negative infinity
+- The value to be filled in the attention mask. Default value is -10000.0f
- num_heads : int (required)
- Number of attention heads
@@ -2125,7 +2127,7 @@ This version of the operator has been available since version 1 of the 'com.micr
- mask_filter_value : float
-- The value to be filled in the attention mask. Default value is negative infinity
+- The value to be filled in the attention mask. Default value is -10000.0f
- num_heads : int (required)
- Number of attention heads
@@ -2410,6 +2412,8 @@ This version of the operator has been available since version 1 of the 'com.micr
Number of attention heads
past_present_share_buffer : int
Corresponding past and present are same tensor, its shape is (2, batch_size, num_heads, max_sequence_length, head_size)
+scale : float
+Custom scale will be used if specified. Default value is 1/sqrt(head_size)
unidirectional : int
Whether every token can only attend to previous tokens. Default value is 0.
diff --git a/onnxruntime/contrib_ops/cpu/bert/attention_base.cc b/onnxruntime/contrib_ops/cpu/bert/attention_base.cc
index 4b4f9d4029..affe7cab1d 100644
--- a/onnxruntime/contrib_ops/cpu/bert/attention_base.cc
+++ b/onnxruntime/contrib_ops/cpu/bert/attention_base.cc
@@ -249,6 +249,7 @@ Status AttentionBase::CheckInputs(const TensorShape& input_shape,
output_parameters->is_unidirectional = is_unidirectional_;
output_parameters->past_present_share_buffer = (past_present_share_buffer_ != 0 && past != nullptr);
output_parameters->mask_filter_value = mask_filter_value_;
+ output_parameters->scale = scale_;
output_parameters->mask_type = mask_type;
}
diff --git a/onnxruntime/contrib_ops/cpu/bert/attention_base.h b/onnxruntime/contrib_ops/cpu/bert/attention_base.h
index 689db55965..2c49f196d5 100644
--- a/onnxruntime/contrib_ops/cpu/bert/attention_base.h
+++ b/onnxruntime/contrib_ops/cpu/bert/attention_base.h
@@ -38,6 +38,7 @@ class AttentionBase {
is_unidirectional_ = info.GetAttrOrDefault("unidirectional", 0) == 1;
mask_filter_value_ = info.GetAttrOrDefault("mask_filter_value", -10000.0f);
+ scale_ = info.GetAttrOrDefault("scale", 0.0f);
if (!info.GetAttrs("qkv_hidden_sizes", qkv_hidden_sizes_).IsOK()) {
qkv_hidden_sizes_.clear();
@@ -70,6 +71,7 @@ class AttentionBase {
bool require_same_hidden_size_; // whether the implementation supports different hidden sizes of Q/K/V.
bool past_present_share_buffer_; // whether or not the past (if used) and present tensor share the same buffer
float mask_filter_value_; // the value to be used for filtered out positions
+ float scale_; // the scale to be used for softmax
};
} // namespace contrib
diff --git a/onnxruntime/contrib_ops/cpu/bert/attention_common.h b/onnxruntime/contrib_ops/cpu/bert/attention_common.h
index d69644ae0c..849937d8cf 100644
--- a/onnxruntime/contrib_ops/cpu/bert/attention_common.h
+++ b/onnxruntime/contrib_ops/cpu/bert/attention_common.h
@@ -37,19 +37,20 @@ enum AttentionKernelType{
struct AttentionParameters {
int batch_size;
int sequence_length;
- int kv_sequence_length; // input sequence length of K or V
- int past_sequence_length; // sequence length in past state of K or V
- int total_sequence_length; // total sequence length of K or V
- int max_sequence_length; // max sequence length from 4D mask
- int input_hidden_size; // first dimension of weights for input projection
- int hidden_size; // hidden size of Q or K
- int head_size; // hidden size per head of Q or K
- int v_hidden_size; // hidden size of V
- int v_head_size; // hidden size per head of V
+ int kv_sequence_length; // input sequence length of K or V
+ int past_sequence_length; // sequence length in past state of K or V
+ int total_sequence_length; // total sequence length of K or V
+ int max_sequence_length; // max sequence length from 4D mask
+ int input_hidden_size; // first dimension of weights for input projection
+ int hidden_size; // hidden size of Q or K
+ int head_size; // hidden size per head of Q or K
+ int v_hidden_size; // hidden size of V
+ int v_head_size; // hidden size per head of V
int num_heads;
bool is_unidirectional;
bool past_present_share_buffer;
float mask_filter_value;
+ float scale;
AttentionMaskType mask_type;
};
diff --git a/onnxruntime/contrib_ops/cpu/bert/attention_cpu_base.h b/onnxruntime/contrib_ops/cpu/bert/attention_cpu_base.h
index 7dc3ca9873..0185fa9ea0 100644
--- a/onnxruntime/contrib_ops/cpu/bert/attention_cpu_base.h
+++ b/onnxruntime/contrib_ops/cpu/bert/attention_cpu_base.h
@@ -130,7 +130,7 @@ class AttentionCPUBase : public AttentionBase {
}
const int loop_len = batch_size * num_heads_;
- const float alpha = 1.0f / sqrt(static_cast(head_size));
+ const float alpha = scale_ == 0.0f ? 1.0f / sqrt(static_cast(head_size)) : scale_;
// The cost of Gemm
const double cost = static_cast(head_size) * sequence_length * total_sequence_length;
diff --git a/onnxruntime/contrib_ops/cpu/bert/multihead_attention_helper.h b/onnxruntime/contrib_ops/cpu/bert/multihead_attention_helper.h
index ade1c527f4..ce109a8372 100644
--- a/onnxruntime/contrib_ops/cpu/bert/multihead_attention_helper.h
+++ b/onnxruntime/contrib_ops/cpu/bert/multihead_attention_helper.h
@@ -107,6 +107,7 @@ Status CheckInputs(const T* query,
output_parameters->past_present_share_buffer = false;
output_parameters->mask_filter_value = mask_filter_value;
output_parameters->mask_type = mask_type;
+ output_parameters->scale = 0.0f;
}
if (max_threads_per_block > 0 && num_heads > max_threads_per_block) {
diff --git a/onnxruntime/contrib_ops/cuda/bert/attention.cc b/onnxruntime/contrib_ops/cuda/bert/attention.cc
index d9445ff09e..3201ad1bcf 100644
--- a/onnxruntime/contrib_ops/cuda/bert/attention.cc
+++ b/onnxruntime/contrib_ops/cuda/bert/attention.cc
@@ -107,7 +107,7 @@ Status Attention::ComputeInternal(OpKernelContext* context) const {
// Here we assume that num_heads, head_size and is_unidirectional does not change for an Attention node.
if (nullptr == fused_fp16_runner_.get()) {
fused_fp16_runner_.reset(new FusedMHARunnerFP16v2(num_heads_, parameters.head_size, sm, is_unidirectional_,
- enable_flash_attention_));
+ enable_flash_attention_, parameters.scale));
}
// Here we assume all causal kernels can be loaded into shared memory. TODO: add a function to check.
@@ -128,7 +128,7 @@ Status Attention::ComputeInternal(OpKernelContext* context) const {
// Here we assume that num_heads, head_size and is_unidirectional does not change for an Attention node.
if (nullptr == fused_fp16_runner_.get()) {
fused_fp16_runner_.reset(new FusedMHARunnerFP16v2(num_heads_, parameters.head_size, sm, is_unidirectional_,
- enable_flash_attention_));
+ enable_flash_attention_, parameters.scale));
}
// In case some kernel not loaded due to shared memory limit, we need to double check here.
diff --git a/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu b/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu
index 6ede1c3600..091e8de13a 100644
--- a/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu
+++ b/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu
@@ -540,8 +540,9 @@ Status QkvToContext(
float zero = 0.f;
// For raw attention mask, the scalar 1/sqrt(H) is moved to combine with softmax computation.
- const float rsqrt_head_size = 1.f / sqrt(static_cast(qk_head_size));
- float alpha = use_raw_attention_mask ? one : rsqrt_head_size;
+ const float scale = parameters.scale == 0.0f ? 1.f / sqrt(static_cast(qk_head_size))
+ : parameters.scale;
+ float alpha = use_raw_attention_mask ? one : scale;
cublasSetStream(cublas, stream);
@@ -570,7 +571,7 @@ Status QkvToContext(
ORT_RETURN_IF_ERROR(
ComputeSoftmaxWithRawMask(stream, total_sequence_length, sequence_length, batch_size, num_heads,
mask_index, nullptr, data.extra_add_qk, scratch1, scratch2,
- parameters.is_unidirectional, rsqrt_head_size, mask_dimension,
+ parameters.is_unidirectional, scale, mask_dimension,
parameters.max_sequence_length, use_persistent_softmax,
persistent_softmax_workspace, mask_filter_value));
} else if (nullptr != mask_index) { // 1d mask index
diff --git a/onnxruntime/contrib_ops/cuda/bert/multihead_attention.cc b/onnxruntime/contrib_ops/cuda/bert/multihead_attention.cc
index 18658a7eef..c9baf66d70 100644
--- a/onnxruntime/contrib_ops/cuda/bert/multihead_attention.cc
+++ b/onnxruntime/contrib_ops/cuda/bert/multihead_attention.cc
@@ -114,7 +114,7 @@ Status MultiHeadAttention::ComputeInternal(OpKernelContext* context) const {
if (nullptr == fused_fp16_runner_.get()) {
constexpr bool is_unidirectional = false;
fused_fp16_runner_.reset(new FusedMHARunnerFP16v2(
- num_heads_, parameters.head_size, sm, is_unidirectional, enable_flash_attention_));
+ num_heads_, parameters.head_size, sm, is_unidirectional, enable_flash_attention_, parameters.scale));
}
// In case some kernel not loaded due to shared memory limit, we need to double check here.
diff --git a/onnxruntime/contrib_ops/cuda/bert/tensorrt_fused_multihead_attention/mha_runner.cu b/onnxruntime/contrib_ops/cuda/bert/tensorrt_fused_multihead_attention/mha_runner.cu
index 4514fe384d..1b4114fe69 100644
--- a/onnxruntime/contrib_ops/cuda/bert/tensorrt_fused_multihead_attention/mha_runner.cu
+++ b/onnxruntime/contrib_ops/cuda/bert/tensorrt_fused_multihead_attention/mha_runner.cu
@@ -100,7 +100,7 @@ class FusedMHARunnerFP16v2::mhaImpl {
// The number of xmmas in the M dimension. We use one uint32_t per XMMA in the M dimension.
xmmas_m = (S + 16 * warps_m - 1) / (16 * warps_m);
- const float scale_bmm1 = interface->mRsqrtHeadSize;
+ const float scale_bmm1 = interface->mScale;
const float scale_softmax = 1.f; // Seems to be only required for int8
const float scale_bmm2 = 1.f;
@@ -121,7 +121,7 @@ class FusedMHARunnerFP16v2::mhaImpl {
}
void setup_causal_masked_fmha(const int S, const int B) {
- const float scale_bmm1 = interface->mRsqrtHeadSize;
+ const float scale_bmm1 = interface->mScale;
const float scale_softmax = 1.f; // Seems to be only required for int8
const float scale_bmm2 = 1.f;
@@ -219,8 +219,16 @@ class FusedMHARunnerFP16v2::mhaImpl {
bool has_causal_mask = false;
};
-FusedMHARunnerFP16v2::FusedMHARunnerFP16v2(const int numHeads, const int headSize, const int sm, bool causal_mask, bool enable_flash_attention)
- : MHARunner(numHeads, headSize, 2, causal_mask), mSm(sm), mEnableFlashAttention(enable_flash_attention), pimpl(new mhaImpl(this)) {
+FusedMHARunnerFP16v2::FusedMHARunnerFP16v2(const int numHeads,
+ const int headSize,
+ const int sm,
+ bool causal_mask,
+ bool enable_flash_attention,
+ const float scale)
+ : MHARunner(numHeads, headSize, 2, causal_mask, scale),
+ mSm(sm),
+ mEnableFlashAttention(enable_flash_attention),
+ pimpl(new mhaImpl(this)) {
}
void FusedMHARunnerFP16v2::setup(const int S, const int B) {
diff --git a/onnxruntime/contrib_ops/cuda/bert/tensorrt_fused_multihead_attention/mha_runner.h b/onnxruntime/contrib_ops/cuda/bert/tensorrt_fused_multihead_attention/mha_runner.h
index 37946239d0..f11ceccd97 100644
--- a/onnxruntime/contrib_ops/cuda/bert/tensorrt_fused_multihead_attention/mha_runner.h
+++ b/onnxruntime/contrib_ops/cuda/bert/tensorrt_fused_multihead_attention/mha_runner.h
@@ -28,7 +28,7 @@ constexpr int kMinSequenceLengthFlashAttention = 385;
// Multi-Head Attention runner
class MHARunner {
public:
- MHARunner(const int numHeads, const int headSize, const int wordSize, bool causal_mask)
+ MHARunner(const int numHeads, const int headSize, const int wordSize, bool causal_mask, const float scale)
: mS(0),
mB(0),
mOmatSize(0),
@@ -40,7 +40,8 @@ class MHARunner {
mStrideQKV(0),
mLdOut(0),
mStrideOut(0),
- mRsqrtHeadSize(1.f / sqrtf(static_cast(headSize))),
+ mScale(scale == 0.0f ? 1.f / sqrtf(static_cast(headSize))
+ : scale),
mHasCausalMask(causal_mask) {
}
@@ -83,7 +84,7 @@ class MHARunner {
int mLdOut;
int mStrideOut;
- float mRsqrtHeadSize;
+ float mScale;
bool mHasCausalMask;
};
@@ -93,7 +94,8 @@ class FusedMHARunnerFP16v2 : public MHARunner {
const int headSize,
const int sm,
bool causal_mask,
- bool enable_flash_attention);
+ bool enable_flash_attention,
+ const float scale);
~FusedMHARunnerFP16v2() = default; // for pimpl
virtual void setup(const int S, const int B) override;
diff --git a/onnxruntime/core/graph/contrib_ops/bert_defs.cc b/onnxruntime/core/graph/contrib_ops/bert_defs.cc
index b489e04494..b4ad4d64e7 100644
--- a/onnxruntime/core/graph/contrib_ops/bert_defs.cc
+++ b/onnxruntime/core/graph/contrib_ops/bert_defs.cc
@@ -203,6 +203,10 @@ ONNX_MS_OPERATOR_SET_SCHEMA(
"The value to be filled in the attention mask. Default value is -10000.0f",
AttributeProto::FLOAT,
OPTIONAL_VALUE)
+ .Attr("scale",
+ "Custom scale will be used if specified. Default value is 1/sqrt(head_size)",
+ AttributeProto::FLOAT,
+ OPTIONAL_VALUE)
.Input(0,
"input",
"Input tensor with shape (batch_size, sequence_length, input_hidden_size)",
@@ -275,7 +279,7 @@ ONNX_MS_OPERATOR_SET_SCHEMA(
OpSchema()
.SetDoc(MultiHeadAttention_ver1_doc)
.Attr("num_heads", "Number of attention heads", AttributeProto::INT)
- .Attr("mask_filter_value", "The value to be filled in the attention mask. Default value is negative infinity",
+ .Attr("mask_filter_value", "The value to be filled in the attention mask. Default value is -10000.0f",
AttributeProto::FLOAT, OPTIONAL_VALUE)
.Input(0,
"query",
@@ -349,7 +353,7 @@ ONNX_MS_OPERATOR_SET_SCHEMA(
OpSchema()
.SetDoc(Decoder_Attention_doc)
.Attr("num_heads", "Number of attention heads", AttributeProto::INT)
- .Attr("mask_filter_value", "The value to be filled in the attention mask. Default value is negative infinity",
+ .Attr("mask_filter_value", "The value to be filled in the attention mask. Default value is -10000.0f",
AttributeProto::FLOAT, OPTIONAL_VALUE)
.Input(0, "query", "3D input tensor with shape (sequence_length, batch_size, hidden_size), hidden_size = num_heads * head_size", "T")
.Input(1, "key", "3D input tensor with shape (total_sequence_length, batch_size, hidden_size)", "T")
diff --git a/onnxruntime/core/graph/contrib_ops/quantization_defs.cc b/onnxruntime/core/graph/contrib_ops/quantization_defs.cc
index 59a9bc70bf..c45b5a79e5 100644
--- a/onnxruntime/core/graph/contrib_ops/quantization_defs.cc
+++ b/onnxruntime/core/graph/contrib_ops/quantization_defs.cc
@@ -952,6 +952,10 @@ ONNX_MS_OPERATOR_SET_SCHEMA(
.Attr("past_present_share_buffer", "Corresponding past and present are same tensor, its shape is "
"(2, batch_size, num_heads, max_sequence_length, head_size)",
AttributeProto::INT, OPTIONAL_VALUE)
+ .Attr("scale",
+ "Custom scale will be used if specified. Default value is 1/sqrt(head_size)",
+ AttributeProto::FLOAT,
+ OPTIONAL_VALUE)
.Input(0, "input", "3D input tensor with shape (batch_size, sequence_length, input_hidden_size)", "T1")
.Input(1, "weight",
"2D input tensor with shape (input_hidden_size, 3 * hidden_size), hidden_size = num_heads * head_size",
diff --git a/onnxruntime/test/contrib_ops/attention_op_test.cc b/onnxruntime/test/contrib_ops/attention_op_test.cc
index e910f251cd..743c05050c 100644
--- a/onnxruntime/test/contrib_ops/attention_op_test.cc
+++ b/onnxruntime/test/contrib_ops/attention_op_test.cc
@@ -61,7 +61,8 @@ static void RunAttentionTest(
std::vector qkv_sizes = {},
const std::vector& extra_add_data = {},
int kv_sequence_length = 0,
- bool past_present_share_buffer = false) {
+ bool past_present_share_buffer = false,
+ bool use_scale = false) {
input_hidden_size = (input_hidden_size == 0 ? hidden_size : input_hidden_size); // By default, no pruning.
kv_sequence_length = (kv_sequence_length == 0 ? sequence_length : kv_sequence_length);
past_present_share_buffer = past_present_share_buffer && use_past_state;
@@ -78,6 +79,9 @@ static void RunAttentionTest(
tester.AddAttribute("unidirectional", static_cast(is_unidirectional ? 1 : 0));
tester.AddAttribute("past_present_share_buffer", static_cast(past_present_share_buffer ? 1 : 0));
tester.AddAttribute("mask_filter_value", static_cast(-10000.0f));
+ if (use_scale && !enable_rocm) {
+ tester.AddAttribute("scale", static_cast(1.f / sqrt(head_size)));
+ }
int32_t qkv_hidden_size_sum;
int32_t v_hidden_size;
@@ -262,19 +266,20 @@ static void RunAttentionTest(
const std::vector qkv_sizes = {},
const std::vector& extra_add_data = {},
int kv_sequence_length = 0,
- bool past_present_share_buffer = false) {
+ bool past_present_share_buffer = false,
+ bool use_scale = false) {
RunAttentionTest(input_data, weights_data, false, bias_data, mask_index_data, output_data,
batch_size, sequence_length, hidden_size, number_of_heads,
use_float16, is_unidirectional, use_past_state, past_sequence_length,
past_data, present_data, mask_type, input_hidden_size, max_sequence_length,
disable_cpu, disable_cuda, disable_rocm, qkv_sizes, extra_add_data,
- kv_sequence_length, past_present_share_buffer);
+ kv_sequence_length, past_present_share_buffer, use_scale);
RunAttentionTest(input_data, weights_data, true, bias_data, mask_index_data, output_data,
batch_size, sequence_length, hidden_size, number_of_heads,
use_float16, is_unidirectional, use_past_state, past_sequence_length,
past_data, present_data, mask_type, input_hidden_size, max_sequence_length,
disable_cpu, disable_cuda, disable_rocm, qkv_sizes, extra_add_data,
- kv_sequence_length, past_present_share_buffer);
+ kv_sequence_length, past_present_share_buffer, use_scale);
}
TEST(AttentionTest, AttentionBatch1) {
@@ -1663,6 +1668,51 @@ TEST(AttentionTest, AttentionUnidirectionalAttentionMask) {
AttentionMaskType::MASK_2D_KEY_PADDING);
}
+TEST(AttentionTest, AttentionWithNormFactor) {
+ int batch_size = 2;
+ int sequence_length = 2;
+ int hidden_size = 4;
+ int number_of_heads = 2;
+
+ std::vector input_data = {
+ 0.5f, 0.2f, 0.3f, -0.6f,
+ 0.8f, -0.5f, 0.0f, 1.f,
+ 0.8f, -0.5f, 0.0f, 1.f,
+ 0.5f, 0.2f, 0.3f, -0.6f};
+
+ std::vector weight_data = {
+ 0.1f, -0.2f, 0.3f, 1.0f, 1.1f, 0.3f, 0.5f, 0.2f, 0.3f, -0.6f, 1.5f, 2.0f,
+ 0.5f, 0.1f, 0.4f, 1.6f, 1.0f, 2.0f, 0.4f, 0.8f, 0.9f, 0.1f, -1.3f, 0.7f,
+ 0.3f, 0.2f, 4.0f, 2.2f, 1.6f, 1.1f, 0.7f, 0.2f, 0.4f, 1.0f, 1.2f, 0.5f,
+ 0.2f, 0.1f, 0.4f, 1.6f, 2.4f, 3.3f, 2.1f, 4.2f, 8.4f, 0.0f, 2.1f, 3.2f};
+
+ std::vector bias_data = {
+ -0.5f, 0.6f, 1.2f, 2.1f, 0.5f, 0.7f, 0.2f, 1.2f, 0.5f, 0.4f, 0.3f, 1.2f};
+
+ // Test mask start position > 0.
+ std::vector mask_index_data = {0, 1, 1, 1};
+
+ std::vector output_data = {
+ 3.0146f, 0.1142f, 3.9834f, 5.3394f,
+ 8.69f, -0.13f, 4.25f, 5.65f,
+ 8.69f, -0.13f, 4.25f, 5.65f,
+ 3.96967912f, 0.07314367f, 4.25f, 5.65f};
+
+ bool use_float16 = false;
+ bool is_unidirectional = true;
+ bool use_past_state = false;
+ int past_sequence_length = 0;
+ const std::vector* past_data = nullptr;
+ const std::vector* present_data = nullptr;
+ RunAttentionTest(input_data, weight_data, bias_data, mask_index_data, output_data,
+ batch_size, sequence_length, hidden_size, number_of_heads,
+ use_float16, is_unidirectional, use_past_state, past_sequence_length, past_data, present_data,
+ AttentionMaskType::MASK_2D_KEY_PADDING, 0 /*input_hidden_size*/, 0 /*max_sequence_length*/,
+ false /*disable_cpu*/, false /*disable_cuda*/, true /*disable_rocm*/, {} /*qkv_sizes*/,
+ {} /*extra_add_data*/, 0 /*kv_sequence_length*/, false /*past_present_share_buffer*/,
+ true /*use_scale*/);
+}
+
TEST(AttentionTest, AttentionMask1DEndNoWord) {
int batch_size = 2;
int sequence_length = 2;