mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-19 21:32:23 +00:00
Speed Up DecoderMaskedSelfAttentionTest (#19531)
### Description The unit tests take 19 minutes to run (in debug build) because of too many combinations. I reduce the combinations and remain good test coverage. After the change, the test can finish in 51 seconds. Before: [----------] 2 tests from DecoderMaskedSelfAttentionTest [ RUN ] DecoderMaskedSelfAttentionTest.Test_fp32 [ OK ] DecoderMaskedSelfAttentionTest.Test_fp32 (394086 ms) [ RUN ] DecoderMaskedSelfAttentionTest.Test_fp16 [ OK ] DecoderMaskedSelfAttentionTest.Test_fp16 (747035 ms) [----------] 2 tests from DecoderMaskedSelfAttentionTest (1141122 ms total) After: [----------] 2 tests from DecoderMaskedSelfAttentionTest [ RUN ] DecoderMaskedSelfAttentionTest.Test_fp32 [ OK ] DecoderMaskedSelfAttentionTest.Test_fp32 (21057 ms) [ RUN ] DecoderMaskedSelfAttentionTest.Test_fp16 [ OK ] DecoderMaskedSelfAttentionTest.Test_fp16 (30653 ms) [----------] 2 tests from DecoderMaskedSelfAttentionTest (51710 ms total) ### Motivation and Context Reduce test time, and improve build pipeline efficiency.
This commit is contained in:
parent
d0061d6fb1
commit
4bfa69def8
1 changed files with 210 additions and 177 deletions
|
|
@ -640,122 +640,139 @@ TEST(DecoderMaskedSelfAttentionTest, Test_fp32) {
|
|||
return;
|
||||
}
|
||||
|
||||
// Vary batch size
|
||||
for (int batch_size = 1; batch_size <= 5; batch_size += 2) {
|
||||
// Vary kv_lengths
|
||||
for (int past_sequence_length = 1; past_sequence_length <= 3000; past_sequence_length += 150) {
|
||||
int sequence_length = 1;
|
||||
int number_of_heads = 12;
|
||||
// Vary head_size / hidden_size
|
||||
int hidden_sizes[3] = {384, 768, 1536};
|
||||
for (int hidden_size : hidden_sizes) {
|
||||
int head_size = (hidden_size / number_of_heads);
|
||||
int total_sequence_length = sequence_length + past_sequence_length;
|
||||
int max_sequence_length = past_sequence_length + 1; // Always keep > past_sequence_length
|
||||
// Buckets for test data:
|
||||
// batch_size: 1, >=2
|
||||
// past_sequence_length 0~30, 31~2046, >=2047 (so that total_sequence_length: 1~31, 32~2047, >=2048)
|
||||
// head_size: 32, 64, 128
|
||||
struct MyTestCase {
|
||||
int batch_size;
|
||||
int past_sequence_length;
|
||||
int hidden_size;
|
||||
} test_cases[] = {
|
||||
{1, 0, 768},
|
||||
{1, 1, 384},
|
||||
{2, 30, 768},
|
||||
{3, 31, 1536},
|
||||
{4, 512, 384},
|
||||
{1, 1024, 768},
|
||||
{1, 2046, 1536},
|
||||
{2, 2047, 384},
|
||||
{3, 3000, 768},
|
||||
};
|
||||
|
||||
OpTester tester("DecoderMaskedSelfAttention", 1, onnxruntime::kMSDomain);
|
||||
tester.AddAttribute<int64_t>("num_heads", static_cast<int64_t>(number_of_heads));
|
||||
tester.AddAttribute<int64_t>("past_present_share_buffer", static_cast<int64_t>(1));
|
||||
constexpr int sequence_length = 1;
|
||||
constexpr int number_of_heads = 12;
|
||||
|
||||
std::vector<int64_t> input_dims = {batch_size, sequence_length, hidden_size};
|
||||
std::vector<int64_t> weights_dims = {hidden_size, 3 * hidden_size};
|
||||
std::vector<int64_t> bias_dims = {3 * hidden_size};
|
||||
std::vector<int64_t> output_dims = {batch_size, sequence_length, hidden_size};
|
||||
for (MyTestCase test_case : test_cases) {
|
||||
int batch_size = test_case.batch_size;
|
||||
int past_sequence_length = test_case.past_sequence_length;
|
||||
int hidden_size = test_case.hidden_size;
|
||||
|
||||
auto input = CreateRandom<float>(batch_size * sequence_length * hidden_size);
|
||||
tester.AddInput<float>("input", input_dims, input);
|
||||
int head_size = (hidden_size / number_of_heads);
|
||||
int total_sequence_length = sequence_length + past_sequence_length;
|
||||
int max_sequence_length = past_sequence_length + 1; // Always keep > past_sequence_length
|
||||
|
||||
auto weight = CreateRandom<float>(hidden_size * 3 * hidden_size);
|
||||
tester.AddInput<float>("weight", weights_dims, weight);
|
||||
OpTester tester("DecoderMaskedSelfAttention", 1, onnxruntime::kMSDomain);
|
||||
tester.AddAttribute<int64_t>("num_heads", static_cast<int64_t>(number_of_heads));
|
||||
tester.AddAttribute<int64_t>("past_present_share_buffer", static_cast<int64_t>(1));
|
||||
|
||||
auto bias = CreateRandom<float>(3 * hidden_size);
|
||||
tester.AddInput<float>("bias", bias_dims, bias);
|
||||
std::vector<int64_t> input_dims = {batch_size, sequence_length, hidden_size};
|
||||
std::vector<int64_t> weights_dims = {hidden_size, 3 * hidden_size};
|
||||
std::vector<int64_t> bias_dims = {3 * hidden_size};
|
||||
std::vector<int64_t> output_dims = {batch_size, sequence_length, hidden_size};
|
||||
|
||||
// Mask
|
||||
tester.AddOptionalInputEdge<int32_t>();
|
||||
auto input = CreateRandom<float>(batch_size * sequence_length * hidden_size);
|
||||
tester.AddInput<float>("input", input_dims, input);
|
||||
|
||||
// Past
|
||||
std::vector<int64_t> past_dims = {2, batch_size, number_of_heads, max_sequence_length, head_size};
|
||||
int past_present_size = 2 * batch_size * number_of_heads * max_sequence_length * head_size;
|
||||
auto weight = CreateRandom<float>(hidden_size * 3 * hidden_size);
|
||||
tester.AddInput<float>("weight", weights_dims, weight);
|
||||
|
||||
auto kv_cache = CreateRandom<float>(past_present_size);
|
||||
auto bias = CreateRandom<float>(3 * hidden_size);
|
||||
tester.AddInput<float>("bias", bias_dims, bias);
|
||||
|
||||
auto reordered_kv_cache = ReorderKVCache<float>(kv_cache, batch_size,
|
||||
number_of_heads, past_sequence_length, head_size, max_sequence_length);
|
||||
// Mask
|
||||
tester.AddOptionalInputEdge<int32_t>();
|
||||
|
||||
// Validate if reordering went well - by transposing and checking equality
|
||||
int chunk_size = 16 / sizeof(float);
|
||||
int num_chunks = head_size / chunk_size;
|
||||
auto transposed = Transpose<float>(kv_cache.data(), batch_size, number_of_heads, num_chunks, max_sequence_length, chunk_size);
|
||||
CheckEquality<float>(transposed.data(), reordered_kv_cache.data(), batch_size, number_of_heads, num_chunks,
|
||||
max_sequence_length, past_sequence_length, chunk_size);
|
||||
// Past
|
||||
std::vector<int64_t> past_dims = {2, batch_size, number_of_heads, max_sequence_length, head_size};
|
||||
int past_present_size = 2 * batch_size * number_of_heads * max_sequence_length * head_size;
|
||||
|
||||
tester.AddInput<float>("past", past_dims, reordered_kv_cache);
|
||||
auto kv_cache = CreateRandom<float>(past_present_size);
|
||||
|
||||
// Rel
|
||||
tester.AddOptionalInputEdge<float>();
|
||||
auto reordered_kv_cache = ReorderKVCache<float>(kv_cache, batch_size,
|
||||
number_of_heads, past_sequence_length, head_size, max_sequence_length);
|
||||
|
||||
// Past sequence length
|
||||
std::vector<int32_t> arr_past_sequence_len(1, past_sequence_length);
|
||||
tester.AddInput<int32_t>("past_sequence_length", {1}, arr_past_sequence_len);
|
||||
// Validate if reordering went well - by transposing and checking equality
|
||||
int chunk_size = 16 / sizeof(float);
|
||||
int num_chunks = head_size / chunk_size;
|
||||
auto transposed = Transpose<float>(kv_cache.data(), batch_size, number_of_heads, num_chunks, max_sequence_length, chunk_size);
|
||||
CheckEquality<float>(transposed.data(), reordered_kv_cache.data(), batch_size, number_of_heads, num_chunks,
|
||||
max_sequence_length, past_sequence_length, chunk_size);
|
||||
|
||||
// QKV MatMul
|
||||
auto qkv = QKV(input, weight, bias, batch_size, sequence_length, hidden_size);
|
||||
auto* qkv_matrix = qkv.data();
|
||||
tester.AddInput<float>("past", past_dims, reordered_kv_cache);
|
||||
|
||||
auto pair = MergePastKWithPresentKAndTranspose<float>(kv_cache.data(), qkv_matrix + hidden_size, batch_size,
|
||||
number_of_heads, past_sequence_length,
|
||||
max_sequence_length, head_size);
|
||||
// Rel
|
||||
tester.AddOptionalInputEdge<float>();
|
||||
|
||||
auto k_merged = pair.first;
|
||||
auto k_transpose = pair.second;
|
||||
// Past sequence length
|
||||
std::vector<int32_t> arr_past_sequence_len(1, past_sequence_length);
|
||||
tester.AddInput<int32_t>("past_sequence_length", {1}, arr_past_sequence_len);
|
||||
|
||||
auto qk_transpose = QK_Transpose<float>(qkv_matrix, k_transpose.data(), batch_size, number_of_heads,
|
||||
total_sequence_length, head_size);
|
||||
// QKV MatMul
|
||||
auto qkv = QKV(input, weight, bias, batch_size, sequence_length, hidden_size);
|
||||
auto* qkv_matrix = qkv.data();
|
||||
|
||||
auto softmax_qk_transpose = Softmax_QK_Transpose<float>(qk_transpose.data(), batch_size, number_of_heads,
|
||||
sequence_length, total_sequence_length, head_size);
|
||||
auto pair = MergePastKWithPresentKAndTranspose<float>(kv_cache.data(), qkv_matrix + hidden_size, batch_size,
|
||||
number_of_heads, past_sequence_length,
|
||||
max_sequence_length, head_size);
|
||||
|
||||
auto present = MergeReorderedKVCacheWithK<float>(reordered_kv_cache, qkv_matrix + hidden_size, batch_size,
|
||||
number_of_heads, past_sequence_length, max_sequence_length, head_size);
|
||||
auto k_merged = pair.first;
|
||||
auto k_transpose = pair.second;
|
||||
|
||||
// Validate our test logic
|
||||
// We want to validate if our merged "unordered" K is the same as
|
||||
// the merged "ordered" K so that the QKT we do in our test code
|
||||
// is equivalent to the QKT we do in the kernel
|
||||
ValidateReorderedMergedKWithK<float>(k_merged.data(), present.data(), batch_size, number_of_heads, total_sequence_length, max_sequence_length, head_size);
|
||||
auto qk_transpose = QK_Transpose<float>(qkv_matrix, k_transpose.data(), batch_size, number_of_heads,
|
||||
total_sequence_length, head_size);
|
||||
|
||||
MergeReorderedKVCacheWithV<float>(present.data() + (past_present_size / 2), qkv_matrix + 2 * hidden_size, batch_size,
|
||||
number_of_heads, past_sequence_length, max_sequence_length, head_size);
|
||||
auto softmax_qk_transpose = Softmax_QK_Transpose<float>(qk_transpose.data(), batch_size, number_of_heads,
|
||||
sequence_length, total_sequence_length, head_size);
|
||||
|
||||
auto output = Softmax_QK_Transpose_V<float>(softmax_qk_transpose.data(), present.data() + (past_present_size / 2),
|
||||
batch_size, number_of_heads,
|
||||
sequence_length, total_sequence_length,
|
||||
max_sequence_length, head_size);
|
||||
auto present = MergeReorderedKVCacheWithK<float>(reordered_kv_cache, qkv_matrix + hidden_size, batch_size,
|
||||
number_of_heads, past_sequence_length, max_sequence_length, head_size);
|
||||
|
||||
// Output(s)
|
||||
tester.AddOutput<float>("output", input_dims, output);
|
||||
// Validate our test logic
|
||||
// We want to validate if our merged "unordered" K is the same as
|
||||
// the merged "ordered" K so that the QKT we do in our test code
|
||||
// is equivalent to the QKT we do in the kernel
|
||||
ValidateReorderedMergedKWithK<float>(k_merged.data(), present.data(), batch_size, number_of_heads, total_sequence_length, max_sequence_length, head_size);
|
||||
|
||||
tester.AddOutput<float>("present", past_dims, present);
|
||||
MergeReorderedKVCacheWithV<float>(present.data() + (past_present_size / 2), qkv_matrix + 2 * hidden_size, batch_size,
|
||||
number_of_heads, past_sequence_length, max_sequence_length, head_size);
|
||||
|
||||
// Run - Regular kernel execution path
|
||||
{
|
||||
std::vector<std::unique_ptr<IExecutionProvider>> execution_providers;
|
||||
execution_providers.push_back(DefaultCudaExecutionProvider());
|
||||
tester.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers);
|
||||
}
|
||||
auto output = Softmax_QK_Transpose_V<float>(softmax_qk_transpose.data(), present.data() + (past_present_size / 2),
|
||||
batch_size, number_of_heads,
|
||||
sequence_length, total_sequence_length,
|
||||
max_sequence_length, head_size);
|
||||
|
||||
// Test alternate kernel path of loading more KV data "in flight"
|
||||
{
|
||||
ScopedEnvironmentVariables scoped_env_vars{
|
||||
EnvVarMap{{onnxruntime::contrib::attention::kDecoderMaskedAttentionLoadKVDataInFlight, "1"}}};
|
||||
// Output(s)
|
||||
tester.AddOutput<float>("output", input_dims, output);
|
||||
|
||||
std::vector<std::unique_ptr<IExecutionProvider>> execution_providers;
|
||||
execution_providers.push_back(DefaultCudaExecutionProvider());
|
||||
tester.AddOutput<float>("present", past_dims, present);
|
||||
|
||||
tester.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers);
|
||||
}
|
||||
}
|
||||
// Run - Regular kernel execution path
|
||||
{
|
||||
std::vector<std::unique_ptr<IExecutionProvider>> execution_providers;
|
||||
execution_providers.push_back(DefaultCudaExecutionProvider());
|
||||
tester.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers);
|
||||
}
|
||||
|
||||
// Test alternate kernel path of loading more KV data "in flight"
|
||||
{
|
||||
ScopedEnvironmentVariables scoped_env_vars{
|
||||
EnvVarMap{{onnxruntime::contrib::attention::kDecoderMaskedAttentionLoadKVDataInFlight, "1"}}};
|
||||
|
||||
std::vector<std::unique_ptr<IExecutionProvider>> execution_providers;
|
||||
execution_providers.push_back(DefaultCudaExecutionProvider());
|
||||
|
||||
tester.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -766,122 +783,138 @@ TEST(DecoderMaskedSelfAttentionTest, Test_fp16) {
|
|||
return;
|
||||
}
|
||||
|
||||
// Vary batch size
|
||||
for (int batch_size = 1; batch_size <= 5; batch_size += 2) {
|
||||
// Vary kv_lengths
|
||||
for (int past_sequence_length = 1; past_sequence_length <= 3000; past_sequence_length += 150) {
|
||||
int sequence_length = 1;
|
||||
int number_of_heads = 12;
|
||||
// Buckets for test data:
|
||||
// batch_size: 1, >=2
|
||||
// past_sequence_length 0, 1~30, 31~2046, >=2047 (so that total_sequence_length: 1, 2-31, 32~2047, >=2048)
|
||||
// head_size: 32, 64, 128
|
||||
struct MyTestCase {
|
||||
int batch_size;
|
||||
int past_sequence_length;
|
||||
int hidden_size;
|
||||
} test_cases[] = {
|
||||
{1, 0, 768},
|
||||
{1, 1, 768},
|
||||
{3, 30, 384},
|
||||
{8, 31, 1536},
|
||||
{4, 256, 384},
|
||||
{3, 1024, 768},
|
||||
{2, 2046, 1536},
|
||||
{1, 2047, 384},
|
||||
{2, 3000, 768},
|
||||
};
|
||||
|
||||
// Vary head_size / hidden_size
|
||||
int hidden_sizes[3] = {384, 768, 1536};
|
||||
for (int hidden_size : hidden_sizes) {
|
||||
int head_size = (hidden_size / number_of_heads);
|
||||
int total_sequence_length = sequence_length + past_sequence_length;
|
||||
int max_sequence_length = past_sequence_length + 1; // Always keep > past_sequence_length
|
||||
constexpr int sequence_length = 1;
|
||||
constexpr int number_of_heads = 12;
|
||||
|
||||
OpTester tester("DecoderMaskedSelfAttention", 1, onnxruntime::kMSDomain);
|
||||
tester.AddAttribute<int64_t>("num_heads", static_cast<int64_t>(number_of_heads));
|
||||
tester.AddAttribute<int64_t>("past_present_share_buffer", static_cast<int64_t>(1));
|
||||
for (MyTestCase test_case : test_cases) {
|
||||
int batch_size = test_case.batch_size;
|
||||
int past_sequence_length = test_case.past_sequence_length;
|
||||
int hidden_size = test_case.hidden_size;
|
||||
|
||||
std::vector<int64_t> input_dims = {batch_size, sequence_length, hidden_size};
|
||||
std::vector<int64_t> weights_dims = {hidden_size, 3 * hidden_size};
|
||||
std::vector<int64_t> bias_dims = {3 * hidden_size};
|
||||
std::vector<int64_t> output_dims = {batch_size, sequence_length, hidden_size};
|
||||
int head_size = (hidden_size / number_of_heads);
|
||||
int total_sequence_length = sequence_length + past_sequence_length;
|
||||
int max_sequence_length = past_sequence_length + 1; // Always keep > past_sequence_length
|
||||
|
||||
auto input = CreateRandom<MLFloat16>(batch_size * sequence_length * hidden_size);
|
||||
tester.AddInput<MLFloat16>("input", input_dims, input);
|
||||
OpTester tester("DecoderMaskedSelfAttention", 1, onnxruntime::kMSDomain);
|
||||
tester.AddAttribute<int64_t>("num_heads", static_cast<int64_t>(number_of_heads));
|
||||
tester.AddAttribute<int64_t>("past_present_share_buffer", static_cast<int64_t>(1));
|
||||
|
||||
auto weight = CreateRandom<MLFloat16>(hidden_size * 3 * hidden_size);
|
||||
tester.AddInput<MLFloat16>("weight", weights_dims, weight);
|
||||
std::vector<int64_t> input_dims = {batch_size, sequence_length, hidden_size};
|
||||
std::vector<int64_t> weights_dims = {hidden_size, 3 * hidden_size};
|
||||
std::vector<int64_t> bias_dims = {3 * hidden_size};
|
||||
std::vector<int64_t> output_dims = {batch_size, sequence_length, hidden_size};
|
||||
|
||||
auto bias = CreateRandom<MLFloat16>(3 * hidden_size);
|
||||
tester.AddInput<MLFloat16>("bias", bias_dims, bias);
|
||||
auto input = CreateRandom<MLFloat16>(batch_size * sequence_length * hidden_size);
|
||||
tester.AddInput<MLFloat16>("input", input_dims, input);
|
||||
|
||||
// Mask
|
||||
tester.AddOptionalInputEdge<int32_t>();
|
||||
auto weight = CreateRandom<MLFloat16>(hidden_size * 3 * hidden_size);
|
||||
tester.AddInput<MLFloat16>("weight", weights_dims, weight);
|
||||
|
||||
// Past
|
||||
std::vector<int64_t> past_dims = {2, batch_size, number_of_heads, max_sequence_length, head_size};
|
||||
int past_present_size = 2 * batch_size * number_of_heads * max_sequence_length * head_size;
|
||||
auto bias = CreateRandom<MLFloat16>(3 * hidden_size);
|
||||
tester.AddInput<MLFloat16>("bias", bias_dims, bias);
|
||||
|
||||
auto kv_cache = CreateRandom<MLFloat16>(past_present_size);
|
||||
// Mask
|
||||
tester.AddOptionalInputEdge<int32_t>();
|
||||
|
||||
auto reordered_kv_cache = ReorderKVCache<MLFloat16>(kv_cache, batch_size,
|
||||
number_of_heads, past_sequence_length, head_size, max_sequence_length);
|
||||
// Past
|
||||
std::vector<int64_t> past_dims = {2, batch_size, number_of_heads, max_sequence_length, head_size};
|
||||
int past_present_size = 2 * batch_size * number_of_heads * max_sequence_length * head_size;
|
||||
|
||||
// Validate if reordering went well - by transposing and checking equality
|
||||
int chunk_size = 16 / sizeof(MLFloat16);
|
||||
int num_chunks = head_size / chunk_size;
|
||||
auto transposed = Transpose<MLFloat16>(kv_cache.data(), batch_size, number_of_heads, num_chunks, max_sequence_length, chunk_size);
|
||||
CheckEquality<MLFloat16>(transposed.data(), reordered_kv_cache.data(), batch_size, number_of_heads, num_chunks,
|
||||
max_sequence_length, past_sequence_length, chunk_size);
|
||||
auto kv_cache = CreateRandom<MLFloat16>(past_present_size);
|
||||
|
||||
tester.AddInput<MLFloat16>("past", past_dims, reordered_kv_cache);
|
||||
auto reordered_kv_cache = ReorderKVCache<MLFloat16>(kv_cache, batch_size,
|
||||
number_of_heads, past_sequence_length, head_size, max_sequence_length);
|
||||
|
||||
// Rel
|
||||
tester.AddOptionalInputEdge<MLFloat16>();
|
||||
// Validate if reordering went well - by transposing and checking equality
|
||||
int chunk_size = 16 / sizeof(MLFloat16);
|
||||
int num_chunks = head_size / chunk_size;
|
||||
auto transposed = Transpose<MLFloat16>(kv_cache.data(), batch_size, number_of_heads, num_chunks, max_sequence_length, chunk_size);
|
||||
CheckEquality<MLFloat16>(transposed.data(), reordered_kv_cache.data(), batch_size, number_of_heads, num_chunks,
|
||||
max_sequence_length, past_sequence_length, chunk_size);
|
||||
|
||||
// Past sequence length
|
||||
std::vector<int32_t> arr_past_sequence_len(1, past_sequence_length);
|
||||
tester.AddInput<int32_t>("past_sequence_length", {1}, arr_past_sequence_len);
|
||||
tester.AddInput<MLFloat16>("past", past_dims, reordered_kv_cache);
|
||||
|
||||
// QKV MatMul
|
||||
auto qkv = QKV(input, weight, bias, batch_size, sequence_length, hidden_size);
|
||||
auto* qkv_matrix = qkv.data();
|
||||
// Rel
|
||||
tester.AddOptionalInputEdge<MLFloat16>();
|
||||
|
||||
auto pair = MergePastKWithPresentKAndTranspose<MLFloat16>(kv_cache.data(), qkv_matrix + hidden_size, batch_size,
|
||||
number_of_heads, past_sequence_length,
|
||||
max_sequence_length, head_size);
|
||||
// Past sequence length
|
||||
std::vector<int32_t> arr_past_sequence_len(1, past_sequence_length);
|
||||
tester.AddInput<int32_t>("past_sequence_length", {1}, arr_past_sequence_len);
|
||||
|
||||
auto k_merged = pair.first;
|
||||
auto k_transpose = pair.second;
|
||||
// QKV MatMul
|
||||
auto qkv = QKV(input, weight, bias, batch_size, sequence_length, hidden_size);
|
||||
auto* qkv_matrix = qkv.data();
|
||||
|
||||
auto qk_transpose = QK_Transpose<MLFloat16>(qkv_matrix, k_transpose.data(), batch_size, number_of_heads,
|
||||
total_sequence_length, head_size);
|
||||
auto pair = MergePastKWithPresentKAndTranspose<MLFloat16>(kv_cache.data(), qkv_matrix + hidden_size, batch_size,
|
||||
number_of_heads, past_sequence_length,
|
||||
max_sequence_length, head_size);
|
||||
|
||||
auto softmax_qk_transpose = Softmax_QK_Transpose<MLFloat16>(qk_transpose.data(), batch_size, number_of_heads,
|
||||
sequence_length, total_sequence_length, head_size);
|
||||
auto k_merged = pair.first;
|
||||
auto k_transpose = pair.second;
|
||||
|
||||
auto present = MergeReorderedKVCacheWithK<MLFloat16>(reordered_kv_cache, qkv_matrix + hidden_size, batch_size,
|
||||
number_of_heads, past_sequence_length, max_sequence_length, head_size);
|
||||
auto qk_transpose = QK_Transpose<MLFloat16>(qkv_matrix, k_transpose.data(), batch_size, number_of_heads,
|
||||
total_sequence_length, head_size);
|
||||
|
||||
// Validate our test logic
|
||||
// We want to validate if our merged "unordered" K is the same as
|
||||
// the merged "ordered" K so that the QKT we do in our test code
|
||||
// is equivalent to the QKT we do in the kernel
|
||||
ValidateReorderedMergedKWithK<MLFloat16>(k_merged.data(), present.data(), batch_size, number_of_heads, total_sequence_length, max_sequence_length, head_size);
|
||||
auto softmax_qk_transpose = Softmax_QK_Transpose<MLFloat16>(qk_transpose.data(), batch_size, number_of_heads,
|
||||
sequence_length, total_sequence_length, head_size);
|
||||
|
||||
MergeReorderedKVCacheWithV<MLFloat16>(present.data() + (past_present_size / 2), qkv_matrix + 2 * hidden_size, batch_size,
|
||||
number_of_heads, past_sequence_length, max_sequence_length, head_size);
|
||||
auto present = MergeReorderedKVCacheWithK<MLFloat16>(reordered_kv_cache, qkv_matrix + hidden_size, batch_size,
|
||||
number_of_heads, past_sequence_length, max_sequence_length, head_size);
|
||||
|
||||
auto output = Softmax_QK_Transpose_V(softmax_qk_transpose.data(), present.data() + (past_present_size / 2),
|
||||
batch_size, number_of_heads,
|
||||
sequence_length, total_sequence_length,
|
||||
max_sequence_length, head_size);
|
||||
// Validate our test logic
|
||||
// We want to validate if our merged "unordered" K is the same as
|
||||
// the merged "ordered" K so that the QKT we do in our test code
|
||||
// is equivalent to the QKT we do in the kernel
|
||||
ValidateReorderedMergedKWithK<MLFloat16>(k_merged.data(), present.data(), batch_size, number_of_heads, total_sequence_length, max_sequence_length, head_size);
|
||||
|
||||
// Output(s)
|
||||
tester.AddOutput<MLFloat16>("output", input_dims, output);
|
||||
MergeReorderedKVCacheWithV<MLFloat16>(present.data() + (past_present_size / 2), qkv_matrix + 2 * hidden_size, batch_size,
|
||||
number_of_heads, past_sequence_length, max_sequence_length, head_size);
|
||||
|
||||
tester.AddOutput<MLFloat16>("present", past_dims, present);
|
||||
auto output = Softmax_QK_Transpose_V(softmax_qk_transpose.data(), present.data() + (past_present_size / 2),
|
||||
batch_size, number_of_heads,
|
||||
sequence_length, total_sequence_length,
|
||||
max_sequence_length, head_size);
|
||||
|
||||
// Run - Regular kernel execution path
|
||||
{
|
||||
std::vector<std::unique_ptr<IExecutionProvider>> execution_providers;
|
||||
execution_providers.push_back(DefaultCudaExecutionProvider());
|
||||
tester.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers);
|
||||
}
|
||||
// Output(s)
|
||||
tester.AddOutput<MLFloat16>("output", input_dims, output);
|
||||
|
||||
// Test alternate kernel path of loading more KV data "in flight"
|
||||
{
|
||||
ScopedEnvironmentVariables scoped_env_vars{
|
||||
EnvVarMap{{onnxruntime::contrib::attention::kDecoderMaskedAttentionLoadKVDataInFlight, "1"}}};
|
||||
tester.AddOutput<MLFloat16>("present", past_dims, present);
|
||||
|
||||
std::vector<std::unique_ptr<IExecutionProvider>> execution_providers;
|
||||
execution_providers.push_back(DefaultCudaExecutionProvider());
|
||||
tester.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers);
|
||||
}
|
||||
}
|
||||
// Run - Regular kernel execution path
|
||||
{
|
||||
std::vector<std::unique_ptr<IExecutionProvider>> execution_providers;
|
||||
execution_providers.push_back(DefaultCudaExecutionProvider());
|
||||
tester.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers);
|
||||
}
|
||||
|
||||
// Test alternate kernel path of loading more KV data "in flight"
|
||||
{
|
||||
ScopedEnvironmentVariables scoped_env_vars{
|
||||
EnvVarMap{{onnxruntime::contrib::attention::kDecoderMaskedAttentionLoadKVDataInFlight, "1"}}};
|
||||
|
||||
std::vector<std::unique_ptr<IExecutionProvider>> execution_providers;
|
||||
execution_providers.push_back(DefaultCudaExecutionProvider());
|
||||
tester.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -889,4 +922,4 @@ TEST(DecoderMaskedSelfAttentionTest, Test_fp16) {
|
|||
#endif
|
||||
|
||||
} // namespace test
|
||||
} // namespace onnxruntime
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
Loading…
Reference in a new issue