mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-19 21:32:23 +00:00
Support muP in Attention (#14348)
### Description <!-- Describe your changes. --> ### Motivation and Context <!-- - Why is this change required? What problem does it solve? - If it fixes an open issue, please link to the issue here. --> Co-authored-by: Ubuntu <wy@v100-2.0cdb2e52twzevn1i4fi45bylyg.jx.internal.cloudapp.net>
This commit is contained in:
parent
1dd07d147d
commit
668586e8f8
14 changed files with 110 additions and 32 deletions
|
|
@ -131,6 +131,8 @@ This version of the operator has been available since version 1 of the 'com.micr
|
|||
<dd>Corresponding past and present are same tensor, its size is (2, batch_size, num_heads, max_sequence_length, head_size)</dd>
|
||||
<dt><tt>qkv_hidden_sizes</tt> : list of ints</dt>
|
||||
<dd>Hidden dimension of Q, K, V: hidden_size, hidden_size and v_hidden_size</dd>
|
||||
<dt><tt>scale</tt> : float</dt>
|
||||
<dd>Custom scale will be used if specified. Default value is 1/sqrt(head_size)</dd>
|
||||
<dt><tt>unidirectional</tt> : int</dt>
|
||||
<dd>Whether every token can only attend to previous tokens. Default value is 0.</dd>
|
||||
</dl>
|
||||
|
|
@ -970,7 +972,7 @@ This version of the operator has been available since version 1 of the 'com.micr
|
|||
|
||||
<dl>
|
||||
<dt><tt>mask_filter_value</tt> : float</dt>
|
||||
<dd>The value to be filled in the attention mask. Default value is negative infinity</dd>
|
||||
<dd>The value to be filled in the attention mask. Default value is -10000.0f</dd>
|
||||
<dt><tt>num_heads</tt> : int (required)</dt>
|
||||
<dd>Number of attention heads</dd>
|
||||
</dl>
|
||||
|
|
@ -2125,7 +2127,7 @@ This version of the operator has been available since version 1 of the 'com.micr
|
|||
|
||||
<dl>
|
||||
<dt><tt>mask_filter_value</tt> : float</dt>
|
||||
<dd>The value to be filled in the attention mask. Default value is negative infinity</dd>
|
||||
<dd>The value to be filled in the attention mask. Default value is -10000.0f</dd>
|
||||
<dt><tt>num_heads</tt> : int (required)</dt>
|
||||
<dd>Number of attention heads</dd>
|
||||
</dl>
|
||||
|
|
@ -2410,6 +2412,8 @@ This version of the operator has been available since version 1 of the 'com.micr
|
|||
<dd>Number of attention heads</dd>
|
||||
<dt><tt>past_present_share_buffer</tt> : int</dt>
|
||||
<dd>Corresponding past and present are same tensor, its shape is (2, batch_size, num_heads, max_sequence_length, head_size)</dd>
|
||||
<dt><tt>scale</tt> : float</dt>
|
||||
<dd>Custom scale will be used if specified. Default value is 1/sqrt(head_size)</dd>
|
||||
<dt><tt>unidirectional</tt> : int</dt>
|
||||
<dd>Whether every token can only attend to previous tokens. Default value is 0.</dd>
|
||||
</dl>
|
||||
|
|
|
|||
|
|
@ -249,6 +249,7 @@ Status AttentionBase::CheckInputs(const TensorShape& input_shape,
|
|||
output_parameters->is_unidirectional = is_unidirectional_;
|
||||
output_parameters->past_present_share_buffer = (past_present_share_buffer_ != 0 && past != nullptr);
|
||||
output_parameters->mask_filter_value = mask_filter_value_;
|
||||
output_parameters->scale = scale_;
|
||||
output_parameters->mask_type = mask_type;
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -38,6 +38,7 @@ class AttentionBase {
|
|||
|
||||
is_unidirectional_ = info.GetAttrOrDefault<int64_t>("unidirectional", 0) == 1;
|
||||
mask_filter_value_ = info.GetAttrOrDefault<float>("mask_filter_value", -10000.0f);
|
||||
scale_ = info.GetAttrOrDefault<float>("scale", 0.0f);
|
||||
|
||||
if (!info.GetAttrs<int64_t>("qkv_hidden_sizes", qkv_hidden_sizes_).IsOK()) {
|
||||
qkv_hidden_sizes_.clear();
|
||||
|
|
@ -70,6 +71,7 @@ class AttentionBase {
|
|||
bool require_same_hidden_size_; // whether the implementation supports different hidden sizes of Q/K/V.
|
||||
bool past_present_share_buffer_; // whether or not the past (if used) and present tensor share the same buffer
|
||||
float mask_filter_value_; // the value to be used for filtered out positions
|
||||
float scale_; // the scale to be used for softmax
|
||||
};
|
||||
|
||||
} // namespace contrib
|
||||
|
|
|
|||
|
|
@ -37,19 +37,20 @@ enum AttentionKernelType{
|
|||
struct AttentionParameters {
|
||||
int batch_size;
|
||||
int sequence_length;
|
||||
int kv_sequence_length; // input sequence length of K or V
|
||||
int past_sequence_length; // sequence length in past state of K or V
|
||||
int total_sequence_length; // total sequence length of K or V
|
||||
int max_sequence_length; // max sequence length from 4D mask
|
||||
int input_hidden_size; // first dimension of weights for input projection
|
||||
int hidden_size; // hidden size of Q or K
|
||||
int head_size; // hidden size per head of Q or K
|
||||
int v_hidden_size; // hidden size of V
|
||||
int v_head_size; // hidden size per head of V
|
||||
int kv_sequence_length; // input sequence length of K or V
|
||||
int past_sequence_length; // sequence length in past state of K or V
|
||||
int total_sequence_length; // total sequence length of K or V
|
||||
int max_sequence_length; // max sequence length from 4D mask
|
||||
int input_hidden_size; // first dimension of weights for input projection
|
||||
int hidden_size; // hidden size of Q or K
|
||||
int head_size; // hidden size per head of Q or K
|
||||
int v_hidden_size; // hidden size of V
|
||||
int v_head_size; // hidden size per head of V
|
||||
int num_heads;
|
||||
bool is_unidirectional;
|
||||
bool past_present_share_buffer;
|
||||
float mask_filter_value;
|
||||
float scale;
|
||||
AttentionMaskType mask_type;
|
||||
};
|
||||
|
||||
|
|
|
|||
|
|
@ -130,7 +130,7 @@ class AttentionCPUBase : public AttentionBase {
|
|||
}
|
||||
|
||||
const int loop_len = batch_size * num_heads_;
|
||||
const float alpha = 1.0f / sqrt(static_cast<float>(head_size));
|
||||
const float alpha = scale_ == 0.0f ? 1.0f / sqrt(static_cast<float>(head_size)) : scale_;
|
||||
|
||||
// The cost of Gemm
|
||||
const double cost = static_cast<double>(head_size) * sequence_length * total_sequence_length;
|
||||
|
|
|
|||
|
|
@ -107,6 +107,7 @@ Status CheckInputs(const T* query,
|
|||
output_parameters->past_present_share_buffer = false;
|
||||
output_parameters->mask_filter_value = mask_filter_value;
|
||||
output_parameters->mask_type = mask_type;
|
||||
output_parameters->scale = 0.0f;
|
||||
}
|
||||
|
||||
if (max_threads_per_block > 0 && num_heads > max_threads_per_block) {
|
||||
|
|
|
|||
|
|
@ -107,7 +107,7 @@ Status Attention<T>::ComputeInternal(OpKernelContext* context) const {
|
|||
// Here we assume that num_heads, head_size and is_unidirectional does not change for an Attention node.
|
||||
if (nullptr == fused_fp16_runner_.get()) {
|
||||
fused_fp16_runner_.reset(new FusedMHARunnerFP16v2(num_heads_, parameters.head_size, sm, is_unidirectional_,
|
||||
enable_flash_attention_));
|
||||
enable_flash_attention_, parameters.scale));
|
||||
}
|
||||
|
||||
// Here we assume all causal kernels can be loaded into shared memory. TODO: add a function to check.
|
||||
|
|
@ -128,7 +128,7 @@ Status Attention<T>::ComputeInternal(OpKernelContext* context) const {
|
|||
// Here we assume that num_heads, head_size and is_unidirectional does not change for an Attention node.
|
||||
if (nullptr == fused_fp16_runner_.get()) {
|
||||
fused_fp16_runner_.reset(new FusedMHARunnerFP16v2(num_heads_, parameters.head_size, sm, is_unidirectional_,
|
||||
enable_flash_attention_));
|
||||
enable_flash_attention_, parameters.scale));
|
||||
}
|
||||
|
||||
// In case some kernel not loaded due to shared memory limit, we need to double check here.
|
||||
|
|
|
|||
|
|
@ -540,8 +540,9 @@ Status QkvToContext(
|
|||
float zero = 0.f;
|
||||
|
||||
// For raw attention mask, the scalar 1/sqrt(H) is moved to combine with softmax computation.
|
||||
const float rsqrt_head_size = 1.f / sqrt(static_cast<float>(qk_head_size));
|
||||
float alpha = use_raw_attention_mask ? one : rsqrt_head_size;
|
||||
const float scale = parameters.scale == 0.0f ? 1.f / sqrt(static_cast<float>(qk_head_size))
|
||||
: parameters.scale;
|
||||
float alpha = use_raw_attention_mask ? one : scale;
|
||||
|
||||
cublasSetStream(cublas, stream);
|
||||
|
||||
|
|
@ -570,7 +571,7 @@ Status QkvToContext(
|
|||
ORT_RETURN_IF_ERROR(
|
||||
ComputeSoftmaxWithRawMask<T>(stream, total_sequence_length, sequence_length, batch_size, num_heads,
|
||||
mask_index, nullptr, data.extra_add_qk, scratch1, scratch2,
|
||||
parameters.is_unidirectional, rsqrt_head_size, mask_dimension,
|
||||
parameters.is_unidirectional, scale, mask_dimension,
|
||||
parameters.max_sequence_length, use_persistent_softmax,
|
||||
persistent_softmax_workspace, mask_filter_value));
|
||||
} else if (nullptr != mask_index) { // 1d mask index
|
||||
|
|
|
|||
|
|
@ -114,7 +114,7 @@ Status MultiHeadAttention<T>::ComputeInternal(OpKernelContext* context) const {
|
|||
if (nullptr == fused_fp16_runner_.get()) {
|
||||
constexpr bool is_unidirectional = false;
|
||||
fused_fp16_runner_.reset(new FusedMHARunnerFP16v2(
|
||||
num_heads_, parameters.head_size, sm, is_unidirectional, enable_flash_attention_));
|
||||
num_heads_, parameters.head_size, sm, is_unidirectional, enable_flash_attention_, parameters.scale));
|
||||
}
|
||||
|
||||
// In case some kernel not loaded due to shared memory limit, we need to double check here.
|
||||
|
|
|
|||
|
|
@ -100,7 +100,7 @@ class FusedMHARunnerFP16v2::mhaImpl {
|
|||
// The number of xmmas in the M dimension. We use one uint32_t per XMMA in the M dimension.
|
||||
xmmas_m = (S + 16 * warps_m - 1) / (16 * warps_m);
|
||||
|
||||
const float scale_bmm1 = interface->mRsqrtHeadSize;
|
||||
const float scale_bmm1 = interface->mScale;
|
||||
const float scale_softmax = 1.f; // Seems to be only required for int8
|
||||
const float scale_bmm2 = 1.f;
|
||||
|
||||
|
|
@ -121,7 +121,7 @@ class FusedMHARunnerFP16v2::mhaImpl {
|
|||
}
|
||||
|
||||
void setup_causal_masked_fmha(const int S, const int B) {
|
||||
const float scale_bmm1 = interface->mRsqrtHeadSize;
|
||||
const float scale_bmm1 = interface->mScale;
|
||||
const float scale_softmax = 1.f; // Seems to be only required for int8
|
||||
const float scale_bmm2 = 1.f;
|
||||
|
||||
|
|
@ -219,8 +219,16 @@ class FusedMHARunnerFP16v2::mhaImpl {
|
|||
bool has_causal_mask = false;
|
||||
};
|
||||
|
||||
FusedMHARunnerFP16v2::FusedMHARunnerFP16v2(const int numHeads, const int headSize, const int sm, bool causal_mask, bool enable_flash_attention)
|
||||
: MHARunner(numHeads, headSize, 2, causal_mask), mSm(sm), mEnableFlashAttention(enable_flash_attention), pimpl(new mhaImpl(this)) {
|
||||
FusedMHARunnerFP16v2::FusedMHARunnerFP16v2(const int numHeads,
|
||||
const int headSize,
|
||||
const int sm,
|
||||
bool causal_mask,
|
||||
bool enable_flash_attention,
|
||||
const float scale)
|
||||
: MHARunner(numHeads, headSize, 2, causal_mask, scale),
|
||||
mSm(sm),
|
||||
mEnableFlashAttention(enable_flash_attention),
|
||||
pimpl(new mhaImpl(this)) {
|
||||
}
|
||||
|
||||
void FusedMHARunnerFP16v2::setup(const int S, const int B) {
|
||||
|
|
|
|||
|
|
@ -28,7 +28,7 @@ constexpr int kMinSequenceLengthFlashAttention = 385;
|
|||
// Multi-Head Attention runner
|
||||
class MHARunner {
|
||||
public:
|
||||
MHARunner(const int numHeads, const int headSize, const int wordSize, bool causal_mask)
|
||||
MHARunner(const int numHeads, const int headSize, const int wordSize, bool causal_mask, const float scale)
|
||||
: mS(0),
|
||||
mB(0),
|
||||
mOmatSize(0),
|
||||
|
|
@ -40,7 +40,8 @@ class MHARunner {
|
|||
mStrideQKV(0),
|
||||
mLdOut(0),
|
||||
mStrideOut(0),
|
||||
mRsqrtHeadSize(1.f / sqrtf(static_cast<float>(headSize))),
|
||||
mScale(scale == 0.0f ? 1.f / sqrtf(static_cast<float>(headSize))
|
||||
: scale),
|
||||
mHasCausalMask(causal_mask) {
|
||||
}
|
||||
|
||||
|
|
@ -83,7 +84,7 @@ class MHARunner {
|
|||
int mLdOut;
|
||||
int mStrideOut;
|
||||
|
||||
float mRsqrtHeadSize;
|
||||
float mScale;
|
||||
bool mHasCausalMask;
|
||||
};
|
||||
|
||||
|
|
@ -93,7 +94,8 @@ class FusedMHARunnerFP16v2 : public MHARunner {
|
|||
const int headSize,
|
||||
const int sm,
|
||||
bool causal_mask,
|
||||
bool enable_flash_attention);
|
||||
bool enable_flash_attention,
|
||||
const float scale);
|
||||
~FusedMHARunnerFP16v2() = default; // for pimpl
|
||||
|
||||
virtual void setup(const int S, const int B) override;
|
||||
|
|
|
|||
|
|
@ -203,6 +203,10 @@ ONNX_MS_OPERATOR_SET_SCHEMA(
|
|||
"The value to be filled in the attention mask. Default value is -10000.0f",
|
||||
AttributeProto::FLOAT,
|
||||
OPTIONAL_VALUE)
|
||||
.Attr("scale",
|
||||
"Custom scale will be used if specified. Default value is 1/sqrt(head_size)",
|
||||
AttributeProto::FLOAT,
|
||||
OPTIONAL_VALUE)
|
||||
.Input(0,
|
||||
"input",
|
||||
"Input tensor with shape (batch_size, sequence_length, input_hidden_size)",
|
||||
|
|
@ -275,7 +279,7 @@ ONNX_MS_OPERATOR_SET_SCHEMA(
|
|||
OpSchema()
|
||||
.SetDoc(MultiHeadAttention_ver1_doc)
|
||||
.Attr("num_heads", "Number of attention heads", AttributeProto::INT)
|
||||
.Attr("mask_filter_value", "The value to be filled in the attention mask. Default value is negative infinity",
|
||||
.Attr("mask_filter_value", "The value to be filled in the attention mask. Default value is -10000.0f",
|
||||
AttributeProto::FLOAT, OPTIONAL_VALUE)
|
||||
.Input(0,
|
||||
"query",
|
||||
|
|
@ -349,7 +353,7 @@ ONNX_MS_OPERATOR_SET_SCHEMA(
|
|||
OpSchema()
|
||||
.SetDoc(Decoder_Attention_doc)
|
||||
.Attr("num_heads", "Number of attention heads", AttributeProto::INT)
|
||||
.Attr("mask_filter_value", "The value to be filled in the attention mask. Default value is negative infinity",
|
||||
.Attr("mask_filter_value", "The value to be filled in the attention mask. Default value is -10000.0f",
|
||||
AttributeProto::FLOAT, OPTIONAL_VALUE)
|
||||
.Input(0, "query", "3D input tensor with shape (sequence_length, batch_size, hidden_size), hidden_size = num_heads * head_size", "T")
|
||||
.Input(1, "key", "3D input tensor with shape (total_sequence_length, batch_size, hidden_size)", "T")
|
||||
|
|
|
|||
|
|
@ -952,6 +952,10 @@ ONNX_MS_OPERATOR_SET_SCHEMA(
|
|||
.Attr("past_present_share_buffer", "Corresponding past and present are same tensor, its shape is "
|
||||
"(2, batch_size, num_heads, max_sequence_length, head_size)",
|
||||
AttributeProto::INT, OPTIONAL_VALUE)
|
||||
.Attr("scale",
|
||||
"Custom scale will be used if specified. Default value is 1/sqrt(head_size)",
|
||||
AttributeProto::FLOAT,
|
||||
OPTIONAL_VALUE)
|
||||
.Input(0, "input", "3D input tensor with shape (batch_size, sequence_length, input_hidden_size)", "T1")
|
||||
.Input(1, "weight",
|
||||
"2D input tensor with shape (input_hidden_size, 3 * hidden_size), hidden_size = num_heads * head_size",
|
||||
|
|
|
|||
|
|
@ -61,7 +61,8 @@ static void RunAttentionTest(
|
|||
std::vector<int32_t> qkv_sizes = {},
|
||||
const std::vector<float>& extra_add_data = {},
|
||||
int kv_sequence_length = 0,
|
||||
bool past_present_share_buffer = false) {
|
||||
bool past_present_share_buffer = false,
|
||||
bool use_scale = false) {
|
||||
input_hidden_size = (input_hidden_size == 0 ? hidden_size : input_hidden_size); // By default, no pruning.
|
||||
kv_sequence_length = (kv_sequence_length == 0 ? sequence_length : kv_sequence_length);
|
||||
past_present_share_buffer = past_present_share_buffer && use_past_state;
|
||||
|
|
@ -78,6 +79,9 @@ static void RunAttentionTest(
|
|||
tester.AddAttribute<int64_t>("unidirectional", static_cast<int64_t>(is_unidirectional ? 1 : 0));
|
||||
tester.AddAttribute<int64_t>("past_present_share_buffer", static_cast<int64_t>(past_present_share_buffer ? 1 : 0));
|
||||
tester.AddAttribute<float>("mask_filter_value", static_cast<float>(-10000.0f));
|
||||
if (use_scale && !enable_rocm) {
|
||||
tester.AddAttribute<float>("scale", static_cast<float>(1.f / sqrt(head_size)));
|
||||
}
|
||||
|
||||
int32_t qkv_hidden_size_sum;
|
||||
int32_t v_hidden_size;
|
||||
|
|
@ -262,19 +266,20 @@ static void RunAttentionTest(
|
|||
const std::vector<int32_t> qkv_sizes = {},
|
||||
const std::vector<float>& extra_add_data = {},
|
||||
int kv_sequence_length = 0,
|
||||
bool past_present_share_buffer = false) {
|
||||
bool past_present_share_buffer = false,
|
||||
bool use_scale = false) {
|
||||
RunAttentionTest(input_data, weights_data, false, bias_data, mask_index_data, output_data,
|
||||
batch_size, sequence_length, hidden_size, number_of_heads,
|
||||
use_float16, is_unidirectional, use_past_state, past_sequence_length,
|
||||
past_data, present_data, mask_type, input_hidden_size, max_sequence_length,
|
||||
disable_cpu, disable_cuda, disable_rocm, qkv_sizes, extra_add_data,
|
||||
kv_sequence_length, past_present_share_buffer);
|
||||
kv_sequence_length, past_present_share_buffer, use_scale);
|
||||
RunAttentionTest(input_data, weights_data, true, bias_data, mask_index_data, output_data,
|
||||
batch_size, sequence_length, hidden_size, number_of_heads,
|
||||
use_float16, is_unidirectional, use_past_state, past_sequence_length,
|
||||
past_data, present_data, mask_type, input_hidden_size, max_sequence_length,
|
||||
disable_cpu, disable_cuda, disable_rocm, qkv_sizes, extra_add_data,
|
||||
kv_sequence_length, past_present_share_buffer);
|
||||
kv_sequence_length, past_present_share_buffer, use_scale);
|
||||
}
|
||||
|
||||
TEST(AttentionTest, AttentionBatch1) {
|
||||
|
|
@ -1663,6 +1668,51 @@ TEST(AttentionTest, AttentionUnidirectionalAttentionMask) {
|
|||
AttentionMaskType::MASK_2D_KEY_PADDING);
|
||||
}
|
||||
|
||||
TEST(AttentionTest, AttentionWithNormFactor) {
|
||||
int batch_size = 2;
|
||||
int sequence_length = 2;
|
||||
int hidden_size = 4;
|
||||
int number_of_heads = 2;
|
||||
|
||||
std::vector<float> input_data = {
|
||||
0.5f, 0.2f, 0.3f, -0.6f,
|
||||
0.8f, -0.5f, 0.0f, 1.f,
|
||||
0.8f, -0.5f, 0.0f, 1.f,
|
||||
0.5f, 0.2f, 0.3f, -0.6f};
|
||||
|
||||
std::vector<float> weight_data = {
|
||||
0.1f, -0.2f, 0.3f, 1.0f, 1.1f, 0.3f, 0.5f, 0.2f, 0.3f, -0.6f, 1.5f, 2.0f,
|
||||
0.5f, 0.1f, 0.4f, 1.6f, 1.0f, 2.0f, 0.4f, 0.8f, 0.9f, 0.1f, -1.3f, 0.7f,
|
||||
0.3f, 0.2f, 4.0f, 2.2f, 1.6f, 1.1f, 0.7f, 0.2f, 0.4f, 1.0f, 1.2f, 0.5f,
|
||||
0.2f, 0.1f, 0.4f, 1.6f, 2.4f, 3.3f, 2.1f, 4.2f, 8.4f, 0.0f, 2.1f, 3.2f};
|
||||
|
||||
std::vector<float> bias_data = {
|
||||
-0.5f, 0.6f, 1.2f, 2.1f, 0.5f, 0.7f, 0.2f, 1.2f, 0.5f, 0.4f, 0.3f, 1.2f};
|
||||
|
||||
// Test mask start position > 0.
|
||||
std::vector<int32_t> mask_index_data = {0, 1, 1, 1};
|
||||
|
||||
std::vector<float> output_data = {
|
||||
3.0146f, 0.1142f, 3.9834f, 5.3394f,
|
||||
8.69f, -0.13f, 4.25f, 5.65f,
|
||||
8.69f, -0.13f, 4.25f, 5.65f,
|
||||
3.96967912f, 0.07314367f, 4.25f, 5.65f};
|
||||
|
||||
bool use_float16 = false;
|
||||
bool is_unidirectional = true;
|
||||
bool use_past_state = false;
|
||||
int past_sequence_length = 0;
|
||||
const std::vector<float>* past_data = nullptr;
|
||||
const std::vector<float>* present_data = nullptr;
|
||||
RunAttentionTest(input_data, weight_data, bias_data, mask_index_data, output_data,
|
||||
batch_size, sequence_length, hidden_size, number_of_heads,
|
||||
use_float16, is_unidirectional, use_past_state, past_sequence_length, past_data, present_data,
|
||||
AttentionMaskType::MASK_2D_KEY_PADDING, 0 /*input_hidden_size*/, 0 /*max_sequence_length*/,
|
||||
false /*disable_cpu*/, false /*disable_cuda*/, true /*disable_rocm*/, {} /*qkv_sizes*/,
|
||||
{} /*extra_add_data*/, 0 /*kv_sequence_length*/, false /*past_present_share_buffer*/,
|
||||
true /*use_scale*/);
|
||||
}
|
||||
|
||||
TEST(AttentionTest, AttentionMask1DEndNoWord) {
|
||||
int batch_size = 2;
|
||||
int sequence_length = 2;
|
||||
|
|
|
|||
Loading…
Reference in a new issue