diff --git a/onnxruntime/core/optimizer/embed_layer_norm_fusion.cc b/onnxruntime/core/optimizer/embed_layer_norm_fusion.cc index eb1cd9e85c..904714c87b 100644 --- a/onnxruntime/core/optimizer/embed_layer_norm_fusion.cc +++ b/onnxruntime/core/optimizer/embed_layer_norm_fusion.cc @@ -112,10 +112,10 @@ static bool MatchInputToConcatSubgraph( const logging::Logger& logger, const NodeIndex expected_gather_node_1_index) { std::vector expand_parent_path1{ - {0, index, "Concat", {4, 11}, kOnnxDomain}, - {0, 0, "Unsqueeze", {1, 11}, kOnnxDomain}, - {0, 0, "Gather", {1, 11}, kOnnxDomain}, - {0, 0, "Shape", {1}, kOnnxDomain}, + {0, index, "Concat", {4, 11, 13}, kOnnxDomain}, + {0, 0, "Unsqueeze", {1, 11, 13}, kOnnxDomain}, + {0, 0, "Gather", {1, 11, 13}, kOnnxDomain}, + {0, 0, "Shape", {1, 13}, kOnnxDomain}, }; std::vector edges; @@ -145,9 +145,9 @@ static bool MatchInputToConcatSubgraph( } std::vector concat_parent_path{ - {0, 1, "Unsqueeze", {1, 11}, kOnnxDomain}, - {0, 0, "Gather", {1, 11}, kOnnxDomain}, - {0, 0, "Shape", {1}, kOnnxDomain}}; + {0, 1, "Unsqueeze", {1, 11, 13}, kOnnxDomain}, + {0, 0, "Gather", {1, 11, 13}, kOnnxDomain}, + {0, 0, "Shape", {1, 13}, kOnnxDomain}}; if (!graph_utils::FindPath(concat_node, true, concat_parent_path, edges, logger)) { DEBUG_LOG("Failed to find path 2 of position shape."); @@ -316,7 +316,7 @@ static bool MatchPositionEmbeddingSubgraphsFromGather( // Match Shape --> Expand path. std::vector pg_edges_2; - if (!graph_utils::FindPath(expand_node, true, {{0, 1, "Shape", {1}, kOnnxDomain}}, pg_edges_2, logger)) { + if (!graph_utils::FindPath(expand_node, true, {{0, 1, "Shape", {1, 13}, kOnnxDomain}}, pg_edges_2, logger)) { DEBUG_LOG("Failed to match Shape node. "); return false; } diff --git a/onnxruntime/test/optimizer/graph_transform_test.cc b/onnxruntime/test/optimizer/graph_transform_test.cc index 9a01db9ee4..18e58e7efa 100644 --- a/onnxruntime/test/optimizer/graph_transform_test.cc +++ b/onnxruntime/test/optimizer/graph_transform_test.cc @@ -2897,15 +2897,47 @@ TEST_F(GraphTransformationTests, EmbedLayerNormFusionFormat2) { ASSERT_TRUE(op_to_count["com.microsoft.EmbedLayerNormalization"] == 1); } -TEST_F(GraphTransformationTests, EmbedLayerNormFusionFormat3) { - auto model_uri = MODEL_FOLDER "fusion/embed_layer_norm_format3.onnx"; +static void EmbedLayerNormFusionFormat3(const std::basic_string& file_path, logging::Logger* logger) { std::shared_ptr p_model; - ASSERT_STATUS_OK(Model::Load(model_uri, p_model, nullptr, *logger_)); + ASSERT_TRUE(Model::Load(file_path, p_model, nullptr, *logger).IsOK()); Graph& graph = p_model->MainGraph(); onnxruntime::GraphTransformerManager graph_transformation_mgr{5}; graph_transformation_mgr.Register(onnxruntime::make_unique(), TransformerLevel::Level2); - auto ret = graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2, *logger_); + auto ret = graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2, *logger); + ASSERT_TRUE(ret.IsOK()); + + std::map op_to_count = CountOpsInGraph(graph); + EXPECT_EQ(op_to_count["Shape"], 0); + EXPECT_EQ(op_to_count["Expand"], 0); + EXPECT_EQ(op_to_count["Gather"], 0); + EXPECT_EQ(op_to_count["Unsqueeze"], 0); + EXPECT_EQ(op_to_count["LayerNormalization"], 0); + EXPECT_EQ(op_to_count["com.microsoft.SkipLayerNormalization"], 0); + EXPECT_EQ(op_to_count["ReduceSum"], 1); + EXPECT_EQ(op_to_count["MatMul"], 1); + EXPECT_EQ(op_to_count["Add"], 2); + EXPECT_EQ(op_to_count["Cast"], 3); + EXPECT_EQ(op_to_count["com.microsoft.Attention"], 1); + EXPECT_EQ(op_to_count["com.microsoft.EmbedLayerNormalization"], 1); +} + +TEST_F(GraphTransformationTests, EmbedLayerNormFusionFormat3) { + EmbedLayerNormFusionFormat3(MODEL_FOLDER "fusion/embed_layer_norm_format3.onnx", logger_.get()); +} + +TEST_F(GraphTransformationTests, EmbedLayerNormFusionFormat3_OpSet13) { + EmbedLayerNormFusionFormat3(MODEL_FOLDER "fusion/embed_layer_norm_format3_opset13.onnx", logger_.get()); +} + +static void EmbedLayerNormFusionFormat3NoCast(const std::basic_string& file_path, logging::Logger* logger) { + std::shared_ptr p_model; + ASSERT_TRUE(Model::Load(file_path, p_model, nullptr, *logger).IsOK()); + Graph& graph = p_model->MainGraph(); + + onnxruntime::GraphTransformerManager graph_transformation_mgr{5}; + graph_transformation_mgr.Register(onnxruntime::make_unique(), TransformerLevel::Level2); + auto ret = graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2, *logger); ASSERT_TRUE(ret.IsOK()); std::map op_to_count = CountOpsInGraph(graph); @@ -2924,29 +2956,11 @@ TEST_F(GraphTransformationTests, EmbedLayerNormFusionFormat3) { } TEST_F(GraphTransformationTests, EmbedLayerNormFusionFormat3NoCast) { - auto model_uri = MODEL_FOLDER "fusion/embed_layer_norm_format3_no_cast.onnx"; - std::shared_ptr p_model; - ASSERT_STATUS_OK(Model::Load(model_uri, p_model, nullptr, *logger_)); - Graph& graph = p_model->MainGraph(); + EmbedLayerNormFusionFormat3NoCast(MODEL_FOLDER "fusion/embed_layer_norm_format3_no_cast.onnx", logger_.get()); +} - onnxruntime::GraphTransformerManager graph_transformation_mgr{5}; - graph_transformation_mgr.Register(onnxruntime::make_unique(), TransformerLevel::Level2); - auto ret = graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2, *logger_); - ASSERT_TRUE(ret.IsOK()); - - std::map op_to_count = CountOpsInGraph(graph); - EXPECT_EQ(op_to_count["Shape"], 0); - EXPECT_EQ(op_to_count["Expand"], 0); - EXPECT_EQ(op_to_count["Gather"], 0); - EXPECT_EQ(op_to_count["Unsqueeze"], 0); - EXPECT_EQ(op_to_count["LayerNormalization"], 0); - EXPECT_EQ(op_to_count["com.microsoft.SkipLayerNormalization"], 0); - EXPECT_EQ(op_to_count["ReduceSum"], 1); - EXPECT_EQ(op_to_count["MatMul"], 1); - EXPECT_EQ(op_to_count["Add"], 2); - EXPECT_EQ(op_to_count["Cast"], 3); - EXPECT_EQ(op_to_count["com.microsoft.Attention"], 1); - EXPECT_EQ(op_to_count["com.microsoft.EmbedLayerNormalization"], 1); +TEST_F(GraphTransformationTests, EmbedLayerNormFusionFormat3NoCast_OpSet13) { + EmbedLayerNormFusionFormat3NoCast(MODEL_FOLDER "fusion/embed_layer_norm_format3_no_cast_opset13.onnx", logger_.get()); } TEST_F(GraphTransformationTests, EmbedLayerNormFusionFormat4) { @@ -2977,15 +2991,14 @@ TEST_F(GraphTransformationTests, EmbedLayerNormFusionFormat4) { ASSERT_TRUE(op_to_count["com.microsoft.EmbedLayerNormalization"] == 1); } -TEST_F(GraphTransformationTests, EmbedLayerNormFusionFormat5) { - auto model_uri = MODEL_FOLDER "fusion/embed_layer_norm_format5.onnx"; +static void EmbedLayerNormFusionFormat5(const std::basic_string& file_path, logging::Logger* logger) { std::shared_ptr p_model; - ASSERT_STATUS_OK(Model::Load(model_uri, p_model, nullptr, *logger_)); + ASSERT_TRUE(Model::Load(file_path, p_model, nullptr, *logger).IsOK()); Graph& graph = p_model->MainGraph(); onnxruntime::GraphTransformerManager graph_transformation_mgr{5}; graph_transformation_mgr.Register(onnxruntime::make_unique(), TransformerLevel::Level2); - auto ret = graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2, *logger_); + auto ret = graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2, *logger); ASSERT_TRUE(ret.IsOK()); std::map op_to_count = CountOpsInGraph(graph); @@ -3019,15 +3032,22 @@ TEST_F(GraphTransformationTests, EmbedLayerNormFusionFormat5) { } } -TEST_F(GraphTransformationTests, EmbedLayerNormFusionFormat6) { - auto model_uri = MODEL_FOLDER "fusion/embed_layer_norm_format6.onnx"; +TEST_F(GraphTransformationTests, EmbedLayerNormFusionFormat5) { + EmbedLayerNormFusionFormat5(MODEL_FOLDER "fusion/embed_layer_norm_format5.onnx", logger_.get()); +} + +TEST_F(GraphTransformationTests, EmbedLayerNormFusionFormat5_OpSet13) { + EmbedLayerNormFusionFormat5(MODEL_FOLDER "fusion/embed_layer_norm_format5_opset13.onnx", logger_.get()); +} + +static void EmbedLayerNormFusionFormat6(const std::basic_string& file_path, logging::Logger* logger) { std::shared_ptr p_model; - ASSERT_STATUS_OK(Model::Load(model_uri, p_model, nullptr, *logger_)); + ASSERT_TRUE(Model::Load(file_path, p_model, nullptr, *logger).IsOK()); Graph& graph = p_model->MainGraph(); onnxruntime::GraphTransformerManager graph_transformation_mgr{5}; graph_transformation_mgr.Register(onnxruntime::make_unique(), TransformerLevel::Level2); - auto ret = graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2, *logger_); + auto ret = graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2, *logger); ASSERT_TRUE(ret.IsOK()); std::map op_to_count = CountOpsInGraph(graph); @@ -3048,6 +3068,14 @@ TEST_F(GraphTransformationTests, EmbedLayerNormFusionFormat6) { EXPECT_EQ(op_to_count["com.microsoft.EmbedLayerNormalization"], 1); } +TEST_F(GraphTransformationTests, EmbedLayerNormFusionFormat6) { + EmbedLayerNormFusionFormat6(MODEL_FOLDER "fusion/embed_layer_norm_format6.onnx", logger_.get()); +} + +TEST_F(GraphTransformationTests, EmbedLayerNormFusionFormat6_OpSet13) { + EmbedLayerNormFusionFormat6(MODEL_FOLDER "fusion/embed_layer_norm_format6_opset13.onnx", logger_.get()); +} + static void TestEmbedLayerNormFusionDistilBert(const std::basic_string& model_uri, std::map& op_to_count, logging::Logger* logger) { @@ -3076,6 +3104,18 @@ TEST_F(GraphTransformationTests, EmbedLayerNormFusionFormat7) { EXPECT_EQ(op_to_count["ReduceSum"], 1); } +TEST_F(GraphTransformationTests, EmbedLayerNormFusionFormat7_OpSet13) { + std::map op_to_count; + TestEmbedLayerNormFusionDistilBert(MODEL_FOLDER "fusion/embed_layer_norm_format7_opset13.onnx", op_to_count, logger_.get()); + EXPECT_EQ(op_to_count["com.microsoft.EmbedLayerNormalization"], 1); + EXPECT_EQ(op_to_count["com.microsoft.Attention"], 1); + EXPECT_EQ(op_to_count["Cast"], 2); + EXPECT_EQ(op_to_count["Shape"], 0); + EXPECT_EQ(op_to_count["Gather"], 0); + EXPECT_EQ(op_to_count["Unsqueeze"], 0); + EXPECT_EQ(op_to_count["ReduceSum"], 1); +} + TEST_F(GraphTransformationTests, EmbedLayerNormFusionFormat8) { std::map op_to_count; TestEmbedLayerNormFusionDistilBert(MODEL_FOLDER "fusion/embed_layer_norm_format8.onnx", op_to_count, logger_.get()); @@ -3088,6 +3128,18 @@ TEST_F(GraphTransformationTests, EmbedLayerNormFusionFormat8) { EXPECT_EQ(op_to_count["ReduceSum"], 1); } +TEST_F(GraphTransformationTests, EmbedLayerNormFusionFormat8_OpSet13) { + std::map op_to_count; + TestEmbedLayerNormFusionDistilBert(MODEL_FOLDER "fusion/embed_layer_norm_format8_opset13.onnx", op_to_count, logger_.get()); + EXPECT_EQ(op_to_count["com.microsoft.EmbedLayerNormalization"], 1); + EXPECT_EQ(op_to_count["com.microsoft.Attention"], 1); + EXPECT_EQ(op_to_count["Cast"], 2); + EXPECT_EQ(op_to_count["Shape"], 0); + EXPECT_EQ(op_to_count["Gather"], 0); + EXPECT_EQ(op_to_count["Unsqueeze"], 0); + EXPECT_EQ(op_to_count["ReduceSum"], 1); +} + TEST_F(GraphTransformationTests, EmbedLayerNormFusionFormat9) { std::map op_to_count; TestEmbedLayerNormFusionDistilBert(MODEL_FOLDER "fusion/embed_layer_norm_format9.onnx", op_to_count, logger_.get()); @@ -3100,15 +3152,26 @@ TEST_F(GraphTransformationTests, EmbedLayerNormFusionFormat9) { EXPECT_EQ(op_to_count["ReduceSum"], 1); } -TEST_F(GraphTransformationTests, EmbedLayerNormFusionMultiple) { - auto model_uri = MODEL_FOLDER "fusion/embed_layer_norm_multiple.onnx"; +TEST_F(GraphTransformationTests, EmbedLayerNormFusionFormat9_OpSet13) { + std::map op_to_count; + TestEmbedLayerNormFusionDistilBert(MODEL_FOLDER "fusion/embed_layer_norm_format9_opset13.onnx", op_to_count, logger_.get()); + EXPECT_EQ(op_to_count["com.microsoft.EmbedLayerNormalization"], 1); + EXPECT_EQ(op_to_count["com.microsoft.Attention"], 1); + EXPECT_EQ(op_to_count["Cast"], 2); + EXPECT_EQ(op_to_count["Shape"], 1); + EXPECT_EQ(op_to_count["Gather"], 2); + EXPECT_EQ(op_to_count["Unsqueeze"], 2); + EXPECT_EQ(op_to_count["ReduceSum"], 1); +} + +static void EmbedLayerNormFusionFormatMultiple(const std::basic_string& file_path, logging::Logger* logger) { std::shared_ptr p_model; - ASSERT_STATUS_OK(Model::Load(model_uri, p_model, nullptr, *logger_)); + ASSERT_TRUE(Model::Load(file_path, p_model, nullptr, *logger).IsOK()); Graph& graph = p_model->MainGraph(); onnxruntime::GraphTransformerManager graph_transformation_mgr{5}; graph_transformation_mgr.Register(onnxruntime::make_unique(), TransformerLevel::Level2); - auto ret = graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2, *logger_); + auto ret = graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2, *logger); ASSERT_TRUE(ret.IsOK()); std::map op_to_count = CountOpsInGraph(graph); @@ -3126,6 +3189,14 @@ TEST_F(GraphTransformationTests, EmbedLayerNormFusionMultiple) { EXPECT_EQ(op_to_count["com.microsoft.EmbedLayerNormalization"], 2); } +TEST_F(GraphTransformationTests, EmbedLayerNormFusionMultiple) { + EmbedLayerNormFusionFormatMultiple(MODEL_FOLDER "fusion/embed_layer_norm_multiple.onnx", logger_.get()); +} + +TEST_F(GraphTransformationTests, EmbedLayerNormFusionMultiple_OpSet13) { + EmbedLayerNormFusionFormatMultiple(MODEL_FOLDER "fusion/embed_layer_norm_multiple_opset13.onnx", logger_.get()); +} + TEST_F(GraphTransformationTests, DynamicQuantizeMatMulTest) { auto model_uri = MODEL_FOLDER "fusion/dynamic_quantize_matmul.onnx"; std::shared_ptr p_model; diff --git a/onnxruntime/test/testdata/transform/fusion/embed_layer_norm_format3_no_cast_opset13.onnx b/onnxruntime/test/testdata/transform/fusion/embed_layer_norm_format3_no_cast_opset13.onnx new file mode 100644 index 0000000000..bb1528acd2 Binary files /dev/null and b/onnxruntime/test/testdata/transform/fusion/embed_layer_norm_format3_no_cast_opset13.onnx differ diff --git a/onnxruntime/test/testdata/transform/fusion/embed_layer_norm_format3_opset13.onnx b/onnxruntime/test/testdata/transform/fusion/embed_layer_norm_format3_opset13.onnx new file mode 100644 index 0000000000..949fb8aeb2 Binary files /dev/null and b/onnxruntime/test/testdata/transform/fusion/embed_layer_norm_format3_opset13.onnx differ diff --git a/onnxruntime/test/testdata/transform/fusion/embed_layer_norm_format5_opset13.onnx b/onnxruntime/test/testdata/transform/fusion/embed_layer_norm_format5_opset13.onnx new file mode 100644 index 0000000000..980bbfe352 Binary files /dev/null and b/onnxruntime/test/testdata/transform/fusion/embed_layer_norm_format5_opset13.onnx differ diff --git a/onnxruntime/test/testdata/transform/fusion/embed_layer_norm_format6_opset13.onnx b/onnxruntime/test/testdata/transform/fusion/embed_layer_norm_format6_opset13.onnx new file mode 100644 index 0000000000..3ec2cc4714 Binary files /dev/null and b/onnxruntime/test/testdata/transform/fusion/embed_layer_norm_format6_opset13.onnx differ diff --git a/onnxruntime/test/testdata/transform/fusion/embed_layer_norm_format7_opset13.onnx b/onnxruntime/test/testdata/transform/fusion/embed_layer_norm_format7_opset13.onnx new file mode 100644 index 0000000000..7c0e64e1ec Binary files /dev/null and b/onnxruntime/test/testdata/transform/fusion/embed_layer_norm_format7_opset13.onnx differ diff --git a/onnxruntime/test/testdata/transform/fusion/embed_layer_norm_format8_opset13.onnx b/onnxruntime/test/testdata/transform/fusion/embed_layer_norm_format8_opset13.onnx new file mode 100644 index 0000000000..39d3b686f1 Binary files /dev/null and b/onnxruntime/test/testdata/transform/fusion/embed_layer_norm_format8_opset13.onnx differ diff --git a/onnxruntime/test/testdata/transform/fusion/embed_layer_norm_format9_opset13.onnx b/onnxruntime/test/testdata/transform/fusion/embed_layer_norm_format9_opset13.onnx new file mode 100644 index 0000000000..6501640e08 Binary files /dev/null and b/onnxruntime/test/testdata/transform/fusion/embed_layer_norm_format9_opset13.onnx differ diff --git a/onnxruntime/test/testdata/transform/fusion/embed_layer_norm_multiple_opset13.onnx b/onnxruntime/test/testdata/transform/fusion/embed_layer_norm_multiple_opset13.onnx new file mode 100644 index 0000000000..44ca4a4c50 Binary files /dev/null and b/onnxruntime/test/testdata/transform/fusion/embed_layer_norm_multiple_opset13.onnx differ