Support muP in Attention (#14348)

### Description
<!-- Describe your changes. -->



### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->

Co-authored-by: Ubuntu <wy@v100-2.0cdb2e52twzevn1i4fi45bylyg.jx.internal.cloudapp.net>
This commit is contained in:
Ye Wang 2023-01-19 20:36:55 -08:00 committed by GitHub
parent 1dd07d147d
commit 668586e8f8
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
14 changed files with 110 additions and 32 deletions

View file

@ -131,6 +131,8 @@ This version of the operator has been available since version 1 of the 'com.micr
<dd>Corresponding past and present are same tensor, its size is (2, batch_size, num_heads, max_sequence_length, head_size)</dd>
<dt><tt>qkv_hidden_sizes</tt> : list of ints</dt>
<dd>Hidden dimension of Q, K, V: hidden_size, hidden_size and v_hidden_size</dd>
<dt><tt>scale</tt> : float</dt>
<dd>Custom scale will be used if specified. Default value is 1/sqrt(head_size)</dd>
<dt><tt>unidirectional</tt> : int</dt>
<dd>Whether every token can only attend to previous tokens. Default value is 0.</dd>
</dl>
@ -970,7 +972,7 @@ This version of the operator has been available since version 1 of the 'com.micr
<dl>
<dt><tt>mask_filter_value</tt> : float</dt>
<dd>The value to be filled in the attention mask. Default value is negative infinity</dd>
<dd>The value to be filled in the attention mask. Default value is -10000.0f</dd>
<dt><tt>num_heads</tt> : int (required)</dt>
<dd>Number of attention heads</dd>
</dl>
@ -2125,7 +2127,7 @@ This version of the operator has been available since version 1 of the 'com.micr
<dl>
<dt><tt>mask_filter_value</tt> : float</dt>
<dd>The value to be filled in the attention mask. Default value is negative infinity</dd>
<dd>The value to be filled in the attention mask. Default value is -10000.0f</dd>
<dt><tt>num_heads</tt> : int (required)</dt>
<dd>Number of attention heads</dd>
</dl>
@ -2410,6 +2412,8 @@ This version of the operator has been available since version 1 of the 'com.micr
<dd>Number of attention heads</dd>
<dt><tt>past_present_share_buffer</tt> : int</dt>
<dd>Corresponding past and present are same tensor, its shape is (2, batch_size, num_heads, max_sequence_length, head_size)</dd>
<dt><tt>scale</tt> : float</dt>
<dd>Custom scale will be used if specified. Default value is 1/sqrt(head_size)</dd>
<dt><tt>unidirectional</tt> : int</dt>
<dd>Whether every token can only attend to previous tokens. Default value is 0.</dd>
</dl>

View file

@ -249,6 +249,7 @@ Status AttentionBase::CheckInputs(const TensorShape& input_shape,
output_parameters->is_unidirectional = is_unidirectional_;
output_parameters->past_present_share_buffer = (past_present_share_buffer_ != 0 && past != nullptr);
output_parameters->mask_filter_value = mask_filter_value_;
output_parameters->scale = scale_;
output_parameters->mask_type = mask_type;
}

View file

@ -38,6 +38,7 @@ class AttentionBase {
is_unidirectional_ = info.GetAttrOrDefault<int64_t>("unidirectional", 0) == 1;
mask_filter_value_ = info.GetAttrOrDefault<float>("mask_filter_value", -10000.0f);
scale_ = info.GetAttrOrDefault<float>("scale", 0.0f);
if (!info.GetAttrs<int64_t>("qkv_hidden_sizes", qkv_hidden_sizes_).IsOK()) {
qkv_hidden_sizes_.clear();
@ -70,6 +71,7 @@ class AttentionBase {
bool require_same_hidden_size_; // whether the implementation supports different hidden sizes of Q/K/V.
bool past_present_share_buffer_; // whether or not the past (if used) and present tensor share the same buffer
float mask_filter_value_; // the value to be used for filtered out positions
float scale_; // the scale to be used for softmax
};
} // namespace contrib

View file

@ -37,19 +37,20 @@ enum AttentionKernelType{
struct AttentionParameters {
int batch_size;
int sequence_length;
int kv_sequence_length; // input sequence length of K or V
int past_sequence_length; // sequence length in past state of K or V
int total_sequence_length; // total sequence length of K or V
int max_sequence_length; // max sequence length from 4D mask
int input_hidden_size; // first dimension of weights for input projection
int hidden_size; // hidden size of Q or K
int head_size; // hidden size per head of Q or K
int v_hidden_size; // hidden size of V
int v_head_size; // hidden size per head of V
int kv_sequence_length; // input sequence length of K or V
int past_sequence_length; // sequence length in past state of K or V
int total_sequence_length; // total sequence length of K or V
int max_sequence_length; // max sequence length from 4D mask
int input_hidden_size; // first dimension of weights for input projection
int hidden_size; // hidden size of Q or K
int head_size; // hidden size per head of Q or K
int v_hidden_size; // hidden size of V
int v_head_size; // hidden size per head of V
int num_heads;
bool is_unidirectional;
bool past_present_share_buffer;
float mask_filter_value;
float scale;
AttentionMaskType mask_type;
};

View file

@ -130,7 +130,7 @@ class AttentionCPUBase : public AttentionBase {
}
const int loop_len = batch_size * num_heads_;
const float alpha = 1.0f / sqrt(static_cast<float>(head_size));
const float alpha = scale_ == 0.0f ? 1.0f / sqrt(static_cast<float>(head_size)) : scale_;
// The cost of Gemm
const double cost = static_cast<double>(head_size) * sequence_length * total_sequence_length;

View file

@ -107,6 +107,7 @@ Status CheckInputs(const T* query,
output_parameters->past_present_share_buffer = false;
output_parameters->mask_filter_value = mask_filter_value;
output_parameters->mask_type = mask_type;
output_parameters->scale = 0.0f;
}
if (max_threads_per_block > 0 && num_heads > max_threads_per_block) {

View file

@ -107,7 +107,7 @@ Status Attention<T>::ComputeInternal(OpKernelContext* context) const {
// Here we assume that num_heads, head_size and is_unidirectional does not change for an Attention node.
if (nullptr == fused_fp16_runner_.get()) {
fused_fp16_runner_.reset(new FusedMHARunnerFP16v2(num_heads_, parameters.head_size, sm, is_unidirectional_,
enable_flash_attention_));
enable_flash_attention_, parameters.scale));
}
// Here we assume all causal kernels can be loaded into shared memory. TODO: add a function to check.
@ -128,7 +128,7 @@ Status Attention<T>::ComputeInternal(OpKernelContext* context) const {
// Here we assume that num_heads, head_size and is_unidirectional does not change for an Attention node.
if (nullptr == fused_fp16_runner_.get()) {
fused_fp16_runner_.reset(new FusedMHARunnerFP16v2(num_heads_, parameters.head_size, sm, is_unidirectional_,
enable_flash_attention_));
enable_flash_attention_, parameters.scale));
}
// In case some kernel not loaded due to shared memory limit, we need to double check here.

View file

@ -540,8 +540,9 @@ Status QkvToContext(
float zero = 0.f;
// For raw attention mask, the scalar 1/sqrt(H) is moved to combine with softmax computation.
const float rsqrt_head_size = 1.f / sqrt(static_cast<float>(qk_head_size));
float alpha = use_raw_attention_mask ? one : rsqrt_head_size;
const float scale = parameters.scale == 0.0f ? 1.f / sqrt(static_cast<float>(qk_head_size))
: parameters.scale;
float alpha = use_raw_attention_mask ? one : scale;
cublasSetStream(cublas, stream);
@ -570,7 +571,7 @@ Status QkvToContext(
ORT_RETURN_IF_ERROR(
ComputeSoftmaxWithRawMask<T>(stream, total_sequence_length, sequence_length, batch_size, num_heads,
mask_index, nullptr, data.extra_add_qk, scratch1, scratch2,
parameters.is_unidirectional, rsqrt_head_size, mask_dimension,
parameters.is_unidirectional, scale, mask_dimension,
parameters.max_sequence_length, use_persistent_softmax,
persistent_softmax_workspace, mask_filter_value));
} else if (nullptr != mask_index) { // 1d mask index

View file

@ -114,7 +114,7 @@ Status MultiHeadAttention<T>::ComputeInternal(OpKernelContext* context) const {
if (nullptr == fused_fp16_runner_.get()) {
constexpr bool is_unidirectional = false;
fused_fp16_runner_.reset(new FusedMHARunnerFP16v2(
num_heads_, parameters.head_size, sm, is_unidirectional, enable_flash_attention_));
num_heads_, parameters.head_size, sm, is_unidirectional, enable_flash_attention_, parameters.scale));
}
// In case some kernel not loaded due to shared memory limit, we need to double check here.

View file

@ -100,7 +100,7 @@ class FusedMHARunnerFP16v2::mhaImpl {
// The number of xmmas in the M dimension. We use one uint32_t per XMMA in the M dimension.
xmmas_m = (S + 16 * warps_m - 1) / (16 * warps_m);
const float scale_bmm1 = interface->mRsqrtHeadSize;
const float scale_bmm1 = interface->mScale;
const float scale_softmax = 1.f; // Seems to be only required for int8
const float scale_bmm2 = 1.f;
@ -121,7 +121,7 @@ class FusedMHARunnerFP16v2::mhaImpl {
}
void setup_causal_masked_fmha(const int S, const int B) {
const float scale_bmm1 = interface->mRsqrtHeadSize;
const float scale_bmm1 = interface->mScale;
const float scale_softmax = 1.f; // Seems to be only required for int8
const float scale_bmm2 = 1.f;
@ -219,8 +219,16 @@ class FusedMHARunnerFP16v2::mhaImpl {
bool has_causal_mask = false;
};
FusedMHARunnerFP16v2::FusedMHARunnerFP16v2(const int numHeads, const int headSize, const int sm, bool causal_mask, bool enable_flash_attention)
: MHARunner(numHeads, headSize, 2, causal_mask), mSm(sm), mEnableFlashAttention(enable_flash_attention), pimpl(new mhaImpl(this)) {
FusedMHARunnerFP16v2::FusedMHARunnerFP16v2(const int numHeads,
const int headSize,
const int sm,
bool causal_mask,
bool enable_flash_attention,
const float scale)
: MHARunner(numHeads, headSize, 2, causal_mask, scale),
mSm(sm),
mEnableFlashAttention(enable_flash_attention),
pimpl(new mhaImpl(this)) {
}
void FusedMHARunnerFP16v2::setup(const int S, const int B) {

View file

@ -28,7 +28,7 @@ constexpr int kMinSequenceLengthFlashAttention = 385;
// Multi-Head Attention runner
class MHARunner {
public:
MHARunner(const int numHeads, const int headSize, const int wordSize, bool causal_mask)
MHARunner(const int numHeads, const int headSize, const int wordSize, bool causal_mask, const float scale)
: mS(0),
mB(0),
mOmatSize(0),
@ -40,7 +40,8 @@ class MHARunner {
mStrideQKV(0),
mLdOut(0),
mStrideOut(0),
mRsqrtHeadSize(1.f / sqrtf(static_cast<float>(headSize))),
mScale(scale == 0.0f ? 1.f / sqrtf(static_cast<float>(headSize))
: scale),
mHasCausalMask(causal_mask) {
}
@ -83,7 +84,7 @@ class MHARunner {
int mLdOut;
int mStrideOut;
float mRsqrtHeadSize;
float mScale;
bool mHasCausalMask;
};
@ -93,7 +94,8 @@ class FusedMHARunnerFP16v2 : public MHARunner {
const int headSize,
const int sm,
bool causal_mask,
bool enable_flash_attention);
bool enable_flash_attention,
const float scale);
~FusedMHARunnerFP16v2() = default; // for pimpl
virtual void setup(const int S, const int B) override;

View file

@ -203,6 +203,10 @@ ONNX_MS_OPERATOR_SET_SCHEMA(
"The value to be filled in the attention mask. Default value is -10000.0f",
AttributeProto::FLOAT,
OPTIONAL_VALUE)
.Attr("scale",
"Custom scale will be used if specified. Default value is 1/sqrt(head_size)",
AttributeProto::FLOAT,
OPTIONAL_VALUE)
.Input(0,
"input",
"Input tensor with shape (batch_size, sequence_length, input_hidden_size)",
@ -275,7 +279,7 @@ ONNX_MS_OPERATOR_SET_SCHEMA(
OpSchema()
.SetDoc(MultiHeadAttention_ver1_doc)
.Attr("num_heads", "Number of attention heads", AttributeProto::INT)
.Attr("mask_filter_value", "The value to be filled in the attention mask. Default value is negative infinity",
.Attr("mask_filter_value", "The value to be filled in the attention mask. Default value is -10000.0f",
AttributeProto::FLOAT, OPTIONAL_VALUE)
.Input(0,
"query",
@ -349,7 +353,7 @@ ONNX_MS_OPERATOR_SET_SCHEMA(
OpSchema()
.SetDoc(Decoder_Attention_doc)
.Attr("num_heads", "Number of attention heads", AttributeProto::INT)
.Attr("mask_filter_value", "The value to be filled in the attention mask. Default value is negative infinity",
.Attr("mask_filter_value", "The value to be filled in the attention mask. Default value is -10000.0f",
AttributeProto::FLOAT, OPTIONAL_VALUE)
.Input(0, "query", "3D input tensor with shape (sequence_length, batch_size, hidden_size), hidden_size = num_heads * head_size", "T")
.Input(1, "key", "3D input tensor with shape (total_sequence_length, batch_size, hidden_size)", "T")

View file

@ -952,6 +952,10 @@ ONNX_MS_OPERATOR_SET_SCHEMA(
.Attr("past_present_share_buffer", "Corresponding past and present are same tensor, its shape is "
"(2, batch_size, num_heads, max_sequence_length, head_size)",
AttributeProto::INT, OPTIONAL_VALUE)
.Attr("scale",
"Custom scale will be used if specified. Default value is 1/sqrt(head_size)",
AttributeProto::FLOAT,
OPTIONAL_VALUE)
.Input(0, "input", "3D input tensor with shape (batch_size, sequence_length, input_hidden_size)", "T1")
.Input(1, "weight",
"2D input tensor with shape (input_hidden_size, 3 * hidden_size), hidden_size = num_heads * head_size",

View file

@ -61,7 +61,8 @@ static void RunAttentionTest(
std::vector<int32_t> qkv_sizes = {},
const std::vector<float>& extra_add_data = {},
int kv_sequence_length = 0,
bool past_present_share_buffer = false) {
bool past_present_share_buffer = false,
bool use_scale = false) {
input_hidden_size = (input_hidden_size == 0 ? hidden_size : input_hidden_size); // By default, no pruning.
kv_sequence_length = (kv_sequence_length == 0 ? sequence_length : kv_sequence_length);
past_present_share_buffer = past_present_share_buffer && use_past_state;
@ -78,6 +79,9 @@ static void RunAttentionTest(
tester.AddAttribute<int64_t>("unidirectional", static_cast<int64_t>(is_unidirectional ? 1 : 0));
tester.AddAttribute<int64_t>("past_present_share_buffer", static_cast<int64_t>(past_present_share_buffer ? 1 : 0));
tester.AddAttribute<float>("mask_filter_value", static_cast<float>(-10000.0f));
if (use_scale && !enable_rocm) {
tester.AddAttribute<float>("scale", static_cast<float>(1.f / sqrt(head_size)));
}
int32_t qkv_hidden_size_sum;
int32_t v_hidden_size;
@ -262,19 +266,20 @@ static void RunAttentionTest(
const std::vector<int32_t> qkv_sizes = {},
const std::vector<float>& extra_add_data = {},
int kv_sequence_length = 0,
bool past_present_share_buffer = false) {
bool past_present_share_buffer = false,
bool use_scale = false) {
RunAttentionTest(input_data, weights_data, false, bias_data, mask_index_data, output_data,
batch_size, sequence_length, hidden_size, number_of_heads,
use_float16, is_unidirectional, use_past_state, past_sequence_length,
past_data, present_data, mask_type, input_hidden_size, max_sequence_length,
disable_cpu, disable_cuda, disable_rocm, qkv_sizes, extra_add_data,
kv_sequence_length, past_present_share_buffer);
kv_sequence_length, past_present_share_buffer, use_scale);
RunAttentionTest(input_data, weights_data, true, bias_data, mask_index_data, output_data,
batch_size, sequence_length, hidden_size, number_of_heads,
use_float16, is_unidirectional, use_past_state, past_sequence_length,
past_data, present_data, mask_type, input_hidden_size, max_sequence_length,
disable_cpu, disable_cuda, disable_rocm, qkv_sizes, extra_add_data,
kv_sequence_length, past_present_share_buffer);
kv_sequence_length, past_present_share_buffer, use_scale);
}
TEST(AttentionTest, AttentionBatch1) {
@ -1663,6 +1668,51 @@ TEST(AttentionTest, AttentionUnidirectionalAttentionMask) {
AttentionMaskType::MASK_2D_KEY_PADDING);
}
TEST(AttentionTest, AttentionWithNormFactor) {
int batch_size = 2;
int sequence_length = 2;
int hidden_size = 4;
int number_of_heads = 2;
std::vector<float> input_data = {
0.5f, 0.2f, 0.3f, -0.6f,
0.8f, -0.5f, 0.0f, 1.f,
0.8f, -0.5f, 0.0f, 1.f,
0.5f, 0.2f, 0.3f, -0.6f};
std::vector<float> weight_data = {
0.1f, -0.2f, 0.3f, 1.0f, 1.1f, 0.3f, 0.5f, 0.2f, 0.3f, -0.6f, 1.5f, 2.0f,
0.5f, 0.1f, 0.4f, 1.6f, 1.0f, 2.0f, 0.4f, 0.8f, 0.9f, 0.1f, -1.3f, 0.7f,
0.3f, 0.2f, 4.0f, 2.2f, 1.6f, 1.1f, 0.7f, 0.2f, 0.4f, 1.0f, 1.2f, 0.5f,
0.2f, 0.1f, 0.4f, 1.6f, 2.4f, 3.3f, 2.1f, 4.2f, 8.4f, 0.0f, 2.1f, 3.2f};
std::vector<float> bias_data = {
-0.5f, 0.6f, 1.2f, 2.1f, 0.5f, 0.7f, 0.2f, 1.2f, 0.5f, 0.4f, 0.3f, 1.2f};
// Test mask start position > 0.
std::vector<int32_t> mask_index_data = {0, 1, 1, 1};
std::vector<float> output_data = {
3.0146f, 0.1142f, 3.9834f, 5.3394f,
8.69f, -0.13f, 4.25f, 5.65f,
8.69f, -0.13f, 4.25f, 5.65f,
3.96967912f, 0.07314367f, 4.25f, 5.65f};
bool use_float16 = false;
bool is_unidirectional = true;
bool use_past_state = false;
int past_sequence_length = 0;
const std::vector<float>* past_data = nullptr;
const std::vector<float>* present_data = nullptr;
RunAttentionTest(input_data, weight_data, bias_data, mask_index_data, output_data,
batch_size, sequence_length, hidden_size, number_of_heads,
use_float16, is_unidirectional, use_past_state, past_sequence_length, past_data, present_data,
AttentionMaskType::MASK_2D_KEY_PADDING, 0 /*input_hidden_size*/, 0 /*max_sequence_length*/,
false /*disable_cpu*/, false /*disable_cuda*/, true /*disable_rocm*/, {} /*qkv_sizes*/,
{} /*extra_add_data*/, 0 /*kv_sequence_length*/, false /*past_present_share_buffer*/,
true /*use_scale*/);
}
TEST(AttentionTest, AttentionMask1DEndNoWord) {
int batch_size = 2;
int sequence_length = 2;