diff --git a/onnxruntime/test/contrib_ops/decoder_masked_multihead_attention_op_test.cc b/onnxruntime/test/contrib_ops/decoder_masked_multihead_attention_op_test.cc index 6afb61bd1f..8ea37ad054 100644 --- a/onnxruntime/test/contrib_ops/decoder_masked_multihead_attention_op_test.cc +++ b/onnxruntime/test/contrib_ops/decoder_masked_multihead_attention_op_test.cc @@ -640,122 +640,139 @@ TEST(DecoderMaskedSelfAttentionTest, Test_fp32) { return; } - // Vary batch size - for (int batch_size = 1; batch_size <= 5; batch_size += 2) { - // Vary kv_lengths - for (int past_sequence_length = 1; past_sequence_length <= 3000; past_sequence_length += 150) { - int sequence_length = 1; - int number_of_heads = 12; - // Vary head_size / hidden_size - int hidden_sizes[3] = {384, 768, 1536}; - for (int hidden_size : hidden_sizes) { - int head_size = (hidden_size / number_of_heads); - int total_sequence_length = sequence_length + past_sequence_length; - int max_sequence_length = past_sequence_length + 1; // Always keep > past_sequence_length + // Buckets for test data: + // batch_size: 1, >=2 + // past_sequence_length 0~30, 31~2046, >=2047 (so that total_sequence_length: 1~31, 32~2047, >=2048) + // head_size: 32, 64, 128 + struct MyTestCase { + int batch_size; + int past_sequence_length; + int hidden_size; + } test_cases[] = { + {1, 0, 768}, + {1, 1, 384}, + {2, 30, 768}, + {3, 31, 1536}, + {4, 512, 384}, + {1, 1024, 768}, + {1, 2046, 1536}, + {2, 2047, 384}, + {3, 3000, 768}, + }; - OpTester tester("DecoderMaskedSelfAttention", 1, onnxruntime::kMSDomain); - tester.AddAttribute("num_heads", static_cast(number_of_heads)); - tester.AddAttribute("past_present_share_buffer", static_cast(1)); + constexpr int sequence_length = 1; + constexpr int number_of_heads = 12; - std::vector input_dims = {batch_size, sequence_length, hidden_size}; - std::vector weights_dims = {hidden_size, 3 * hidden_size}; - std::vector bias_dims = {3 * hidden_size}; - std::vector output_dims = {batch_size, sequence_length, hidden_size}; + for (MyTestCase test_case : test_cases) { + int batch_size = test_case.batch_size; + int past_sequence_length = test_case.past_sequence_length; + int hidden_size = test_case.hidden_size; - auto input = CreateRandom(batch_size * sequence_length * hidden_size); - tester.AddInput("input", input_dims, input); + int head_size = (hidden_size / number_of_heads); + int total_sequence_length = sequence_length + past_sequence_length; + int max_sequence_length = past_sequence_length + 1; // Always keep > past_sequence_length - auto weight = CreateRandom(hidden_size * 3 * hidden_size); - tester.AddInput("weight", weights_dims, weight); + OpTester tester("DecoderMaskedSelfAttention", 1, onnxruntime::kMSDomain); + tester.AddAttribute("num_heads", static_cast(number_of_heads)); + tester.AddAttribute("past_present_share_buffer", static_cast(1)); - auto bias = CreateRandom(3 * hidden_size); - tester.AddInput("bias", bias_dims, bias); + std::vector input_dims = {batch_size, sequence_length, hidden_size}; + std::vector weights_dims = {hidden_size, 3 * hidden_size}; + std::vector bias_dims = {3 * hidden_size}; + std::vector output_dims = {batch_size, sequence_length, hidden_size}; - // Mask - tester.AddOptionalInputEdge(); + auto input = CreateRandom(batch_size * sequence_length * hidden_size); + tester.AddInput("input", input_dims, input); - // Past - std::vector past_dims = {2, batch_size, number_of_heads, max_sequence_length, head_size}; - int past_present_size = 2 * batch_size * number_of_heads * max_sequence_length * head_size; + auto weight = CreateRandom(hidden_size * 3 * hidden_size); + tester.AddInput("weight", weights_dims, weight); - auto kv_cache = CreateRandom(past_present_size); + auto bias = CreateRandom(3 * hidden_size); + tester.AddInput("bias", bias_dims, bias); - auto reordered_kv_cache = ReorderKVCache(kv_cache, batch_size, - number_of_heads, past_sequence_length, head_size, max_sequence_length); + // Mask + tester.AddOptionalInputEdge(); - // Validate if reordering went well - by transposing and checking equality - int chunk_size = 16 / sizeof(float); - int num_chunks = head_size / chunk_size; - auto transposed = Transpose(kv_cache.data(), batch_size, number_of_heads, num_chunks, max_sequence_length, chunk_size); - CheckEquality(transposed.data(), reordered_kv_cache.data(), batch_size, number_of_heads, num_chunks, - max_sequence_length, past_sequence_length, chunk_size); + // Past + std::vector past_dims = {2, batch_size, number_of_heads, max_sequence_length, head_size}; + int past_present_size = 2 * batch_size * number_of_heads * max_sequence_length * head_size; - tester.AddInput("past", past_dims, reordered_kv_cache); + auto kv_cache = CreateRandom(past_present_size); - // Rel - tester.AddOptionalInputEdge(); + auto reordered_kv_cache = ReorderKVCache(kv_cache, batch_size, + number_of_heads, past_sequence_length, head_size, max_sequence_length); - // Past sequence length - std::vector arr_past_sequence_len(1, past_sequence_length); - tester.AddInput("past_sequence_length", {1}, arr_past_sequence_len); + // Validate if reordering went well - by transposing and checking equality + int chunk_size = 16 / sizeof(float); + int num_chunks = head_size / chunk_size; + auto transposed = Transpose(kv_cache.data(), batch_size, number_of_heads, num_chunks, max_sequence_length, chunk_size); + CheckEquality(transposed.data(), reordered_kv_cache.data(), batch_size, number_of_heads, num_chunks, + max_sequence_length, past_sequence_length, chunk_size); - // QKV MatMul - auto qkv = QKV(input, weight, bias, batch_size, sequence_length, hidden_size); - auto* qkv_matrix = qkv.data(); + tester.AddInput("past", past_dims, reordered_kv_cache); - auto pair = MergePastKWithPresentKAndTranspose(kv_cache.data(), qkv_matrix + hidden_size, batch_size, - number_of_heads, past_sequence_length, - max_sequence_length, head_size); + // Rel + tester.AddOptionalInputEdge(); - auto k_merged = pair.first; - auto k_transpose = pair.second; + // Past sequence length + std::vector arr_past_sequence_len(1, past_sequence_length); + tester.AddInput("past_sequence_length", {1}, arr_past_sequence_len); - auto qk_transpose = QK_Transpose(qkv_matrix, k_transpose.data(), batch_size, number_of_heads, - total_sequence_length, head_size); + // QKV MatMul + auto qkv = QKV(input, weight, bias, batch_size, sequence_length, hidden_size); + auto* qkv_matrix = qkv.data(); - auto softmax_qk_transpose = Softmax_QK_Transpose(qk_transpose.data(), batch_size, number_of_heads, - sequence_length, total_sequence_length, head_size); + auto pair = MergePastKWithPresentKAndTranspose(kv_cache.data(), qkv_matrix + hidden_size, batch_size, + number_of_heads, past_sequence_length, + max_sequence_length, head_size); - auto present = MergeReorderedKVCacheWithK(reordered_kv_cache, qkv_matrix + hidden_size, batch_size, - number_of_heads, past_sequence_length, max_sequence_length, head_size); + auto k_merged = pair.first; + auto k_transpose = pair.second; - // Validate our test logic - // We want to validate if our merged "unordered" K is the same as - // the merged "ordered" K so that the QKT we do in our test code - // is equivalent to the QKT we do in the kernel - ValidateReorderedMergedKWithK(k_merged.data(), present.data(), batch_size, number_of_heads, total_sequence_length, max_sequence_length, head_size); + auto qk_transpose = QK_Transpose(qkv_matrix, k_transpose.data(), batch_size, number_of_heads, + total_sequence_length, head_size); - MergeReorderedKVCacheWithV(present.data() + (past_present_size / 2), qkv_matrix + 2 * hidden_size, batch_size, - number_of_heads, past_sequence_length, max_sequence_length, head_size); + auto softmax_qk_transpose = Softmax_QK_Transpose(qk_transpose.data(), batch_size, number_of_heads, + sequence_length, total_sequence_length, head_size); - auto output = Softmax_QK_Transpose_V(softmax_qk_transpose.data(), present.data() + (past_present_size / 2), - batch_size, number_of_heads, - sequence_length, total_sequence_length, - max_sequence_length, head_size); + auto present = MergeReorderedKVCacheWithK(reordered_kv_cache, qkv_matrix + hidden_size, batch_size, + number_of_heads, past_sequence_length, max_sequence_length, head_size); - // Output(s) - tester.AddOutput("output", input_dims, output); + // Validate our test logic + // We want to validate if our merged "unordered" K is the same as + // the merged "ordered" K so that the QKT we do in our test code + // is equivalent to the QKT we do in the kernel + ValidateReorderedMergedKWithK(k_merged.data(), present.data(), batch_size, number_of_heads, total_sequence_length, max_sequence_length, head_size); - tester.AddOutput("present", past_dims, present); + MergeReorderedKVCacheWithV(present.data() + (past_present_size / 2), qkv_matrix + 2 * hidden_size, batch_size, + number_of_heads, past_sequence_length, max_sequence_length, head_size); - // Run - Regular kernel execution path - { - std::vector> execution_providers; - execution_providers.push_back(DefaultCudaExecutionProvider()); - tester.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); - } + auto output = Softmax_QK_Transpose_V(softmax_qk_transpose.data(), present.data() + (past_present_size / 2), + batch_size, number_of_heads, + sequence_length, total_sequence_length, + max_sequence_length, head_size); - // Test alternate kernel path of loading more KV data "in flight" - { - ScopedEnvironmentVariables scoped_env_vars{ - EnvVarMap{{onnxruntime::contrib::attention::kDecoderMaskedAttentionLoadKVDataInFlight, "1"}}}; + // Output(s) + tester.AddOutput("output", input_dims, output); - std::vector> execution_providers; - execution_providers.push_back(DefaultCudaExecutionProvider()); + tester.AddOutput("present", past_dims, present); - tester.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); - } - } + // Run - Regular kernel execution path + { + std::vector> execution_providers; + execution_providers.push_back(DefaultCudaExecutionProvider()); + tester.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); + } + + // Test alternate kernel path of loading more KV data "in flight" + { + ScopedEnvironmentVariables scoped_env_vars{ + EnvVarMap{{onnxruntime::contrib::attention::kDecoderMaskedAttentionLoadKVDataInFlight, "1"}}}; + + std::vector> execution_providers; + execution_providers.push_back(DefaultCudaExecutionProvider()); + + tester.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); } } } @@ -766,122 +783,138 @@ TEST(DecoderMaskedSelfAttentionTest, Test_fp16) { return; } - // Vary batch size - for (int batch_size = 1; batch_size <= 5; batch_size += 2) { - // Vary kv_lengths - for (int past_sequence_length = 1; past_sequence_length <= 3000; past_sequence_length += 150) { - int sequence_length = 1; - int number_of_heads = 12; + // Buckets for test data: + // batch_size: 1, >=2 + // past_sequence_length 0, 1~30, 31~2046, >=2047 (so that total_sequence_length: 1, 2-31, 32~2047, >=2048) + // head_size: 32, 64, 128 + struct MyTestCase { + int batch_size; + int past_sequence_length; + int hidden_size; + } test_cases[] = { + {1, 0, 768}, + {1, 1, 768}, + {3, 30, 384}, + {8, 31, 1536}, + {4, 256, 384}, + {3, 1024, 768}, + {2, 2046, 1536}, + {1, 2047, 384}, + {2, 3000, 768}, + }; - // Vary head_size / hidden_size - int hidden_sizes[3] = {384, 768, 1536}; - for (int hidden_size : hidden_sizes) { - int head_size = (hidden_size / number_of_heads); - int total_sequence_length = sequence_length + past_sequence_length; - int max_sequence_length = past_sequence_length + 1; // Always keep > past_sequence_length + constexpr int sequence_length = 1; + constexpr int number_of_heads = 12; - OpTester tester("DecoderMaskedSelfAttention", 1, onnxruntime::kMSDomain); - tester.AddAttribute("num_heads", static_cast(number_of_heads)); - tester.AddAttribute("past_present_share_buffer", static_cast(1)); + for (MyTestCase test_case : test_cases) { + int batch_size = test_case.batch_size; + int past_sequence_length = test_case.past_sequence_length; + int hidden_size = test_case.hidden_size; - std::vector input_dims = {batch_size, sequence_length, hidden_size}; - std::vector weights_dims = {hidden_size, 3 * hidden_size}; - std::vector bias_dims = {3 * hidden_size}; - std::vector output_dims = {batch_size, sequence_length, hidden_size}; + int head_size = (hidden_size / number_of_heads); + int total_sequence_length = sequence_length + past_sequence_length; + int max_sequence_length = past_sequence_length + 1; // Always keep > past_sequence_length - auto input = CreateRandom(batch_size * sequence_length * hidden_size); - tester.AddInput("input", input_dims, input); + OpTester tester("DecoderMaskedSelfAttention", 1, onnxruntime::kMSDomain); + tester.AddAttribute("num_heads", static_cast(number_of_heads)); + tester.AddAttribute("past_present_share_buffer", static_cast(1)); - auto weight = CreateRandom(hidden_size * 3 * hidden_size); - tester.AddInput("weight", weights_dims, weight); + std::vector input_dims = {batch_size, sequence_length, hidden_size}; + std::vector weights_dims = {hidden_size, 3 * hidden_size}; + std::vector bias_dims = {3 * hidden_size}; + std::vector output_dims = {batch_size, sequence_length, hidden_size}; - auto bias = CreateRandom(3 * hidden_size); - tester.AddInput("bias", bias_dims, bias); + auto input = CreateRandom(batch_size * sequence_length * hidden_size); + tester.AddInput("input", input_dims, input); - // Mask - tester.AddOptionalInputEdge(); + auto weight = CreateRandom(hidden_size * 3 * hidden_size); + tester.AddInput("weight", weights_dims, weight); - // Past - std::vector past_dims = {2, batch_size, number_of_heads, max_sequence_length, head_size}; - int past_present_size = 2 * batch_size * number_of_heads * max_sequence_length * head_size; + auto bias = CreateRandom(3 * hidden_size); + tester.AddInput("bias", bias_dims, bias); - auto kv_cache = CreateRandom(past_present_size); + // Mask + tester.AddOptionalInputEdge(); - auto reordered_kv_cache = ReorderKVCache(kv_cache, batch_size, - number_of_heads, past_sequence_length, head_size, max_sequence_length); + // Past + std::vector past_dims = {2, batch_size, number_of_heads, max_sequence_length, head_size}; + int past_present_size = 2 * batch_size * number_of_heads * max_sequence_length * head_size; - // Validate if reordering went well - by transposing and checking equality - int chunk_size = 16 / sizeof(MLFloat16); - int num_chunks = head_size / chunk_size; - auto transposed = Transpose(kv_cache.data(), batch_size, number_of_heads, num_chunks, max_sequence_length, chunk_size); - CheckEquality(transposed.data(), reordered_kv_cache.data(), batch_size, number_of_heads, num_chunks, - max_sequence_length, past_sequence_length, chunk_size); + auto kv_cache = CreateRandom(past_present_size); - tester.AddInput("past", past_dims, reordered_kv_cache); + auto reordered_kv_cache = ReorderKVCache(kv_cache, batch_size, + number_of_heads, past_sequence_length, head_size, max_sequence_length); - // Rel - tester.AddOptionalInputEdge(); + // Validate if reordering went well - by transposing and checking equality + int chunk_size = 16 / sizeof(MLFloat16); + int num_chunks = head_size / chunk_size; + auto transposed = Transpose(kv_cache.data(), batch_size, number_of_heads, num_chunks, max_sequence_length, chunk_size); + CheckEquality(transposed.data(), reordered_kv_cache.data(), batch_size, number_of_heads, num_chunks, + max_sequence_length, past_sequence_length, chunk_size); - // Past sequence length - std::vector arr_past_sequence_len(1, past_sequence_length); - tester.AddInput("past_sequence_length", {1}, arr_past_sequence_len); + tester.AddInput("past", past_dims, reordered_kv_cache); - // QKV MatMul - auto qkv = QKV(input, weight, bias, batch_size, sequence_length, hidden_size); - auto* qkv_matrix = qkv.data(); + // Rel + tester.AddOptionalInputEdge(); - auto pair = MergePastKWithPresentKAndTranspose(kv_cache.data(), qkv_matrix + hidden_size, batch_size, - number_of_heads, past_sequence_length, - max_sequence_length, head_size); + // Past sequence length + std::vector arr_past_sequence_len(1, past_sequence_length); + tester.AddInput("past_sequence_length", {1}, arr_past_sequence_len); - auto k_merged = pair.first; - auto k_transpose = pair.second; + // QKV MatMul + auto qkv = QKV(input, weight, bias, batch_size, sequence_length, hidden_size); + auto* qkv_matrix = qkv.data(); - auto qk_transpose = QK_Transpose(qkv_matrix, k_transpose.data(), batch_size, number_of_heads, - total_sequence_length, head_size); + auto pair = MergePastKWithPresentKAndTranspose(kv_cache.data(), qkv_matrix + hidden_size, batch_size, + number_of_heads, past_sequence_length, + max_sequence_length, head_size); - auto softmax_qk_transpose = Softmax_QK_Transpose(qk_transpose.data(), batch_size, number_of_heads, - sequence_length, total_sequence_length, head_size); + auto k_merged = pair.first; + auto k_transpose = pair.second; - auto present = MergeReorderedKVCacheWithK(reordered_kv_cache, qkv_matrix + hidden_size, batch_size, - number_of_heads, past_sequence_length, max_sequence_length, head_size); + auto qk_transpose = QK_Transpose(qkv_matrix, k_transpose.data(), batch_size, number_of_heads, + total_sequence_length, head_size); - // Validate our test logic - // We want to validate if our merged "unordered" K is the same as - // the merged "ordered" K so that the QKT we do in our test code - // is equivalent to the QKT we do in the kernel - ValidateReorderedMergedKWithK(k_merged.data(), present.data(), batch_size, number_of_heads, total_sequence_length, max_sequence_length, head_size); + auto softmax_qk_transpose = Softmax_QK_Transpose(qk_transpose.data(), batch_size, number_of_heads, + sequence_length, total_sequence_length, head_size); - MergeReorderedKVCacheWithV(present.data() + (past_present_size / 2), qkv_matrix + 2 * hidden_size, batch_size, - number_of_heads, past_sequence_length, max_sequence_length, head_size); + auto present = MergeReorderedKVCacheWithK(reordered_kv_cache, qkv_matrix + hidden_size, batch_size, + number_of_heads, past_sequence_length, max_sequence_length, head_size); - auto output = Softmax_QK_Transpose_V(softmax_qk_transpose.data(), present.data() + (past_present_size / 2), - batch_size, number_of_heads, - sequence_length, total_sequence_length, - max_sequence_length, head_size); + // Validate our test logic + // We want to validate if our merged "unordered" K is the same as + // the merged "ordered" K so that the QKT we do in our test code + // is equivalent to the QKT we do in the kernel + ValidateReorderedMergedKWithK(k_merged.data(), present.data(), batch_size, number_of_heads, total_sequence_length, max_sequence_length, head_size); - // Output(s) - tester.AddOutput("output", input_dims, output); + MergeReorderedKVCacheWithV(present.data() + (past_present_size / 2), qkv_matrix + 2 * hidden_size, batch_size, + number_of_heads, past_sequence_length, max_sequence_length, head_size); - tester.AddOutput("present", past_dims, present); + auto output = Softmax_QK_Transpose_V(softmax_qk_transpose.data(), present.data() + (past_present_size / 2), + batch_size, number_of_heads, + sequence_length, total_sequence_length, + max_sequence_length, head_size); - // Run - Regular kernel execution path - { - std::vector> execution_providers; - execution_providers.push_back(DefaultCudaExecutionProvider()); - tester.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); - } + // Output(s) + tester.AddOutput("output", input_dims, output); - // Test alternate kernel path of loading more KV data "in flight" - { - ScopedEnvironmentVariables scoped_env_vars{ - EnvVarMap{{onnxruntime::contrib::attention::kDecoderMaskedAttentionLoadKVDataInFlight, "1"}}}; + tester.AddOutput("present", past_dims, present); - std::vector> execution_providers; - execution_providers.push_back(DefaultCudaExecutionProvider()); - tester.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); - } - } + // Run - Regular kernel execution path + { + std::vector> execution_providers; + execution_providers.push_back(DefaultCudaExecutionProvider()); + tester.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); + } + + // Test alternate kernel path of loading more KV data "in flight" + { + ScopedEnvironmentVariables scoped_env_vars{ + EnvVarMap{{onnxruntime::contrib::attention::kDecoderMaskedAttentionLoadKVDataInFlight, "1"}}}; + + std::vector> execution_providers; + execution_providers.push_back(DefaultCudaExecutionProvider()); + tester.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); } } } @@ -889,4 +922,4 @@ TEST(DecoderMaskedSelfAttentionTest, Test_fp16) { #endif } // namespace test -} // namespace onnxruntime \ No newline at end of file +} // namespace onnxruntime