diff --git a/onnxruntime/core/optimizer/layer_norm_fusion.cc b/onnxruntime/core/optimizer/layer_norm_fusion.cc index 9bbb3c4d57..fe1efc1320 100644 --- a/onnxruntime/core/optimizer/layer_norm_fusion.cc +++ b/onnxruntime/core/optimizer/layer_norm_fusion.cc @@ -6,6 +6,7 @@ #include "core/optimizer/utils.h" #include "core/optimizer/transpose_optimizer/optimizer_api.h" #include "float.h" +#include #include using namespace ONNX_NAMESPACE; @@ -32,6 +33,52 @@ static bool IsSupportedDataType(const Node& node, int first_n_inputs = -1) { return true; } +static bool CheckAxesOnReduceMean(std::vector& axes_values, int64_t rank) { + // axes has be to be consecutive and constains the last dim. + std::sort(axes_values.begin(), axes_values.end()); + if (axes_values.back() > 0) { + // if reduce_mean node has input shape [N, C1, C2, C3] and axes_values = [1, 2], it's invalid. + // handle axes_values with both positive and negative values. + if (rank == -1) { + return false; + } + std::transform(axes_values.begin(), axes_values.end(), axes_values.begin(), + [rank](int64_t v) { return v >= 0 ? v - rank : v; }); + std::sort(axes_values.begin(), axes_values.end()); + } + // check if axes are consecutive + for (size_t i = 1; i < axes_values.size(); i++) { + if (axes_values[i] != axes_values[i - 1] + 1) { + axes_values.clear(); + break; + } + } + + if (axes_values.empty() || axes_values.back() != -1) { + // axes_values should contain the last dim. + return false; + } + return true; +} + +static std::vector GetAxesFromReduceMeanNode(Node& reduce_mean_node, const Graph& graph) { + const onnxruntime::NodeAttributes& attributes = reduce_mean_node.GetAttributes(); + std::vector axes_values; + // TODO: modify this codes when opset >= 18 (axes is an input). + if (attributes.find("axes") != attributes.end()) { + axes_values = RetrieveValues(attributes.at("axes")); + } else if (reduce_mean_node.InputDefs().size() == 2) { + const auto* axes = reduce_mean_node.InputDefs()[1]; + const auto* axes_const = graph.GetConstantInitializer(axes->Name(), true); + if (axes_const != nullptr) { + Initializer initializer{*axes_const, graph.ModelPath()}; + auto span_axes = initializer.DataAsSpan(); + axes_values.insert(axes_values.end(), span_axes.begin(), span_axes.end()); + } + } + return axes_values; +}; + /** Layer Normalization will fuse LayerNormalization into one node : +---------------------+ @@ -337,20 +384,31 @@ Status LayerNormFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level, nodes_to_remove.push_back(last_add_node); // get axes attributes - const onnxruntime::NodeAttributes& attributes = reduce_mean_node.GetAttributes(); - std::vector axes_values; - // TODO: modify this codes when opset >= 18 (axes is an input). - if (attributes.find("axes") != attributes.end()) { - axes_values = RetrieveValues(attributes.at("axes")); - } else if (reduce_mean_node.InputDefs().size() == 2) { - auto axes = reduce_mean_node.InputDefs()[1]; - auto axes_const = graph.GetConstantInitializer(axes->Name(), true); - if (axes_const != nullptr) { - Initializer initializer{*axes_const, graph.ModelPath()}; - axes_values.insert(axes_values.end(), initializer.DataAsSpan().begin(), initializer.DataAsSpan().end()); - } + + auto axes_values = GetAxesFromReduceMeanNode(reduce_mean_node, graph); + auto axes2_values = GetAxesFromReduceMeanNode(reduce_mean2_node, graph); + + // empty axes means reduce over all axes, which is not supported on layer-norm + if (axes_values.empty() || axes2_values.empty()) { + continue; } + auto input_shape = reduce_mean_node.MutableInputDefs()[0]->Shape(); + auto rank = input_shape ? input_shape->dim().size() : -1; + if (!CheckAxesOnReduceMean(axes_values, rank) || + !CheckAxesOnReduceMean(axes2_values, rank) || + axes_values != axes2_values) { + continue; + } + +#ifdef ENABLE_TRAINING_CORE +#else + // scale as 1D + if (axes_values.size() != 1) { + continue; + } +#endif + // Get the inputs for the new LayerNormalization node. // scale and bias could be multi-dims; we only support it for training at the moment // because SkipLayerNorm kernel, for example, has dependency on single dim size @@ -359,34 +417,18 @@ Status LayerNormFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level, for (size_t i = 0; i < mul_node.MutableInputDefs().size(); i++) { if (graph_utils::NodeArgIsConstant(graph, *(mul_node.MutableInputDefs()[i])) || graph_utils::IsGraphInput(graph, mul_node.MutableInputDefs()[i])) { -#ifdef ENABLE_TRAINING_CORE - if (axes_values.empty() || - mul_node.MutableInputDefs()[i]->Shape()->dim_size() == static_cast(axes_values.size())) { + if (mul_node.MutableInputDefs()[i]->Shape()->dim_size() == static_cast(axes_values.size())) { scale = mul_node.MutableInputDefs()[i]; } -#else - // Scale must be 1d. - if (mul_node.MutableInputDefs()[i]->Shape()->dim_size() == 1) { - scale = mul_node.MutableInputDefs()[i]; - } -#endif } } for (size_t i = 0; i < last_add_node.MutableInputDefs().size(); i++) { if (graph_utils::NodeArgIsConstant(graph, *(last_add_node.MutableInputDefs()[i])) || graph_utils::IsGraphInput(graph, last_add_node.MutableInputDefs()[i])) { -#ifdef ENABLE_TRAINING_CORE - if (axes_values.empty() || - last_add_node.MutableInputDefs()[i]->Shape()->dim_size() == static_cast(axes_values.size())) { + if (last_add_node.MutableInputDefs()[i]->Shape()->dim_size() == static_cast(axes_values.size())) { bias = last_add_node.MutableInputDefs()[i]; } -#else - // Bias must be 1d. - if (last_add_node.MutableInputDefs()[i]->Shape()->dim_size() == 1) { - bias = last_add_node.MutableInputDefs()[i]; - } -#endif } } if (scale == nullptr || bias == nullptr) { @@ -423,6 +465,9 @@ Status LayerNormFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level, layer_norm_node.AddAttribute("epsilon", DEFAULT_LAYERNORM_EPSILON); } + // The axis definition of layer_norm is ranging from axis to the last dim + layer_norm_node.AddAttribute("axis", static_cast(axes_values[0])); + // Set stash_type to double if any input is double, default value if float. if (x_input->TypeAsProto()->tensor_type().elem_type() == ONNX_NAMESPACE::TensorProto_DataType_DOUBLE || scale->TypeAsProto()->tensor_type().elem_type() == ONNX_NAMESPACE::TensorProto_DataType_DOUBLE) { @@ -598,19 +643,26 @@ Status SimplifiedLayerNormFusion::ApplyImpl(Graph& graph, bool& modified, int gr nodes_to_remove.push_back(mul_node); // get axes attributes - const onnxruntime::NodeAttributes& attributes = reduce_mean_node.GetAttributes(); - std::vector axes_values; - if (attributes.find("axes") != attributes.end()) { - axes_values = RetrieveValues(attributes.at("axes")); - } else if (reduce_mean_node.InputDefs().size() == 2) { - auto axes = reduce_mean_node.InputDefs()[1]; - auto axes_const = graph.GetConstantInitializer(axes->Name(), true); - if (axes_const != nullptr && axes_const->data_type() == ONNX_NAMESPACE::TensorProto_DataType_INT64) { - Initializer initializer{*axes_const, graph.ModelPath()}; - axes_values.insert(axes_values.end(), initializer.DataAsSpan().begin(), initializer.DataAsSpan().end()); - } + std::vector axes_values = GetAxesFromReduceMeanNode(reduce_mean_node, graph); + + if (axes_values.empty()) { + continue; } + auto rmean_input_shape = reduce_mean_node.MutableInputDefs()[0]->Shape(); + auto rank = rmean_input_shape ? rmean_input_shape->dim().size() : -1; + if (!CheckAxesOnReduceMean(axes_values, rank)) { + continue; + } + +#ifdef ENABLE_TRAINING_CORE +#else + // scale as 1D + if (axes_values.size() != 1) { + continue; + } +#endif + // Get the inputs for the new LayerNormalization node. // scale and bias could be multi-dims; we only support it for training at the moment // because SkipLayerNorm kernel, for example, has dependency on single dim size @@ -659,6 +711,8 @@ Status SimplifiedLayerNormFusion::ApplyImpl(Graph& graph, bool& modified, int gr layer_norm_node.AddAttribute("stash_type", static_cast(ONNX_NAMESPACE::TensorProto_DataType_DOUBLE)); } + layer_norm_node.AddAttribute("axis", static_cast(axes_values[0])); + // Assign provider to this new node. Provider should be same as the provider for old node. layer_norm_node.SetExecutionProviderType(reduce_mean_node.GetExecutionProviderType()); diff --git a/onnxruntime/test/optimizer/graph_transform_test.cc b/onnxruntime/test/optimizer/graph_transform_test.cc index bd2b733225..c7af8701ee 100755 --- a/onnxruntime/test/optimizer/graph_transform_test.cc +++ b/onnxruntime/test/optimizer/graph_transform_test.cc @@ -35,7 +35,6 @@ #include "core/optimizer/div_mul_fusion.h" #include "core/optimizer/dropout_elimination.h" #include "core/optimizer/dynamic_quantize_matmul_fusion.h" -#include "core/optimizer/embed_layer_norm_fusion.h" #include "core/optimizer/expand_elimination.h" #include "core/optimizer/fast_gelu_fusion.h" #include "core/optimizer/gather_fusion.h" @@ -51,7 +50,6 @@ #include "core/optimizer/identity_elimination.h" #include "core/optimizer/initializer.h" #include "core/optimizer/isinf_reducesum_fusion.h" -#include "core/optimizer/layer_norm_fusion.h" #include "core/optimizer/matmul_add_fusion.h" #include "core/optimizer/matmul_integer_to_float.h" #include "core/optimizer/matmul_scale_fusion.h" @@ -63,7 +61,6 @@ #include "core/optimizer/relu_clip_fusion.h" #include "core/optimizer/reshape_fusion.h" #include "core/optimizer/rule_based_graph_transformer.h" -#include "core/optimizer/skip_layer_norm_fusion.h" #include "core/optimizer/slice_elimination.h" #include "core/optimizer/unsqueeze_elimination.h" #include "core/optimizer/utils.h" @@ -4610,762 +4607,6 @@ TEST_F(GraphTransformationTests, ReshapeFusionOpsetTest) { } #endif -TEST_F(GraphTransformationTests, LayerNormFusionTest) { - constexpr const ORTCHAR_T* model_uri = MODEL_FOLDER "fusion/layer_norm.onnx"; - std::shared_ptr p_model; - ASSERT_STATUS_OK(Model::Load(model_uri, p_model, nullptr, *logger_)); - Graph& graph = p_model->MainGraph(); - - onnxruntime::GraphTransformerManager graph_transformation_mgr{5}; - ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::make_unique(), TransformerLevel::Level2)); - ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2, *logger_)); - - std::map op_to_count = CountOpsInGraph(graph); - ASSERT_TRUE(op_to_count["Div"] == 0); - ASSERT_TRUE(op_to_count["Add"] == 0); - ASSERT_TRUE(op_to_count["Sub"] == 0); - ASSERT_TRUE(op_to_count["ReduceMean"] == 0); - ASSERT_TRUE(op_to_count["Pow"] == 0); - ASSERT_TRUE(op_to_count["Sqrt"] == 0); - ASSERT_TRUE(op_to_count["LayerNormalization"] == 1); - - for (const Node& node : graph.Nodes()) { - if (node.OpType() == "LayerNormalization") { - // LayerNormalization should have three inputs. - EXPECT_EQ(node.InputDefs().size(), 3u) << "LayerNormalization number of inputs does not equal to 3. Got:" << node.InputDefs().size(); - // LayerNormalization input "scale" and "bias" should have the same dimension. - const TensorShapeProto* scale_shape = node.InputDefs()[1]->Shape(); - const TensorShapeProto* bias_shape = node.InputDefs()[2]->Shape(); - EXPECT_EQ(scale_shape->dim_size(), 1) << "LayerNormalization scale should be 1D. Got: " << scale_shape->dim_size(); - EXPECT_EQ(bias_shape->dim_size(), 1) << "LayerNormalization bias should be 1D. Got: " << bias_shape->dim_size(); - EXPECT_EQ(scale_shape->dim(0).dim_value(), bias_shape->dim(0).dim_value()); - } else { - EXPECT_TRUE(false) << "Unexpected node " << node.Name(); - } - } -} - -TEST_F(GraphTransformationTests, LayerNormWithCastFusionTest) { - constexpr const ORTCHAR_T* model_uri = MODEL_FOLDER "fusion/layer_norm_with_cast.onnx"; - std::shared_ptr p_model; - ASSERT_STATUS_OK(Model::Load(model_uri, p_model, nullptr, *logger_)); - Graph& graph = p_model->MainGraph(); - - onnxruntime::GraphTransformerManager graph_transformation_mgr{5}; - ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::make_unique(), TransformerLevel::Level2)); - ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2, *logger_)); - - std::map op_to_count = CountOpsInGraph(graph); - -#ifdef ENABLE_TRAINING_CORE - ASSERT_TRUE(op_to_count["Cast"] == 0); - ASSERT_TRUE(op_to_count["LayerNormalization"] == 1); -#else - ASSERT_TRUE(op_to_count["Cast"] == 1); - ASSERT_TRUE(op_to_count["LayerNormalization"] == 0); -#endif -} - -TEST_F(GraphTransformationTests, LayerNormWithCastFusionTest_2) { - constexpr const ORTCHAR_T* model_uri = MODEL_FOLDER "fusion/layer_norm_with_cast_2.onnx"; - std::shared_ptr p_model; - ASSERT_STATUS_OK(Model::Load(model_uri, p_model, nullptr, *logger_)); - Graph& graph = p_model->MainGraph(); - - onnxruntime::GraphTransformerManager graph_transformation_mgr{5}; - ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::make_unique(), TransformerLevel::Level2)); - ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2, *logger_)); - - std::map op_to_count = CountOpsInGraph(graph); - - ASSERT_TRUE(op_to_count["Cast"] == 0); - ASSERT_TRUE(op_to_count["LayerNormalization"] == 1); -} - -TEST_F(GraphTransformationTests, LayerNormWithCastFusionTest_3) { - constexpr const ORTCHAR_T* model_uri = MODEL_FOLDER "fusion/layer_norm_with_cast_3.onnx"; - std::shared_ptr p_model; - ASSERT_STATUS_OK(Model::Load(model_uri, p_model, nullptr, *logger_)); - Graph& graph = p_model->MainGraph(); - - onnxruntime::GraphTransformerManager graph_transformation_mgr{5}; - ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::make_unique(), TransformerLevel::Level2)); - ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2, *logger_)); - - std::map op_to_count = CountOpsInGraph(graph); - - ASSERT_TRUE(op_to_count["Cast"] == 0); - ASSERT_TRUE(op_to_count["LayerNormalization"] == 1); -} - -TEST_F(GraphTransformationTests, LayerNormWithCastFusionTest_4) { - constexpr const ORTCHAR_T* model_uri = MODEL_FOLDER "fusion/layer_norm_with_cast_4.onnx"; - std::shared_ptr p_model; - ASSERT_STATUS_OK(Model::Load(model_uri, p_model, nullptr, *logger_)); - Graph& graph = p_model->MainGraph(); - - onnxruntime::GraphTransformerManager graph_transformation_mgr{5}; - ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::make_unique(), TransformerLevel::Level2)); - ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2, *logger_)); - - std::map op_to_count = CountOpsInGraph(graph); - - ASSERT_TRUE(op_to_count["Cast"] == 0); - ASSERT_TRUE(op_to_count["LayerNormalization"] == 1); -} - -TEST_F(GraphTransformationTests, LayerNormWithSubDupFusionTest) { - constexpr const ORTCHAR_T* model_uri = MODEL_FOLDER "fusion/layer_norm_sub_dup.onnx"; - std::shared_ptr p_model; - ASSERT_STATUS_OK(Model::Load(model_uri, p_model, nullptr, *logger_)); - Graph& graph = p_model->MainGraph(); - - onnxruntime::GraphTransformerManager graph_transformation_mgr{5}; - ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::make_unique(), TransformerLevel::Level2)); - ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2, *logger_)); - - std::map op_to_count = CountOpsInGraph(graph); - ASSERT_TRUE(op_to_count["Div"] == 0); - ASSERT_TRUE(op_to_count["Add"] == 0); - ASSERT_TRUE(op_to_count["Sub"] == 0); - ASSERT_TRUE(op_to_count["ReduceMean"] == 0); - ASSERT_TRUE(op_to_count["Pow"] == 0); - ASSERT_TRUE(op_to_count["Sqrt"] == 0); - ASSERT_TRUE(op_to_count["LayerNormalization"] == 1); - - for (const Node& node : graph.Nodes()) { - if (node.OpType() == "LayerNormalization") { - // LayerNormalization should have three inputs. - EXPECT_EQ(node.InputDefs().size(), 3u) << "LayerNormalization number of inputs does not equal to 3. Got:" << node.InputDefs().size(); - // LayerNormalization input "scale" and "bias" should have the same dimension. - const TensorShapeProto* scale_shape = node.InputDefs()[1]->Shape(); - const TensorShapeProto* bias_shape = node.InputDefs()[2]->Shape(); - EXPECT_EQ(scale_shape->dim_size(), 1) << "LayerNormalization scale should be 1D. Got: " << scale_shape->dim_size(); - EXPECT_EQ(bias_shape->dim_size(), 1) << "LayerNormalization bias should be 1D. Got: " << bias_shape->dim_size(); - EXPECT_EQ(scale_shape->dim(0).dim_value(), bias_shape->dim(0).dim_value()); - } else { - EXPECT_TRUE(false) << "Unexpected node " << node.Name(); - } - } -} - -TEST_F(GraphTransformationTests, LayerNormWithCastFusionTest_5) { - auto build_test_case = [&](ModelTestBuilder& builder) { - auto* data_arg = builder.MakeInput({{2, 3, 3, 3}}); - auto* pow_initializer = builder.MakeInitializer({}, {2.0f}); - auto* add_initializer = builder.MakeInitializer({}, {1e-5f}); - auto* weight_initializer = builder.MakeInitializer({3}, std::vector(3, MLFloat16(1.0f))); - auto* bias_initializer = builder.MakeInitializer({3}, std::vector(3, MLFloat16(0.0f))); - auto* reduce_mean_out_1 = builder.MakeIntermediate(); - auto* sub_out = builder.MakeIntermediate(); - auto* cast_out_1 = builder.MakeIntermediate(); - auto* pow_out = builder.MakeIntermediate(); - auto* reduce_mean_out_2 = builder.MakeIntermediate(); - auto* add_out_1 = builder.MakeIntermediate(); - auto* sqrt_out = builder.MakeIntermediate(); - auto* div_out = builder.MakeIntermediate(); - auto* cast_out_2 = builder.MakeIntermediate(); - auto* mul_out = builder.MakeIntermediate(); - auto* add_out_2 = builder.MakeOutput(); - auto opset = builder.DomainToVersionMap().find(kOnnxDomain)->second; - onnxruntime::NodeArg* axes = nullptr; - - if (opset >= 18) { - axes = builder.MakeInitializer({1}, {-1}); - builder.AddNode("ReduceMean", {data_arg, axes}, {reduce_mean_out_1}); - } else { - builder.AddNode("ReduceMean", {data_arg}, {reduce_mean_out_1}).AddAttribute("axes", std::vector{-1}); - } - builder.AddNode("Sub", {data_arg, reduce_mean_out_1}, {sub_out}); - builder.AddNode("Cast", {sub_out}, {cast_out_1}) - .AddAttribute("to", static_cast(ONNX_NAMESPACE::TensorProto_DataType_FLOAT)); - builder.AddNode("Pow", {cast_out_1, pow_initializer}, {pow_out}); - if (opset >= 18) { - builder.AddNode("ReduceMean", {pow_out, axes}, {reduce_mean_out_2}); - } else { - builder.AddNode("ReduceMean", {pow_out}, {reduce_mean_out_2}).AddAttribute("axes", std::vector{-1}); - } - builder.AddNode("Add", {reduce_mean_out_2, add_initializer}, {add_out_1}); - builder.AddNode("Sqrt", {add_out_1}, {sqrt_out}); - builder.AddNode("Div", {cast_out_1, sqrt_out}, {div_out}); - builder.AddNode("Cast", {div_out}, {cast_out_2}) - .AddAttribute("to", static_cast(ONNX_NAMESPACE::TensorProto_DataType_FLOAT16)); - builder.AddNode("Mul", {cast_out_2, weight_initializer}, {mul_out}); - builder.AddNode("Add", {mul_out, bias_initializer}, {add_out_2}); - }; - - auto pre_graph_checker = [&](Graph& graph) { - TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["ReduceMean"] == 2); - TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Sub"] == 1); - TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Cast"] == 2); - TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Pow"] == 1); - TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Add"] == 2); - TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Sqrt"] == 1); - TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Div"] == 1); - TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Mul"] == 1); - return Status::OK(); - }; - - auto post_graph_checker = [&](Graph& graph) { - TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["ReduceMean"] == 0); - TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Sub"] == 0); - TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Cast"] == 0); - TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Pow"] == 0); - TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Add"] == 0); - TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Sqrt"] == 0); - TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Div"] == 0); - TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Mul"] == 0); - TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["LayerNormalization"] == 1); - return Status::OK(); - }; - - std::unique_ptr transformer = std::make_unique(); - ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, {14, 18}, *logger_, std::move(transformer), TransformerLevel::Level1, - 1, pre_graph_checker, post_graph_checker)); -} - -TEST_F(GraphTransformationTests, SimplifiedLayerNormFusionTest) { - constexpr const ORTCHAR_T* model_uri = MODEL_FOLDER "fusion/layer_norm_t5.onnx"; - std::shared_ptr p_model; - ASSERT_STATUS_OK(Model::Load(model_uri, p_model, nullptr, *logger_)); - Graph& graph = p_model->MainGraph(); - - onnxruntime::GraphTransformerManager graph_transformation_mgr{5}; - ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::make_unique(), TransformerLevel::Level2)); - ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2, *logger_)); - - std::map op_to_count = CountOpsInGraph(graph); - ASSERT_TRUE(op_to_count["Div"] == 0); - ASSERT_TRUE(op_to_count["Add"] == 0); - ASSERT_TRUE(op_to_count["ReduceMean"] == 0); - ASSERT_TRUE(op_to_count["Pow"] == 0); - ASSERT_TRUE(op_to_count["Sqrt"] == 0); - ASSERT_TRUE(op_to_count["SimplifiedLayerNormalization"] == 1); - - for (const Node& node : graph.Nodes()) { - if (node.OpType() == "SimplifiedLayerNormalization") { - // LayerNormalization should have two inputs. - EXPECT_EQ(node.InputDefs().size(), 2u) << "LayerNormalization number of inputs does not equal to 2. Got:" << node.InputDefs().size(); - // LayerNormalization input "scale" and "bias" should have the same dimension. - const TensorShapeProto* scale_shape = node.InputDefs()[1]->Shape(); - EXPECT_EQ(scale_shape->dim_size(), 1) << "LayerNormalization scale should be 1D. Got: " << scale_shape->dim_size(); - } else { - EXPECT_TRUE(false) << "Unexpected node " << node.Name(); - } - } -} - -// If EP is non-GPU EP or unknown, the sub-graph will be not fused because CPU impl for SimplifiedLayerNormalization -// doesn't support input and scale having different data types. -TEST_F(GraphTransformationTests, SimplifiedLayerNormWithCastsFusionTest) { - constexpr const ORTCHAR_T* model_uri = MODEL_FOLDER "fusion/simplified_layer_norm_with_casts.onnx"; - std::shared_ptr p_model; - ASSERT_STATUS_OK(Model::Load(model_uri, p_model, nullptr, *logger_)); - Graph& graph = p_model->MainGraph(); - - InlinedHashSet compatible_eps; - onnxruntime::GraphTransformerManager graph_transformation_mgr{5}; - ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::make_unique(compatible_eps), - TransformerLevel::Level2)); - ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2, *logger_)); - - std::map op_to_count = CountOpsInGraph(graph); - ASSERT_TRUE(op_to_count["SimplifiedLayerNormalization"] == 0); -} - -TEST_F(GraphTransformationTests, SimplifiedLayerNormWithCastsFusionTestCudaEp) { - constexpr const ORTCHAR_T* model_uri = MODEL_FOLDER "fusion/simplified_layer_norm_with_casts.onnx"; - std::shared_ptr p_model; - ASSERT_STATUS_OK(Model::Load(model_uri, p_model, nullptr, *logger_)); - Graph& graph = p_model->MainGraph(); - for (auto& node : graph.Nodes()) { - node.SetExecutionProviderType(kCudaExecutionProvider); - } - - InlinedHashSet compatible_eps; - onnxruntime::GraphTransformerManager graph_transformation_mgr{5}; - ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::make_unique(compatible_eps), - TransformerLevel::Level2)); - ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2, *logger_)); - - std::map op_to_count = CountOpsInGraph(graph); - ASSERT_TRUE(op_to_count["Div"] == 0); - ASSERT_TRUE(op_to_count["Add"] == 0); - ASSERT_TRUE(op_to_count["ReduceMean"] == 0); - ASSERT_TRUE(op_to_count["Pow"] == 0); - ASSERT_TRUE(op_to_count["Sqrt"] == 0); - ASSERT_TRUE(op_to_count["Cast"] == 0); - ASSERT_TRUE(op_to_count["SimplifiedLayerNormalization"] == 1); - - for (const Node& node : graph.Nodes()) { - if (node.OpType() == "SimplifiedLayerNormalization") { - // LayerNormalization should have two inputs. - EXPECT_EQ(node.InputDefs().size(), 2u) - << "LayerNormalization number of inputs does not equal to 2. Got:" << node.InputDefs().size(); - // LayerNormalization input "scale" and "bias" should have the same dimension. - const TensorShapeProto* scale_shape = node.InputDefs()[1]->Shape(); - EXPECT_EQ(scale_shape->dim_size(), 1) - << "LayerNormalization scale should be 1D. Got: " << scale_shape->dim_size(); - } else if (node.OpType() == "Cast") { - continue; - } else { - EXPECT_TRUE(false) << "Unexpected node " << node.Name(); - } - } -} - -static void TestSkipLayerNormFusion(const std::basic_string& file_path, int add_count, int ln_count, - int skip_ln_count, int cast_count, logging::Logger* logger) { - std::shared_ptr p_model; - ASSERT_TRUE(Model::Load(file_path, p_model, nullptr, *logger).IsOK()); - Graph& graph = p_model->MainGraph(); - - onnxruntime::GraphTransformerManager graph_transformation_mgr{5}; - ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::make_unique(), TransformerLevel::Level2)); - ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::make_unique(), TransformerLevel::Level2)); - ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2, *logger)); - - std::map op_to_count = CountOpsInGraph(graph); - ASSERT_TRUE(op_to_count["Div"] == 0); - ASSERT_TRUE(op_to_count["Add"] == add_count); - ASSERT_TRUE(op_to_count["Sub"] == 0); - ASSERT_TRUE(op_to_count["ReduceMean"] == 0); - ASSERT_TRUE(op_to_count["Pow"] == 0); - ASSERT_TRUE(op_to_count["Sqrt"] == 0); - ASSERT_TRUE(op_to_count["LayerNormalization"] == ln_count); - ASSERT_TRUE(op_to_count["com.microsoft.SkipLayerNormalization"] == skip_ln_count); - ASSERT_TRUE(op_to_count["Cast"] == cast_count); -} - -TEST_F(GraphTransformationTests, SkipLayerNormFusionTest) { - TestSkipLayerNormFusion(MODEL_FOLDER "fusion/skip_layer_norm_format1.onnx", 0, 0, 1, 0, logger_.get()); - TestSkipLayerNormFusion(MODEL_FOLDER "fusion/skip_layer_norm_format2.onnx", 0, 0, 1, 0, logger_.get()); - TestSkipLayerNormFusion(MODEL_FOLDER "fusion/skip_layer_norm_format3.onnx", 0, 0, 1, 0, logger_.get()); - - TestSkipLayerNormFusion(MODEL_FOLDER "fusion/skip_layer_norm_format1_partial.onnx", 1, 0, 1, 0, logger_.get()); - TestSkipLayerNormFusion(MODEL_FOLDER "fusion/skip_layer_norm_format2_partial.onnx", 1, 0, 1, 0, logger_.get()); - TestSkipLayerNormFusion(MODEL_FOLDER "fusion/skip_layer_norm_format3_no_fusion.onnx", 1, 1, 0, 0, logger_.get()); - - TestSkipLayerNormFusion(MODEL_FOLDER "fusion/skip_layer_norm_format1_graph_output.onnx", 1, 0, 1, 0, logger_.get()); - TestSkipLayerNormFusion(MODEL_FOLDER "fusion/skip_layer_norm_format2_graph_output.onnx", 1, 0, 1, 0, logger_.get()); - TestSkipLayerNormFusion(MODEL_FOLDER "fusion/skip_layer_norm_format3_graph_output.onnx", 1, 1, 0, 0, logger_.get()); -} - -TEST_F(GraphTransformationTests, SkipLayerNormFusionWithCastTest) { - TestSkipLayerNormFusion(MODEL_FOLDER "fusion/skip_layer_norm_format1_with_cast.onnx", 0, 0, 1, 3, logger_.get()); - TestSkipLayerNormFusion(MODEL_FOLDER "fusion/skip_layer_norm_format2_with_cast.onnx", 0, 0, 1, 3, logger_.get()); - TestSkipLayerNormFusion(MODEL_FOLDER "fusion/skip_layer_norm_format3_with_cast.onnx", 0, 0, 1, 2, logger_.get()); - - TestSkipLayerNormFusion(MODEL_FOLDER "fusion/skip_layer_norm_format1_partial_with_cast.onnx", 1, 0, 1, 2, logger_.get()); - TestSkipLayerNormFusion(MODEL_FOLDER "fusion/skip_layer_norm_format2_partial_with_cast.onnx", 1, 0, 1, 2, logger_.get()); - TestSkipLayerNormFusion(MODEL_FOLDER "fusion/skip_layer_norm_format3_no_fusion_with_cast.onnx", 1, 1, 0, 0, logger_.get()); - - TestSkipLayerNormFusion(MODEL_FOLDER "fusion/skip_layer_norm_format1_graph_output_with_cast.onnx", 1, 0, 1, 2, logger_.get()); - TestSkipLayerNormFusion(MODEL_FOLDER "fusion/skip_layer_norm_format2_graph_output_with_cast.onnx", 1, 0, 1, 2, logger_.get()); - TestSkipLayerNormFusion(MODEL_FOLDER "fusion/skip_layer_norm_format3_graph_output_with_cast.onnx", 1, 1, 0, 0, logger_.get()); -} - -static void TestSkipLayerNormFusionInputOutputCheck(const std::basic_string& model_uri, bool with_cast, logging::Logger* logger) { - std::shared_ptr p_model; - ASSERT_STATUS_OK(Model::Load(model_uri, p_model, nullptr, *logger)); - Graph& graph = p_model->MainGraph(); - - onnxruntime::GraphTransformerManager graph_transformation_mgr{5}; - ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::make_unique(), TransformerLevel::Level2)); - ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::make_unique(), TransformerLevel::Level2)); - ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2, *logger)); - - for (Node& node : graph.Nodes()) { - if (node.OpType() == "SkipLayerNormalization") { - // check inputs - std::vector& input_defs = node.MutableInputDefs(); - EXPECT_EQ(input_defs.size(), 5u) << "SkipLayerNormalization number of inputs does not equal to 5. Got:" << node.InputDefs().size(); - EXPECT_EQ(input_defs[0]->Name(), ((with_cast) ? "input.1_Float" : "input.1")); - EXPECT_EQ(input_defs[1]->Name(), ((with_cast) ? "6_Float" : "6")); - EXPECT_EQ(input_defs[2]->Name(), "1"); - EXPECT_EQ(input_defs[3]->Name(), "2"); - EXPECT_EQ(input_defs[4]->Name(), ((with_cast) ? "4_Float" : "4")); - - // check outputs - std::vector& output_defs = node.MutableOutputDefs(); -#ifdef ENABLE_TRAINING_CORE - EXPECT_EQ(node.OutputDefs().size(), 3u) << "SkipLayerNormalization number of outputs does not equal to 3. Got:" << node.OutputDefs().size(); -#else - EXPECT_EQ(node.OutputDefs().size(), 1u) << "SkipLayerNormalization number of outputs does not equal to 1. Got:" << node.OutputDefs().size(); -#endif - EXPECT_EQ(output_defs[0]->Name(), "19"); - } else if (node.OpType() == "Cast") { - EXPECT_TRUE(with_cast) << "Unexpected node: " << node.OpType() << "," << node.Name(); - } else { - EXPECT_EQ(node.OpType(), "MatMul") << "Unexpected node: " << node.OpType() << "," << node.Name(); - } - } -} - -TEST_F(GraphTransformationTests, SkipLayerNormFusion_Input_Output_Check) { - TestSkipLayerNormFusionInputOutputCheck(MODEL_FOLDER "fusion/skip_layer_norm_input_output_check.onnx", false, logger_.get()); - TestSkipLayerNormFusionInputOutputCheck(MODEL_FOLDER "fusion/skip_layer_norm_input_output_with_cast_check.onnx", true, logger_.get()); -} - -static void TestSkipLayerNormFusionNoBeta(const std::basic_string& model_uri, bool with_cast, logging::Logger* logger) { - std::shared_ptr p_model; - ASSERT_STATUS_OK(Model::Load(model_uri, p_model, nullptr, *logger)); - Graph& graph = p_model->MainGraph(); - - onnxruntime::GraphTransformerManager graph_transformation_mgr{5}; - ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::make_unique(), TransformerLevel::Level2)); - ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2, *logger)); - - std::map op_to_count = CountOpsInGraph(graph); - ASSERT_TRUE(op_to_count["Add"] == 0); - ASSERT_TRUE(op_to_count["LayerNormalization"] == 0); - ASSERT_TRUE(op_to_count["com.microsoft.SkipLayerNormalization"] == 1); - ASSERT_TRUE(op_to_count["Cast"] == ((with_cast) ? 2 : 0)); -} - -TEST_F(GraphTransformationTests, SkipLayerNormFusion_NoBeta) { - TestSkipLayerNormFusionNoBeta(MODEL_FOLDER "fusion/skip_layer_norm_no_beta.onnx", false, logger_.get()); - TestSkipLayerNormFusionNoBeta(MODEL_FOLDER "fusion/skip_layer_norm_no_beta_with_cast.onnx", true, logger_.get()); -} - -TEST_F(GraphTransformationTests, EmbedLayerNormFusionFormat1) { - constexpr const ORTCHAR_T* model_uri = MODEL_FOLDER "fusion/embed_layer_norm_format1.onnx"; - std::shared_ptr p_model; - ASSERT_STATUS_OK(Model::Load(model_uri, p_model, nullptr, *logger_)); - Graph& graph = p_model->MainGraph(); - - onnxruntime::GraphTransformerManager graph_transformation_mgr{5}; - ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::make_unique(), TransformerLevel::Level2)); - ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2, *logger_)); - - std::map op_to_count = CountOpsInGraph(graph); - ASSERT_TRUE(op_to_count["Gather"] == 0); - ASSERT_TRUE(op_to_count["Add"] == 0); - ASSERT_TRUE(op_to_count["ReduceSum"] == 1); - ASSERT_TRUE(op_to_count["com.microsoft.Attention"] == 1); - ASSERT_TRUE(op_to_count["com.microsoft.SkipLayerNormalization"] == 0); - ASSERT_TRUE(op_to_count["com.microsoft.EmbedLayerNormalization"] == 1); -} - -TEST_F(GraphTransformationTests, EmbedLayerNormFusionFormat2) { - constexpr const ORTCHAR_T* model_uri = MODEL_FOLDER "fusion/embed_layer_norm_format2.onnx"; - std::shared_ptr p_model; - ASSERT_STATUS_OK(Model::Load(model_uri, p_model, nullptr, *logger_)); - Graph& graph = p_model->MainGraph(); - - onnxruntime::GraphTransformerManager graph_transformation_mgr{5}; - ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::make_unique(), TransformerLevel::Level2)); - ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2, *logger_)); - - std::map op_to_count = CountOpsInGraph(graph); - ASSERT_TRUE(op_to_count["Shape"] == 0); - ASSERT_TRUE(op_to_count["Expand"] == 0); - ASSERT_TRUE(op_to_count["Gather"] == 0); - ASSERT_TRUE(op_to_count["Unsqueeze"] == 0); - ASSERT_TRUE(op_to_count["ConstantOfShape"] == 0); - ASSERT_TRUE(op_to_count["NonZero"] == 0); - ASSERT_TRUE(op_to_count["Transpose"] == 0); - ASSERT_TRUE(op_to_count["Squeeze"] == 0); - ASSERT_TRUE(op_to_count["Add"] == 0); - ASSERT_TRUE(op_to_count["ReduceSum"] == 1); - ASSERT_TRUE(op_to_count["com.microsoft.Attention"] == 1); - ASSERT_TRUE(op_to_count["com.microsoft.SkipLayerNormalization"] == 0); - ASSERT_TRUE(op_to_count["com.microsoft.EmbedLayerNormalization"] == 1); -} - -static void EmbedLayerNormFusionFormat3(const std::basic_string& file_path, logging::Logger* logger) { - std::shared_ptr p_model; - ASSERT_TRUE(Model::Load(file_path, p_model, nullptr, *logger).IsOK()); - Graph& graph = p_model->MainGraph(); - - onnxruntime::GraphTransformerManager graph_transformation_mgr{5}; - ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::make_unique(), TransformerLevel::Level2)); - ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2, *logger)); - - std::map op_to_count = CountOpsInGraph(graph); - EXPECT_EQ(op_to_count["Shape"], 0); - EXPECT_EQ(op_to_count["Expand"], 0); - EXPECT_EQ(op_to_count["Gather"], 0); - EXPECT_EQ(op_to_count["Unsqueeze"], 0); - EXPECT_EQ(op_to_count["LayerNormalization"], 0); - EXPECT_EQ(op_to_count["com.microsoft.SkipLayerNormalization"], 0); - EXPECT_EQ(op_to_count["ReduceSum"], 1); - EXPECT_EQ(op_to_count["MatMul"], 1); - EXPECT_EQ(op_to_count["Add"], 2); - EXPECT_EQ(op_to_count["Cast"], 3); - EXPECT_EQ(op_to_count["com.microsoft.Attention"], 1); - EXPECT_EQ(op_to_count["com.microsoft.EmbedLayerNormalization"], 1); -} - -TEST_F(GraphTransformationTests, EmbedLayerNormFusionFormat3) { - EmbedLayerNormFusionFormat3(MODEL_FOLDER "fusion/embed_layer_norm_format3.onnx", logger_.get()); -} - -TEST_F(GraphTransformationTests, EmbedLayerNormFusionFormat3_OpSet13) { - EmbedLayerNormFusionFormat3(MODEL_FOLDER "fusion/embed_layer_norm_format3_opset13.onnx", logger_.get()); -} - -static void EmbedLayerNormFusionFormat3NoCast(const std::basic_string& file_path, logging::Logger* logger) { - std::shared_ptr p_model; - ASSERT_TRUE(Model::Load(file_path, p_model, nullptr, *logger).IsOK()); - Graph& graph = p_model->MainGraph(); - - onnxruntime::GraphTransformerManager graph_transformation_mgr{5}; - ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::make_unique(), TransformerLevel::Level2)); - ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2, *logger)); - - std::map op_to_count = CountOpsInGraph(graph); - EXPECT_EQ(op_to_count["Shape"], 0); - EXPECT_EQ(op_to_count["Expand"], 0); - EXPECT_EQ(op_to_count["Gather"], 0); - EXPECT_EQ(op_to_count["Unsqueeze"], 0); - EXPECT_EQ(op_to_count["LayerNormalization"], 0); - EXPECT_EQ(op_to_count["com.microsoft.SkipLayerNormalization"], 0); - EXPECT_EQ(op_to_count["ReduceSum"], 1); - EXPECT_EQ(op_to_count["MatMul"], 1); - EXPECT_EQ(op_to_count["Add"], 2); - EXPECT_EQ(op_to_count["Cast"], 3); - EXPECT_EQ(op_to_count["com.microsoft.Attention"], 1); - EXPECT_EQ(op_to_count["com.microsoft.EmbedLayerNormalization"], 1); -} - -TEST_F(GraphTransformationTests, EmbedLayerNormFusionFormat3NoCast) { - EmbedLayerNormFusionFormat3NoCast(MODEL_FOLDER "fusion/embed_layer_norm_format3_no_cast.onnx", logger_.get()); -} - -TEST_F(GraphTransformationTests, EmbedLayerNormFusionFormat3NoCast_OpSet13) { - EmbedLayerNormFusionFormat3NoCast(MODEL_FOLDER "fusion/embed_layer_norm_format3_no_cast_opset13.onnx", logger_.get()); -} - -TEST_F(GraphTransformationTests, EmbedLayerNormFusionFormat4) { - constexpr const ORTCHAR_T* model_uri = MODEL_FOLDER "fusion/embed_layer_norm_format4.onnx"; - std::shared_ptr p_model; - ASSERT_STATUS_OK(Model::Load(model_uri, p_model, nullptr, *logger_)); - Graph& graph = p_model->MainGraph(); - - onnxruntime::GraphTransformerManager graph_transformation_mgr{5}; - ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::make_unique(), TransformerLevel::Level2)); - ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2, *logger_)); - - std::map op_to_count = CountOpsInGraph(graph); - ASSERT_TRUE(op_to_count["Shape"] == 0); - ASSERT_TRUE(op_to_count["Expand"] == 0); - ASSERT_TRUE(op_to_count["Gather"] == 0); - ASSERT_TRUE(op_to_count["Concat"] == 0); - ASSERT_TRUE(op_to_count["Unsqueeze"] == 0); - ASSERT_TRUE(op_to_count["ConstantOfShape"] == 0); - ASSERT_TRUE(op_to_count["NonZero"] == 0); - ASSERT_TRUE(op_to_count["Transpose"] == 0); - ASSERT_TRUE(op_to_count["Squeeze"] == 0); - ASSERT_TRUE(op_to_count["Add"] == 0); - ASSERT_TRUE(op_to_count["ReduceSum"] == 1); - ASSERT_TRUE(op_to_count["com.microsoft.Attention"] == 1); - ASSERT_TRUE(op_to_count["com.microsoft.SkipLayerNormalization"] == 0); - ASSERT_TRUE(op_to_count["com.microsoft.EmbedLayerNormalization"] == 1); -} - -static void EmbedLayerNormFusionFormat5(const std::basic_string& file_path, logging::Logger* logger) { - std::shared_ptr p_model; - ASSERT_TRUE(Model::Load(file_path, p_model, nullptr, *logger).IsOK()); - Graph& graph = p_model->MainGraph(); - - onnxruntime::GraphTransformerManager graph_transformation_mgr{5}; - ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::make_unique(), TransformerLevel::Level2)); - ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2, *logger)); - - std::map op_to_count = CountOpsInGraph(graph); - EXPECT_EQ(op_to_count["Gather"], 0); - EXPECT_EQ(op_to_count["LayerNormalization"], 0); - EXPECT_EQ(op_to_count["com.microsoft.SkipLayerNormalization"], 0); - EXPECT_EQ(op_to_count["ReduceSum"], 1); - EXPECT_EQ(op_to_count["MatMul"], 1); - EXPECT_EQ(op_to_count["Add"], 2); - EXPECT_EQ(op_to_count["Cast"], 3); - EXPECT_EQ(op_to_count["com.microsoft.Attention"], 1); - EXPECT_EQ(op_to_count["com.microsoft.EmbedLayerNormalization"], 1); - - // Validate the position embedding input. - for (const Node& node : graph.Nodes()) { - if (node.OpType() == "EmbedLayerNormalization") { - const ONNX_NAMESPACE::TensorProto* tensor_proto = graph_utils::GetConstantInitializer(graph, node.InputDefs()[3]->Name()); - ASSERT_TRUE(tensor_proto != nullptr); - EXPECT_EQ(tensor_proto->data_type(), ONNX_NAMESPACE::TensorProto_DataType_FLOAT); - - auto initializer = std::make_unique(*tensor_proto, graph.ModelPath()); - EXPECT_EQ(initializer->size(), 12); - - std::vector expected_value = {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 8.0, 7.0, 6.0}; - - const float* data = initializer->data(); - for (size_t i = 0; i < expected_value.size(); i++) { - EXPECT_EQ(data[i], static_cast(expected_value[i])); - } - } - } -} - -TEST_F(GraphTransformationTests, EmbedLayerNormFusionFormat5) { - EmbedLayerNormFusionFormat5(MODEL_FOLDER "fusion/embed_layer_norm_format5.onnx", logger_.get()); -} - -TEST_F(GraphTransformationTests, EmbedLayerNormFusionFormat5_OpSet13) { - EmbedLayerNormFusionFormat5(MODEL_FOLDER "fusion/embed_layer_norm_format5_opset13.onnx", logger_.get()); -} - -static void EmbedLayerNormFusionFormat6(const std::basic_string& file_path, logging::Logger* logger) { - std::shared_ptr p_model; - ASSERT_TRUE(Model::Load(file_path, p_model, nullptr, *logger).IsOK()); - Graph& graph = p_model->MainGraph(); - - onnxruntime::GraphTransformerManager graph_transformation_mgr{5}; - ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::make_unique(), TransformerLevel::Level2)); - ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2, *logger)); - - std::map op_to_count = CountOpsInGraph(graph); - EXPECT_EQ(op_to_count["Shape"], 0); - EXPECT_EQ(op_to_count["Expand"], 0); - EXPECT_EQ(op_to_count["Gather"], 0); - EXPECT_EQ(op_to_count["Unsqueeze"], 0); - EXPECT_EQ(op_to_count["Reshape"], 0); - EXPECT_EQ(op_to_count["Equal"], 0); - EXPECT_EQ(op_to_count["Where"], 0); - EXPECT_EQ(op_to_count["LayerNormalization"], 0); - EXPECT_EQ(op_to_count["com.microsoft.SkipLayerNormalization"], 0); - EXPECT_EQ(op_to_count["ReduceSum"], 1); - EXPECT_EQ(op_to_count["MatMul"], 1); - EXPECT_EQ(op_to_count["Add"], 2); - EXPECT_EQ(op_to_count["Cast"], 3); - EXPECT_EQ(op_to_count["com.microsoft.Attention"], 1); - EXPECT_EQ(op_to_count["com.microsoft.EmbedLayerNormalization"], 1); -} - -TEST_F(GraphTransformationTests, EmbedLayerNormFusionFormat6) { - EmbedLayerNormFusionFormat6(MODEL_FOLDER "fusion/embed_layer_norm_format6.onnx", logger_.get()); -} - -TEST_F(GraphTransformationTests, EmbedLayerNormFusionFormat6_OpSet13) { - EmbedLayerNormFusionFormat6(MODEL_FOLDER "fusion/embed_layer_norm_format6_opset13.onnx", logger_.get()); -} - -static void TestEmbedLayerNormFusionDistilBert(const std::basic_string& model_uri, - std::map& op_to_count, - logging::Logger* logger) { - std::shared_ptr p_model; - ASSERT_STATUS_OK(Model::Load(model_uri, p_model, nullptr, *logger)); - Graph& graph = p_model->MainGraph(); - - onnxruntime::GraphTransformerManager graph_transformation_mgr{5}; - ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::make_unique(), TransformerLevel::Level2)); - ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2, *logger)); - - op_to_count = CountOpsInGraph(graph); -} - -// DistilBert -TEST_F(GraphTransformationTests, EmbedLayerNormFusionFormat7) { - std::map op_to_count; - TestEmbedLayerNormFusionDistilBert(MODEL_FOLDER "fusion/embed_layer_norm_format7.onnx", op_to_count, logger_.get()); - EXPECT_EQ(op_to_count["com.microsoft.EmbedLayerNormalization"], 1); - EXPECT_EQ(op_to_count["com.microsoft.Attention"], 1); - EXPECT_EQ(op_to_count["Cast"], 2); - EXPECT_EQ(op_to_count["Shape"], 0); - EXPECT_EQ(op_to_count["Gather"], 0); - EXPECT_EQ(op_to_count["Unsqueeze"], 0); - EXPECT_EQ(op_to_count["ReduceSum"], 1); -} - -TEST_F(GraphTransformationTests, EmbedLayerNormFusionFormat7_OpSet13) { - std::map op_to_count; - TestEmbedLayerNormFusionDistilBert(MODEL_FOLDER "fusion/embed_layer_norm_format7_opset13.onnx", op_to_count, logger_.get()); - EXPECT_EQ(op_to_count["com.microsoft.EmbedLayerNormalization"], 1); - EXPECT_EQ(op_to_count["com.microsoft.Attention"], 1); - EXPECT_EQ(op_to_count["Cast"], 2); - EXPECT_EQ(op_to_count["Shape"], 0); - EXPECT_EQ(op_to_count["Gather"], 0); - EXPECT_EQ(op_to_count["Unsqueeze"], 0); - EXPECT_EQ(op_to_count["ReduceSum"], 1); -} - -TEST_F(GraphTransformationTests, EmbedLayerNormFusionFormat8) { - std::map op_to_count; - TestEmbedLayerNormFusionDistilBert(MODEL_FOLDER "fusion/embed_layer_norm_format8.onnx", op_to_count, logger_.get()); - EXPECT_EQ(op_to_count["com.microsoft.EmbedLayerNormalization"], 1); - EXPECT_EQ(op_to_count["com.microsoft.Attention"], 1); - EXPECT_EQ(op_to_count["Cast"], 2); - EXPECT_EQ(op_to_count["Shape"], 0); - EXPECT_EQ(op_to_count["Gather"], 0); - EXPECT_EQ(op_to_count["Unsqueeze"], 0); - EXPECT_EQ(op_to_count["ReduceSum"], 1); -} - -TEST_F(GraphTransformationTests, EmbedLayerNormFusionFormat8_OpSet13) { - std::map op_to_count; - TestEmbedLayerNormFusionDistilBert(MODEL_FOLDER "fusion/embed_layer_norm_format8_opset13.onnx", op_to_count, logger_.get()); - EXPECT_EQ(op_to_count["com.microsoft.EmbedLayerNormalization"], 1); - EXPECT_EQ(op_to_count["com.microsoft.Attention"], 1); - EXPECT_EQ(op_to_count["Cast"], 2); - EXPECT_EQ(op_to_count["Shape"], 0); - EXPECT_EQ(op_to_count["Gather"], 0); - EXPECT_EQ(op_to_count["Unsqueeze"], 0); - EXPECT_EQ(op_to_count["ReduceSum"], 1); -} - -TEST_F(GraphTransformationTests, EmbedLayerNormFusionFormat9) { - std::map op_to_count; - TestEmbedLayerNormFusionDistilBert(MODEL_FOLDER "fusion/embed_layer_norm_format9.onnx", op_to_count, logger_.get()); - EXPECT_EQ(op_to_count["com.microsoft.EmbedLayerNormalization"], 1); - EXPECT_EQ(op_to_count["com.microsoft.Attention"], 1); - EXPECT_EQ(op_to_count["Cast"], 2); - EXPECT_EQ(op_to_count["Shape"], 1); - EXPECT_EQ(op_to_count["Gather"], 2); - EXPECT_EQ(op_to_count["Unsqueeze"], 2); - EXPECT_EQ(op_to_count["ReduceSum"], 1); -} - -TEST_F(GraphTransformationTests, EmbedLayerNormFusionFormat9_OpSet13) { - std::map op_to_count; - TestEmbedLayerNormFusionDistilBert(MODEL_FOLDER "fusion/embed_layer_norm_format9_opset13.onnx", op_to_count, logger_.get()); - EXPECT_EQ(op_to_count["com.microsoft.EmbedLayerNormalization"], 1); - EXPECT_EQ(op_to_count["com.microsoft.Attention"], 1); - EXPECT_EQ(op_to_count["Cast"], 2); - EXPECT_EQ(op_to_count["Shape"], 1); - EXPECT_EQ(op_to_count["Gather"], 2); - EXPECT_EQ(op_to_count["Unsqueeze"], 2); - EXPECT_EQ(op_to_count["ReduceSum"], 1); -} - -static void EmbedLayerNormFusionFormatMultiple(const std::basic_string& file_path, logging::Logger* logger) { - std::shared_ptr p_model; - ASSERT_TRUE(Model::Load(file_path, p_model, nullptr, *logger).IsOK()); - Graph& graph = p_model->MainGraph(); - - onnxruntime::GraphTransformerManager graph_transformation_mgr{5}; - ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::make_unique(), TransformerLevel::Level2)); - ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2, *logger)); - - std::map op_to_count = CountOpsInGraph(graph); - EXPECT_EQ(op_to_count["Shape"], 0); - EXPECT_EQ(op_to_count["Expand"], 0); - EXPECT_EQ(op_to_count["Gather"], 0); - EXPECT_EQ(op_to_count["Unsqueeze"], 0); - EXPECT_EQ(op_to_count["LayerNormalization"], 0); - EXPECT_EQ(op_to_count["com.microsoft.SkipLayerNormalization"], 0); - EXPECT_EQ(op_to_count["ReduceSum"], 2); - EXPECT_EQ(op_to_count["MatMul"], 2); - EXPECT_EQ(op_to_count["Add"], 5); - EXPECT_EQ(op_to_count["Cast"], 6); - EXPECT_EQ(op_to_count["com.microsoft.Attention"], 2); - EXPECT_EQ(op_to_count["com.microsoft.EmbedLayerNormalization"], 2); -} - -TEST_F(GraphTransformationTests, EmbedLayerNormFusionMultiple) { - EmbedLayerNormFusionFormatMultiple(MODEL_FOLDER "fusion/embed_layer_norm_multiple.onnx", logger_.get()); -} - -TEST_F(GraphTransformationTests, EmbedLayerNormFusionMultiple_OpSet13) { - EmbedLayerNormFusionFormatMultiple(MODEL_FOLDER "fusion/embed_layer_norm_multiple_opset13.onnx", logger_.get()); -} - TEST_F(GraphTransformationTests, DynamicQuantizeMatMulTest) { constexpr const ORTCHAR_T* model_uri = MODEL_FOLDER "fusion/dynamic_quantize_matmul.onnx"; std::shared_ptr p_model; diff --git a/onnxruntime/test/optimizer/graph_transform_test_layernorm.cc b/onnxruntime/test/optimizer/graph_transform_test_layernorm.cc new file mode 100755 index 0000000000..179d4dcfd0 --- /dev/null +++ b/onnxruntime/test/optimizer/graph_transform_test_layernorm.cc @@ -0,0 +1,927 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +#ifdef _MSC_VER +#pragma warning(push) +#pragma warning(disable : 4244) +#endif + +#include + +#include "gtest/gtest.h" + +#include "core/graph/graph_utils.h" +#include "core/graph/graph_viewer.h" +#include "core/graph/model.h" +#include "core/optimizer/initializer.h" + +#include "core/optimizer/embed_layer_norm_fusion.h" +#include "core/optimizer/layer_norm_fusion.h" +#include "core/optimizer/skip_layer_norm_fusion.h" + +#include "test/capturing_sink.h" +#include "test/framework/test_utils.h" +#include "test/optimizer/graph_transform_test_builder.h" +#include "test/optimizer/graph_transform_test_fixture.h" +#include "test/providers/provider_test_utils.h" + +using namespace std; +using namespace ONNX_NAMESPACE; + +namespace onnxruntime { +namespace test { + +#define MODEL_FOLDER ORT_TSTR("testdata/transform/") + +TEST_F(GraphTransformationTests, LayerNormFusionTest) { + constexpr const ORTCHAR_T* model_uri = MODEL_FOLDER "fusion/layer_norm.onnx"; + std::shared_ptr p_model; + ASSERT_STATUS_OK(Model::Load(model_uri, p_model, nullptr, *logger_)); + Graph& graph = p_model->MainGraph(); + + onnxruntime::GraphTransformerManager graph_transformation_mgr{5}; + ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::make_unique(), TransformerLevel::Level2)); + ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2, *logger_)); + + std::map op_to_count = CountOpsInGraph(graph); + ASSERT_TRUE(op_to_count["Div"] == 0); + ASSERT_TRUE(op_to_count["Add"] == 0); + ASSERT_TRUE(op_to_count["Sub"] == 0); + ASSERT_TRUE(op_to_count["ReduceMean"] == 0); + ASSERT_TRUE(op_to_count["Pow"] == 0); + ASSERT_TRUE(op_to_count["Sqrt"] == 0); + ASSERT_TRUE(op_to_count["LayerNormalization"] == 1); + + for (const Node& node : graph.Nodes()) { + if (node.OpType() == "LayerNormalization") { + // LayerNormalization should have three inputs. + EXPECT_EQ(node.InputDefs().size(), 3u) + << "LayerNormalization number of inputs does not equal to 3. Got:" << node.InputDefs().size(); + // LayerNormalization input "scale" and "bias" should have the same dimension. + const TensorShapeProto* scale_shape = node.InputDefs()[1]->Shape(); + const TensorShapeProto* bias_shape = node.InputDefs()[2]->Shape(); + EXPECT_EQ(scale_shape->dim_size(), 1) + << "LayerNormalization scale should be 1D. Got: " << scale_shape->dim_size(); + EXPECT_EQ(bias_shape->dim_size(), 1) + << "LayerNormalization bias should be 1D. Got: " << bias_shape->dim_size(); + EXPECT_EQ(scale_shape->dim(0).dim_value(), bias_shape->dim(0).dim_value()); + } else { + EXPECT_TRUE(false) << "Unexpected node " << node.Name(); + } + } +} + +TEST_F(GraphTransformationTests, LayerNormWithCastFusionTest) { + constexpr const ORTCHAR_T* model_uri = MODEL_FOLDER "fusion/layer_norm_with_cast.onnx"; + std::shared_ptr p_model; + ASSERT_STATUS_OK(Model::Load(model_uri, p_model, nullptr, *logger_)); + Graph& graph = p_model->MainGraph(); + + onnxruntime::GraphTransformerManager graph_transformation_mgr{5}; + ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::make_unique(), TransformerLevel::Level2)); + ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2, *logger_)); + + std::map op_to_count = CountOpsInGraph(graph); + +#ifdef ENABLE_TRAINING_CORE + ASSERT_TRUE(op_to_count["Cast"] == 0); + ASSERT_TRUE(op_to_count["LayerNormalization"] == 1); +#else + ASSERT_TRUE(op_to_count["Cast"] == 1); + ASSERT_TRUE(op_to_count["LayerNormalization"] == 0); +#endif +} + +TEST_F(GraphTransformationTests, LayerNormWithCastFusionTest_2) { + constexpr const ORTCHAR_T* model_uri = MODEL_FOLDER "fusion/layer_norm_with_cast_2.onnx"; + std::shared_ptr p_model; + ASSERT_STATUS_OK(Model::Load(model_uri, p_model, nullptr, *logger_)); + Graph& graph = p_model->MainGraph(); + + onnxruntime::GraphTransformerManager graph_transformation_mgr{5}; + ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::make_unique(), TransformerLevel::Level2)); + ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2, *logger_)); + + std::map op_to_count = CountOpsInGraph(graph); + + ASSERT_TRUE(op_to_count["Cast"] == 0); + ASSERT_TRUE(op_to_count["LayerNormalization"] == 1); +} + +TEST_F(GraphTransformationTests, LayerNormWithCastFusionTest_3) { + constexpr const ORTCHAR_T* model_uri = MODEL_FOLDER "fusion/layer_norm_with_cast_3.onnx"; + std::shared_ptr p_model; + ASSERT_STATUS_OK(Model::Load(model_uri, p_model, nullptr, *logger_)); + Graph& graph = p_model->MainGraph(); + + onnxruntime::GraphTransformerManager graph_transformation_mgr{5}; + ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::make_unique(), TransformerLevel::Level2)); + ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2, *logger_)); + + std::map op_to_count = CountOpsInGraph(graph); + + ASSERT_TRUE(op_to_count["Cast"] == 0); + ASSERT_TRUE(op_to_count["LayerNormalization"] == 1); +} + +TEST_F(GraphTransformationTests, LayerNormWithCastFusionTest_4) { + constexpr const ORTCHAR_T* model_uri = MODEL_FOLDER "fusion/layer_norm_with_cast_4.onnx"; + std::shared_ptr p_model; + ASSERT_STATUS_OK(Model::Load(model_uri, p_model, nullptr, *logger_)); + Graph& graph = p_model->MainGraph(); + + onnxruntime::GraphTransformerManager graph_transformation_mgr{5}; + ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::make_unique(), TransformerLevel::Level2)); + ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2, *logger_)); + + std::map op_to_count = CountOpsInGraph(graph); + + ASSERT_TRUE(op_to_count["Cast"] == 0); + ASSERT_TRUE(op_to_count["LayerNormalization"] == 1); +} + +/* +ReduceMean: + axes - INTS : A list of integers, along which to reduce. + The default is to reduce over all the dimensions of the input tensor. + Accepted range is [-r, r-1] where r = rank(data). +*/ +TEST_F(GraphTransformationTests, LayerNormWithSubDupFusionTest) { + constexpr const ORTCHAR_T* model_uri = MODEL_FOLDER "fusion/layer_norm_sub_dup.onnx"; + std::shared_ptr p_model; + ASSERT_STATUS_OK(Model::Load(model_uri, p_model, nullptr, *logger_)); + Graph& graph = p_model->MainGraph(); + + onnxruntime::GraphTransformerManager graph_transformation_mgr{5}; + ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::make_unique(), TransformerLevel::Level2)); + ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2, *logger_)); + + std::map op_to_count = CountOpsInGraph(graph); + ASSERT_TRUE(op_to_count["Div"] > 0); + ASSERT_TRUE(op_to_count["Add"] > 0); + ASSERT_TRUE(op_to_count["Sub"] > 0); + ASSERT_TRUE(op_to_count["ReduceMean"] > 0); + ASSERT_TRUE(op_to_count["Pow"] > 0); + ASSERT_TRUE(op_to_count["Sqrt"] > 0); + ASSERT_TRUE(op_to_count["LayerNormalization"] == 0); + /* + for (const Node& node : graph.Nodes()) { + if (node.OpType() == "LayerNormalization") { + // LayerNormalization should have three inputs. + EXPECT_EQ(node.InputDefs().size(), 3u) << "LayerNormalization number of inputs does not equal to 3. Got:" << node.InputDefs().size(); + // LayerNormalization input "scale" and "bias" should have the same dimension. + const TensorShapeProto* scale_shape = node.InputDefs()[1]->Shape(); + const TensorShapeProto* bias_shape = node.InputDefs()[2]->Shape(); + EXPECT_EQ(scale_shape->dim_size(), 1) << "LayerNormalization scale should be 1D. Got: " << scale_shape->dim_size(); + EXPECT_EQ(bias_shape->dim_size(), 1) << "LayerNormalization bias should be 1D. Got: " << bias_shape->dim_size(); + EXPECT_EQ(scale_shape->dim(0).dim_value(), bias_shape->dim(0).dim_value()); + } else { + EXPECT_TRUE(false) << "Unexpected node " << node.Name(); + } + } + */ +} + +void BuildLayerNorm(ModelTestBuilder& builder, std::vector reduce1_axes = {-1}, + std::vector reduce2_axes = {-1}) { + std::vector input_shape = {2, 3, 3, 3}; + auto* data_arg = builder.MakeInput(input_shape); + auto* pow_initializer = builder.MakeInitializer({}, {2.0f}); + auto* add_initializer = builder.MakeInitializer({}, {1e-5f}); + std::vector normalized_shape = {}; + int64_t normalized_shape_size = 1; + auto raxes = reduce1_axes; + std::transform(raxes.begin(), raxes.end(), raxes.begin(), [&input_shape](int64_t i) { + return i < 0 ? i + input_shape.size() : i; + }); + sort(raxes.begin(), raxes.end()); + for (auto axis : raxes) { + normalized_shape.push_back(input_shape[axis]); + normalized_shape_size *= input_shape[axis]; + } + + auto* weight_initializer = builder.MakeInitializer( + normalized_shape, std::vector(normalized_shape_size, MLFloat16(1.0f))); + auto* bias_initializer = builder.MakeInitializer( + normalized_shape, std::vector(normalized_shape_size, MLFloat16(0.0f))); + auto* reduce_mean_out_1 = builder.MakeIntermediate(); + auto* sub_out = builder.MakeIntermediate(); + auto* cast_out_1 = builder.MakeIntermediate(); + auto* pow_out = builder.MakeIntermediate(); + auto* reduce_mean_out_2 = builder.MakeIntermediate(); + auto* add_out_1 = builder.MakeIntermediate(); + auto* sqrt_out = builder.MakeIntermediate(); + auto* div_out = builder.MakeIntermediate(); + auto* cast_out_2 = builder.MakeIntermediate(); + auto* mul_out = builder.MakeIntermediate(); + auto* add_out_2 = builder.MakeOutput(); + auto opset = builder.DomainToVersionMap().find(kOnnxDomain)->second; + + if (opset >= 18) { + int64_t rsize = static_cast(reduce1_axes.size()); + onnxruntime::NodeArg* axes = builder.MakeInitializer({rsize}, reduce1_axes); + builder.AddNode("ReduceMean", {data_arg, axes}, {reduce_mean_out_1}); + } else { + builder.AddNode("ReduceMean", {data_arg}, {reduce_mean_out_1}).AddAttribute("axes", reduce1_axes); + } + builder.AddNode("Sub", {data_arg, reduce_mean_out_1}, {sub_out}); + builder.AddNode("Cast", {sub_out}, {cast_out_1}) + .AddAttribute("to", static_cast(ONNX_NAMESPACE::TensorProto_DataType_FLOAT)); + builder.AddNode("Pow", {cast_out_1, pow_initializer}, {pow_out}); + if (opset >= 18) { + int64_t rsize = static_cast(reduce2_axes.size()); + onnxruntime::NodeArg* axes = builder.MakeInitializer({rsize}, reduce2_axes); + builder.AddNode("ReduceMean", {pow_out, axes}, {reduce_mean_out_2}); + } else { + builder.AddNode("ReduceMean", {pow_out}, {reduce_mean_out_2}).AddAttribute("axes", reduce2_axes); + } + builder.AddNode("Add", {reduce_mean_out_2, add_initializer}, {add_out_1}); + builder.AddNode("Sqrt", {add_out_1}, {sqrt_out}); + builder.AddNode("Div", {cast_out_1, sqrt_out}, {div_out}); + builder.AddNode("Cast", {div_out}, {cast_out_2}) + .AddAttribute("to", static_cast(ONNX_NAMESPACE::TensorProto_DataType_FLOAT16)); + builder.AddNode("Mul", {cast_out_2, weight_initializer}, {mul_out}); + builder.AddNode("Add", {mul_out, bias_initializer}, {add_out_2}); +} + +TEST_F(GraphTransformationTests, LayerNormWithCastFusionTest_5) { + auto build_test_case = [](ModelTestBuilder& builder) { + BuildLayerNorm(builder, {-1}, {-1}); + }; + + auto pre_graph_checker = [&](Graph& graph) { + TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["ReduceMean"] == 2); + TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Sub"] == 1); + TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Cast"] == 2); + TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Pow"] == 1); + TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Add"] == 2); + TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Sqrt"] == 1); + TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Div"] == 1); + TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Mul"] == 1); + return Status::OK(); + }; + + auto post_graph_checker = [&](Graph& graph) { + TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["ReduceMean"] == 0); + TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Sub"] == 0); + TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Cast"] == 0); + TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Pow"] == 0); + TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Add"] == 0); + TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Sqrt"] == 0); + TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Div"] == 0); + TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Mul"] == 0); + TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["LayerNormalization"] == 1); + return Status::OK(); + }; + + std::unique_ptr transformer = std::make_unique(); + ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, {14, 18}, *logger_, std::move(transformer), TransformerLevel::Level1, + 1, pre_graph_checker, post_graph_checker)); +} + +TEST_F(GraphTransformationTests, LayerNormWithCastFusionTest_6) { + auto build_test_case = [](ModelTestBuilder& builder) { + BuildLayerNorm(builder, {-2}, {-1}); + }; + + int num_of_layer_norm = 0; + auto post_graph_checker = [&](Graph& graph) { + TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["ReduceMean"] == 2 - 2 * num_of_layer_norm); + TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Sub"] == 1 - num_of_layer_norm); + TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Cast"] == 2 - 2 * num_of_layer_norm); + TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Pow"] == 1 - num_of_layer_norm); + TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Add"] == 2 - 2 * num_of_layer_norm); + TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Sqrt"] == 1 - num_of_layer_norm); + TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Div"] == 1 - num_of_layer_norm); + TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Mul"] == 1 - num_of_layer_norm); + TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["LayerNormalization"] == num_of_layer_norm); + return Status::OK(); + }; + + std::unique_ptr transformer = std::make_unique(); + ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, {14, 18}, *logger_, std::move(transformer), TransformerLevel::Level1, + 1, nullptr, post_graph_checker)); +} + +TEST_F(GraphTransformationTests, LayerNormWithCastFusionTest_7) { + auto build_test_case = [](ModelTestBuilder& builder) { + BuildLayerNorm(builder, {-2, -1}, {-1, -2}); + }; +#ifdef ENABLE_TRAINING_CORE + int num_of_layer_norm = 1; +#else + int num_of_layer_norm = 0; +#endif + auto post_graph_checker = [&](Graph& graph) { + TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["ReduceMean"] == 2 - 2 * num_of_layer_norm); + TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Sub"] == 1 - num_of_layer_norm); + TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Cast"] == 2 - 2 * num_of_layer_norm); + TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Pow"] == 1 - num_of_layer_norm); + TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Add"] == 2 - 2 * num_of_layer_norm); + TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Sqrt"] == 1 - num_of_layer_norm); + TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Div"] == 1 - num_of_layer_norm); + TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Mul"] == 1 - num_of_layer_norm); + TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["LayerNormalization"] == num_of_layer_norm); + return Status::OK(); + }; + + std::unique_ptr transformer = std::make_unique(); + ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, {14, 18}, *logger_, std::move(transformer), TransformerLevel::Level1, + 1, nullptr, post_graph_checker)); +} + +TEST_F(GraphTransformationTests, LayerNormWithCastFusionTest_8) { + auto build_test_case = [](ModelTestBuilder& builder) { + BuildLayerNorm(builder, {-3, -2, -1}, {-1, -2}); + }; + + int num_of_layer_norm = 0; + auto post_graph_checker = [&](Graph& graph) { + TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["ReduceMean"] == 2 - 2 * num_of_layer_norm); + TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Sub"] == 1 - num_of_layer_norm); + TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Cast"] == 2 - 2 * num_of_layer_norm); + TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Pow"] == 1 - num_of_layer_norm); + TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Add"] == 2 - 2 * num_of_layer_norm); + TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Sqrt"] == 1 - num_of_layer_norm); + TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Div"] == 1 - num_of_layer_norm); + TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Mul"] == 1 - num_of_layer_norm); + TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["LayerNormalization"] == num_of_layer_norm); + return Status::OK(); + }; + + std::unique_ptr transformer = std::make_unique(); + ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, {14, 18}, *logger_, std::move(transformer), TransformerLevel::Level1, + 1, nullptr, post_graph_checker)); +} + +TEST_F(GraphTransformationTests, LayerNormWithCastFusionTest_9) { + auto build_test_case = [](ModelTestBuilder& builder) { + BuildLayerNorm(builder, {2, -1}, {-1, -2}); + }; + +#ifdef ENABLE_TRAINING_CORE + int num_of_layer_norm = 1; +#else + int num_of_layer_norm = 0; +#endif + auto post_graph_checker = [&](Graph& graph) { + TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["ReduceMean"] == 2 - 2 * num_of_layer_norm); + TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Sub"] == 1 - num_of_layer_norm); + TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Cast"] == 2 - 2 * num_of_layer_norm); + TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Pow"] == 1 - num_of_layer_norm); + TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Add"] == 2 - 2 * num_of_layer_norm); + TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Sqrt"] == 1 - num_of_layer_norm); + TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Div"] == 1 - num_of_layer_norm); + TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Mul"] == 1 - num_of_layer_norm); + TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["LayerNormalization"] == num_of_layer_norm); + return Status::OK(); + }; + + std::unique_ptr transformer = std::make_unique(); + ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, {14, 18}, *logger_, std::move(transformer), TransformerLevel::Level1, + 1, nullptr, post_graph_checker)); +} + +TEST_F(GraphTransformationTests, SimplifiedLayerNormFusionTest) { + constexpr const ORTCHAR_T* model_uri = MODEL_FOLDER "fusion/layer_norm_t5.onnx"; + std::shared_ptr p_model; + ASSERT_STATUS_OK(Model::Load(model_uri, p_model, nullptr, *logger_)); + Graph& graph = p_model->MainGraph(); + + onnxruntime::GraphTransformerManager graph_transformation_mgr{5}; + ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::make_unique(), TransformerLevel::Level2)); + ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2, *logger_)); + + std::map op_to_count = CountOpsInGraph(graph); + ASSERT_TRUE(op_to_count["Div"] == 0); + ASSERT_TRUE(op_to_count["Add"] == 0); + ASSERT_TRUE(op_to_count["ReduceMean"] == 0); + ASSERT_TRUE(op_to_count["Pow"] == 0); + ASSERT_TRUE(op_to_count["Sqrt"] == 0); + ASSERT_TRUE(op_to_count["SimplifiedLayerNormalization"] == 1); + + for (const Node& node : graph.Nodes()) { + if (node.OpType() == "SimplifiedLayerNormalization") { + // LayerNormalization should have two inputs. + EXPECT_EQ(node.InputDefs().size(), 2u) << "LayerNormalization number of inputs does not equal to 2. Got:" << node.InputDefs().size(); + // LayerNormalization input "scale" and "bias" should have the same dimension. + const TensorShapeProto* scale_shape = node.InputDefs()[1]->Shape(); + EXPECT_EQ(scale_shape->dim_size(), 1) << "LayerNormalization scale should be 1D. Got: " << scale_shape->dim_size(); + } else { + EXPECT_TRUE(false) << "Unexpected node " << node.Name(); + } + } +} + +// If EP is non-GPU EP or unknown, the sub-graph will be not fused because CPU impl for SimplifiedLayerNormalization +// doesn't support input and scale having different data types. +TEST_F(GraphTransformationTests, SimplifiedLayerNormWithCastsFusionTest) { + constexpr const ORTCHAR_T* model_uri = MODEL_FOLDER "fusion/simplified_layer_norm_with_casts.onnx"; + std::shared_ptr p_model; + ASSERT_STATUS_OK(Model::Load(model_uri, p_model, nullptr, *logger_)); + Graph& graph = p_model->MainGraph(); + + InlinedHashSet compatible_eps; + onnxruntime::GraphTransformerManager graph_transformation_mgr{5}; + ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::make_unique(compatible_eps), + TransformerLevel::Level2)); + ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2, *logger_)); + + std::map op_to_count = CountOpsInGraph(graph); + ASSERT_TRUE(op_to_count["SimplifiedLayerNormalization"] == 0); +} + +TEST_F(GraphTransformationTests, SimplifiedLayerNormWithCastsFusionTestCudaEp) { + constexpr const ORTCHAR_T* model_uri = MODEL_FOLDER "fusion/simplified_layer_norm_with_casts.onnx"; + std::shared_ptr p_model; + ASSERT_STATUS_OK(Model::Load(model_uri, p_model, nullptr, *logger_)); + Graph& graph = p_model->MainGraph(); + for (auto& node : graph.Nodes()) { + node.SetExecutionProviderType(kCudaExecutionProvider); + } + + InlinedHashSet compatible_eps; + onnxruntime::GraphTransformerManager graph_transformation_mgr{5}; + ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::make_unique(compatible_eps), + TransformerLevel::Level2)); + ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2, *logger_)); + + std::map op_to_count = CountOpsInGraph(graph); + ASSERT_TRUE(op_to_count["Div"] == 0); + ASSERT_TRUE(op_to_count["Add"] == 0); + ASSERT_TRUE(op_to_count["ReduceMean"] == 0); + ASSERT_TRUE(op_to_count["Pow"] == 0); + ASSERT_TRUE(op_to_count["Sqrt"] == 0); + ASSERT_TRUE(op_to_count["Cast"] == 0); + ASSERT_TRUE(op_to_count["SimplifiedLayerNormalization"] == 1); + + for (const Node& node : graph.Nodes()) { + if (node.OpType() == "SimplifiedLayerNormalization") { + // LayerNormalization should have two inputs. + EXPECT_EQ(node.InputDefs().size(), 2u) + << "LayerNormalization number of inputs does not equal to 2. Got:" << node.InputDefs().size(); + // LayerNormalization input "scale" and "bias" should have the same dimension. + const TensorShapeProto* scale_shape = node.InputDefs()[1]->Shape(); + EXPECT_EQ(scale_shape->dim_size(), 1) + << "LayerNormalization scale should be 1D. Got: " << scale_shape->dim_size(); + } else if (node.OpType() == "Cast") { + continue; + } else { + EXPECT_TRUE(false) << "Unexpected node " << node.Name(); + } + } +} + +static void TestSkipLayerNormFusion(const std::basic_string& file_path, int add_count, int ln_count, + int skip_ln_count, int cast_count, logging::Logger* logger) { + std::shared_ptr p_model; + ASSERT_TRUE(Model::Load(file_path, p_model, nullptr, *logger).IsOK()); + Graph& graph = p_model->MainGraph(); + + onnxruntime::GraphTransformerManager graph_transformation_mgr{5}; + ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::make_unique(), TransformerLevel::Level2)); + ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::make_unique(), TransformerLevel::Level2)); + ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2, *logger)); + + std::map op_to_count = CountOpsInGraph(graph); + ASSERT_TRUE(op_to_count["Div"] == 0); + ASSERT_TRUE(op_to_count["Add"] == add_count); + ASSERT_TRUE(op_to_count["Sub"] == 0); + ASSERT_TRUE(op_to_count["ReduceMean"] == 0); + ASSERT_TRUE(op_to_count["Pow"] == 0); + ASSERT_TRUE(op_to_count["Sqrt"] == 0); + ASSERT_TRUE(op_to_count["LayerNormalization"] == ln_count); + ASSERT_TRUE(op_to_count["com.microsoft.SkipLayerNormalization"] == skip_ln_count); + ASSERT_TRUE(op_to_count["Cast"] == cast_count); +} + +TEST_F(GraphTransformationTests, SkipLayerNormFusionTest) { + TestSkipLayerNormFusion(MODEL_FOLDER "fusion/skip_layer_norm_format1.onnx", 0, 0, 1, 0, logger_.get()); + TestSkipLayerNormFusion(MODEL_FOLDER "fusion/skip_layer_norm_format2.onnx", 0, 0, 1, 0, logger_.get()); + TestSkipLayerNormFusion(MODEL_FOLDER "fusion/skip_layer_norm_format3.onnx", 0, 0, 1, 0, logger_.get()); + + TestSkipLayerNormFusion(MODEL_FOLDER "fusion/skip_layer_norm_format1_partial.onnx", 1, 0, 1, 0, logger_.get()); + TestSkipLayerNormFusion(MODEL_FOLDER "fusion/skip_layer_norm_format2_partial.onnx", 1, 0, 1, 0, logger_.get()); + TestSkipLayerNormFusion(MODEL_FOLDER "fusion/skip_layer_norm_format3_no_fusion.onnx", 1, 1, 0, 0, logger_.get()); + + TestSkipLayerNormFusion(MODEL_FOLDER "fusion/skip_layer_norm_format1_graph_output.onnx", 1, 0, 1, 0, logger_.get()); + TestSkipLayerNormFusion(MODEL_FOLDER "fusion/skip_layer_norm_format2_graph_output.onnx", 1, 0, 1, 0, logger_.get()); + TestSkipLayerNormFusion(MODEL_FOLDER "fusion/skip_layer_norm_format3_graph_output.onnx", 1, 1, 0, 0, logger_.get()); +} + +TEST_F(GraphTransformationTests, SkipLayerNormFusionWithCastTest) { + TestSkipLayerNormFusion(MODEL_FOLDER "fusion/skip_layer_norm_format1_with_cast.onnx", 0, 0, 1, 3, logger_.get()); + TestSkipLayerNormFusion(MODEL_FOLDER "fusion/skip_layer_norm_format2_with_cast.onnx", 0, 0, 1, 3, logger_.get()); + TestSkipLayerNormFusion(MODEL_FOLDER "fusion/skip_layer_norm_format3_with_cast.onnx", 0, 0, 1, 2, logger_.get()); + + TestSkipLayerNormFusion(MODEL_FOLDER "fusion/skip_layer_norm_format1_partial_with_cast.onnx", 1, 0, 1, 2, logger_.get()); + TestSkipLayerNormFusion(MODEL_FOLDER "fusion/skip_layer_norm_format2_partial_with_cast.onnx", 1, 0, 1, 2, logger_.get()); + TestSkipLayerNormFusion(MODEL_FOLDER "fusion/skip_layer_norm_format3_no_fusion_with_cast.onnx", 1, 1, 0, 0, logger_.get()); + + TestSkipLayerNormFusion(MODEL_FOLDER "fusion/skip_layer_norm_format1_graph_output_with_cast.onnx", 1, 0, 1, 2, logger_.get()); + TestSkipLayerNormFusion(MODEL_FOLDER "fusion/skip_layer_norm_format2_graph_output_with_cast.onnx", 1, 0, 1, 2, logger_.get()); + TestSkipLayerNormFusion(MODEL_FOLDER "fusion/skip_layer_norm_format3_graph_output_with_cast.onnx", 1, 1, 0, 0, logger_.get()); +} + +static void TestSkipLayerNormFusionInputOutputCheck(const std::basic_string& model_uri, bool with_cast, logging::Logger* logger) { + std::shared_ptr p_model; + ASSERT_STATUS_OK(Model::Load(model_uri, p_model, nullptr, *logger)); + Graph& graph = p_model->MainGraph(); + + onnxruntime::GraphTransformerManager graph_transformation_mgr{5}; + ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::make_unique(), TransformerLevel::Level2)); + ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::make_unique(), TransformerLevel::Level2)); + ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2, *logger)); + + for (Node& node : graph.Nodes()) { + if (node.OpType() == "SkipLayerNormalization") { + // check inputs + std::vector& input_defs = node.MutableInputDefs(); + EXPECT_EQ(input_defs.size(), 5u) << "SkipLayerNormalization number of inputs does not equal to 5. Got:" << node.InputDefs().size(); + EXPECT_EQ(input_defs[0]->Name(), ((with_cast) ? "input.1_Float" : "input.1")); + EXPECT_EQ(input_defs[1]->Name(), ((with_cast) ? "6_Float" : "6")); + EXPECT_EQ(input_defs[2]->Name(), "1"); + EXPECT_EQ(input_defs[3]->Name(), "2"); + EXPECT_EQ(input_defs[4]->Name(), ((with_cast) ? "4_Float" : "4")); + + // check outputs + std::vector& output_defs = node.MutableOutputDefs(); +#ifdef ENABLE_TRAINING_CORE + EXPECT_EQ(node.OutputDefs().size(), 3u) << "SkipLayerNormalization number of outputs does not equal to 3. Got:" << node.OutputDefs().size(); +#else + EXPECT_EQ(node.OutputDefs().size(), 1u) << "SkipLayerNormalization number of outputs does not equal to 1. Got:" << node.OutputDefs().size(); +#endif + EXPECT_EQ(output_defs[0]->Name(), "19"); + } else if (node.OpType() == "Cast") { + EXPECT_TRUE(with_cast) << "Unexpected node: " << node.OpType() << "," << node.Name(); + } else { + EXPECT_EQ(node.OpType(), "MatMul") << "Unexpected node: " << node.OpType() << "," << node.Name(); + } + } +} + +TEST_F(GraphTransformationTests, SkipLayerNormFusion_Input_Output_Check) { + TestSkipLayerNormFusionInputOutputCheck(MODEL_FOLDER "fusion/skip_layer_norm_input_output_check.onnx", false, logger_.get()); + TestSkipLayerNormFusionInputOutputCheck(MODEL_FOLDER "fusion/skip_layer_norm_input_output_with_cast_check.onnx", true, logger_.get()); +} + +static void TestSkipLayerNormFusionNoBeta(const std::basic_string& model_uri, bool with_cast, logging::Logger* logger) { + std::shared_ptr p_model; + ASSERT_STATUS_OK(Model::Load(model_uri, p_model, nullptr, *logger)); + Graph& graph = p_model->MainGraph(); + + onnxruntime::GraphTransformerManager graph_transformation_mgr{5}; + ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::make_unique(), TransformerLevel::Level2)); + ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2, *logger)); + + std::map op_to_count = CountOpsInGraph(graph); + ASSERT_TRUE(op_to_count["Add"] == 0); + ASSERT_TRUE(op_to_count["LayerNormalization"] == 0); + ASSERT_TRUE(op_to_count["com.microsoft.SkipLayerNormalization"] == 1); + ASSERT_TRUE(op_to_count["Cast"] == ((with_cast) ? 2 : 0)); +} + +TEST_F(GraphTransformationTests, SkipLayerNormFusion_NoBeta) { + TestSkipLayerNormFusionNoBeta(MODEL_FOLDER "fusion/skip_layer_norm_no_beta.onnx", false, logger_.get()); + TestSkipLayerNormFusionNoBeta(MODEL_FOLDER "fusion/skip_layer_norm_no_beta_with_cast.onnx", true, logger_.get()); +} + +TEST_F(GraphTransformationTests, EmbedLayerNormFusionFormat1) { + constexpr const ORTCHAR_T* model_uri = MODEL_FOLDER "fusion/embed_layer_norm_format1.onnx"; + std::shared_ptr p_model; + ASSERT_STATUS_OK(Model::Load(model_uri, p_model, nullptr, *logger_)); + Graph& graph = p_model->MainGraph(); + + onnxruntime::GraphTransformerManager graph_transformation_mgr{5}; + ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::make_unique(), TransformerLevel::Level2)); + ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2, *logger_)); + + std::map op_to_count = CountOpsInGraph(graph); + ASSERT_TRUE(op_to_count["Gather"] == 0); + ASSERT_TRUE(op_to_count["Add"] == 0); + ASSERT_TRUE(op_to_count["ReduceSum"] == 1); + ASSERT_TRUE(op_to_count["com.microsoft.Attention"] == 1); + ASSERT_TRUE(op_to_count["com.microsoft.SkipLayerNormalization"] == 0); + ASSERT_TRUE(op_to_count["com.microsoft.EmbedLayerNormalization"] == 1); +} + +TEST_F(GraphTransformationTests, EmbedLayerNormFusionFormat2) { + constexpr const ORTCHAR_T* model_uri = MODEL_FOLDER "fusion/embed_layer_norm_format2.onnx"; + std::shared_ptr p_model; + ASSERT_STATUS_OK(Model::Load(model_uri, p_model, nullptr, *logger_)); + Graph& graph = p_model->MainGraph(); + + onnxruntime::GraphTransformerManager graph_transformation_mgr{5}; + ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::make_unique(), TransformerLevel::Level2)); + ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2, *logger_)); + + std::map op_to_count = CountOpsInGraph(graph); + ASSERT_TRUE(op_to_count["Shape"] == 0); + ASSERT_TRUE(op_to_count["Expand"] == 0); + ASSERT_TRUE(op_to_count["Gather"] == 0); + ASSERT_TRUE(op_to_count["Unsqueeze"] == 0); + ASSERT_TRUE(op_to_count["ConstantOfShape"] == 0); + ASSERT_TRUE(op_to_count["NonZero"] == 0); + ASSERT_TRUE(op_to_count["Transpose"] == 0); + ASSERT_TRUE(op_to_count["Squeeze"] == 0); + ASSERT_TRUE(op_to_count["Add"] == 0); + ASSERT_TRUE(op_to_count["ReduceSum"] == 1); + ASSERT_TRUE(op_to_count["com.microsoft.Attention"] == 1); + ASSERT_TRUE(op_to_count["com.microsoft.SkipLayerNormalization"] == 0); + ASSERT_TRUE(op_to_count["com.microsoft.EmbedLayerNormalization"] == 1); +} + +static void EmbedLayerNormFusionFormat3(const std::basic_string& file_path, logging::Logger* logger) { + std::shared_ptr p_model; + ASSERT_TRUE(Model::Load(file_path, p_model, nullptr, *logger).IsOK()); + Graph& graph = p_model->MainGraph(); + + onnxruntime::GraphTransformerManager graph_transformation_mgr{5}; + ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::make_unique(), TransformerLevel::Level2)); + ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2, *logger)); + + std::map op_to_count = CountOpsInGraph(graph); + EXPECT_EQ(op_to_count["Shape"], 0); + EXPECT_EQ(op_to_count["Expand"], 0); + EXPECT_EQ(op_to_count["Gather"], 0); + EXPECT_EQ(op_to_count["Unsqueeze"], 0); + EXPECT_EQ(op_to_count["LayerNormalization"], 0); + EXPECT_EQ(op_to_count["com.microsoft.SkipLayerNormalization"], 0); + EXPECT_EQ(op_to_count["ReduceSum"], 1); + EXPECT_EQ(op_to_count["MatMul"], 1); + EXPECT_EQ(op_to_count["Add"], 2); + EXPECT_EQ(op_to_count["Cast"], 3); + EXPECT_EQ(op_to_count["com.microsoft.Attention"], 1); + EXPECT_EQ(op_to_count["com.microsoft.EmbedLayerNormalization"], 1); +} + +TEST_F(GraphTransformationTests, EmbedLayerNormFusionFormat3) { + EmbedLayerNormFusionFormat3(MODEL_FOLDER "fusion/embed_layer_norm_format3.onnx", logger_.get()); +} + +TEST_F(GraphTransformationTests, EmbedLayerNormFusionFormat3_OpSet13) { + EmbedLayerNormFusionFormat3(MODEL_FOLDER "fusion/embed_layer_norm_format3_opset13.onnx", logger_.get()); +} + +static void EmbedLayerNormFusionFormat3NoCast(const std::basic_string& file_path, logging::Logger* logger) { + std::shared_ptr p_model; + ASSERT_TRUE(Model::Load(file_path, p_model, nullptr, *logger).IsOK()); + Graph& graph = p_model->MainGraph(); + + onnxruntime::GraphTransformerManager graph_transformation_mgr{5}; + ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::make_unique(), TransformerLevel::Level2)); + ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2, *logger)); + + std::map op_to_count = CountOpsInGraph(graph); + EXPECT_EQ(op_to_count["Shape"], 0); + EXPECT_EQ(op_to_count["Expand"], 0); + EXPECT_EQ(op_to_count["Gather"], 0); + EXPECT_EQ(op_to_count["Unsqueeze"], 0); + EXPECT_EQ(op_to_count["LayerNormalization"], 0); + EXPECT_EQ(op_to_count["com.microsoft.SkipLayerNormalization"], 0); + EXPECT_EQ(op_to_count["ReduceSum"], 1); + EXPECT_EQ(op_to_count["MatMul"], 1); + EXPECT_EQ(op_to_count["Add"], 2); + EXPECT_EQ(op_to_count["Cast"], 3); + EXPECT_EQ(op_to_count["com.microsoft.Attention"], 1); + EXPECT_EQ(op_to_count["com.microsoft.EmbedLayerNormalization"], 1); +} + +TEST_F(GraphTransformationTests, EmbedLayerNormFusionFormat3NoCast) { + EmbedLayerNormFusionFormat3NoCast(MODEL_FOLDER "fusion/embed_layer_norm_format3_no_cast.onnx", logger_.get()); +} + +TEST_F(GraphTransformationTests, EmbedLayerNormFusionFormat3NoCast_OpSet13) { + EmbedLayerNormFusionFormat3NoCast(MODEL_FOLDER "fusion/embed_layer_norm_format3_no_cast_opset13.onnx", logger_.get()); +} + +TEST_F(GraphTransformationTests, EmbedLayerNormFusionFormat4) { + constexpr const ORTCHAR_T* model_uri = MODEL_FOLDER "fusion/embed_layer_norm_format4.onnx"; + std::shared_ptr p_model; + ASSERT_STATUS_OK(Model::Load(model_uri, p_model, nullptr, *logger_)); + Graph& graph = p_model->MainGraph(); + + onnxruntime::GraphTransformerManager graph_transformation_mgr{5}; + ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::make_unique(), TransformerLevel::Level2)); + ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2, *logger_)); + + std::map op_to_count = CountOpsInGraph(graph); + ASSERT_TRUE(op_to_count["Shape"] == 0); + ASSERT_TRUE(op_to_count["Expand"] == 0); + ASSERT_TRUE(op_to_count["Gather"] == 0); + ASSERT_TRUE(op_to_count["Concat"] == 0); + ASSERT_TRUE(op_to_count["Unsqueeze"] == 0); + ASSERT_TRUE(op_to_count["ConstantOfShape"] == 0); + ASSERT_TRUE(op_to_count["NonZero"] == 0); + ASSERT_TRUE(op_to_count["Transpose"] == 0); + ASSERT_TRUE(op_to_count["Squeeze"] == 0); + ASSERT_TRUE(op_to_count["Add"] == 0); + ASSERT_TRUE(op_to_count["ReduceSum"] == 1); + ASSERT_TRUE(op_to_count["com.microsoft.Attention"] == 1); + ASSERT_TRUE(op_to_count["com.microsoft.SkipLayerNormalization"] == 0); + ASSERT_TRUE(op_to_count["com.microsoft.EmbedLayerNormalization"] == 1); +} + +static void EmbedLayerNormFusionFormat5(const std::basic_string& file_path, logging::Logger* logger) { + std::shared_ptr p_model; + ASSERT_TRUE(Model::Load(file_path, p_model, nullptr, *logger).IsOK()); + Graph& graph = p_model->MainGraph(); + + onnxruntime::GraphTransformerManager graph_transformation_mgr{5}; + ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::make_unique(), TransformerLevel::Level2)); + ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2, *logger)); + + std::map op_to_count = CountOpsInGraph(graph); + EXPECT_EQ(op_to_count["Gather"], 0); + EXPECT_EQ(op_to_count["LayerNormalization"], 0); + EXPECT_EQ(op_to_count["com.microsoft.SkipLayerNormalization"], 0); + EXPECT_EQ(op_to_count["ReduceSum"], 1); + EXPECT_EQ(op_to_count["MatMul"], 1); + EXPECT_EQ(op_to_count["Add"], 2); + EXPECT_EQ(op_to_count["Cast"], 3); + EXPECT_EQ(op_to_count["com.microsoft.Attention"], 1); + EXPECT_EQ(op_to_count["com.microsoft.EmbedLayerNormalization"], 1); + + // Validate the position embedding input. + for (const Node& node : graph.Nodes()) { + if (node.OpType() == "EmbedLayerNormalization") { + const ONNX_NAMESPACE::TensorProto* tensor_proto = graph_utils::GetConstantInitializer(graph, node.InputDefs()[3]->Name()); + ASSERT_TRUE(tensor_proto != nullptr); + EXPECT_EQ(tensor_proto->data_type(), ONNX_NAMESPACE::TensorProto_DataType_FLOAT); + + auto initializer = std::make_unique(*tensor_proto, graph.ModelPath()); + EXPECT_EQ(initializer->size(), 12); + + std::vector expected_value = {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 8.0, 7.0, 6.0}; + + const float* data = initializer->data(); + for (size_t i = 0; i < expected_value.size(); i++) { + EXPECT_EQ(data[i], static_cast(expected_value[i])); + } + } + } +} + +TEST_F(GraphTransformationTests, EmbedLayerNormFusionFormat5) { + EmbedLayerNormFusionFormat5(MODEL_FOLDER "fusion/embed_layer_norm_format5.onnx", logger_.get()); +} + +TEST_F(GraphTransformationTests, EmbedLayerNormFusionFormat5_OpSet13) { + EmbedLayerNormFusionFormat5(MODEL_FOLDER "fusion/embed_layer_norm_format5_opset13.onnx", logger_.get()); +} + +static void EmbedLayerNormFusionFormat6(const std::basic_string& file_path, logging::Logger* logger) { + std::shared_ptr p_model; + ASSERT_TRUE(Model::Load(file_path, p_model, nullptr, *logger).IsOK()); + Graph& graph = p_model->MainGraph(); + + onnxruntime::GraphTransformerManager graph_transformation_mgr{5}; + ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::make_unique(), TransformerLevel::Level2)); + ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2, *logger)); + + std::map op_to_count = CountOpsInGraph(graph); + EXPECT_EQ(op_to_count["Shape"], 0); + EXPECT_EQ(op_to_count["Expand"], 0); + EXPECT_EQ(op_to_count["Gather"], 0); + EXPECT_EQ(op_to_count["Unsqueeze"], 0); + EXPECT_EQ(op_to_count["Reshape"], 0); + EXPECT_EQ(op_to_count["Equal"], 0); + EXPECT_EQ(op_to_count["Where"], 0); + EXPECT_EQ(op_to_count["LayerNormalization"], 0); + EXPECT_EQ(op_to_count["com.microsoft.SkipLayerNormalization"], 0); + EXPECT_EQ(op_to_count["ReduceSum"], 1); + EXPECT_EQ(op_to_count["MatMul"], 1); + EXPECT_EQ(op_to_count["Add"], 2); + EXPECT_EQ(op_to_count["Cast"], 3); + EXPECT_EQ(op_to_count["com.microsoft.Attention"], 1); + EXPECT_EQ(op_to_count["com.microsoft.EmbedLayerNormalization"], 1); +} + +TEST_F(GraphTransformationTests, EmbedLayerNormFusionFormat6) { + EmbedLayerNormFusionFormat6(MODEL_FOLDER "fusion/embed_layer_norm_format6.onnx", logger_.get()); +} + +TEST_F(GraphTransformationTests, EmbedLayerNormFusionFormat6_OpSet13) { + EmbedLayerNormFusionFormat6(MODEL_FOLDER "fusion/embed_layer_norm_format6_opset13.onnx", logger_.get()); +} + +static void TestEmbedLayerNormFusionDistilBert(const std::basic_string& model_uri, + std::map& op_to_count, + logging::Logger* logger) { + std::shared_ptr p_model; + ASSERT_STATUS_OK(Model::Load(model_uri, p_model, nullptr, *logger)); + Graph& graph = p_model->MainGraph(); + + onnxruntime::GraphTransformerManager graph_transformation_mgr{5}; + ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::make_unique(), TransformerLevel::Level2)); + ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2, *logger)); + + op_to_count = CountOpsInGraph(graph); +} + +// DistilBert +TEST_F(GraphTransformationTests, EmbedLayerNormFusionFormat7) { + std::map op_to_count; + TestEmbedLayerNormFusionDistilBert(MODEL_FOLDER "fusion/embed_layer_norm_format7.onnx", op_to_count, logger_.get()); + EXPECT_EQ(op_to_count["com.microsoft.EmbedLayerNormalization"], 1); + EXPECT_EQ(op_to_count["com.microsoft.Attention"], 1); + EXPECT_EQ(op_to_count["Cast"], 2); + EXPECT_EQ(op_to_count["Shape"], 0); + EXPECT_EQ(op_to_count["Gather"], 0); + EXPECT_EQ(op_to_count["Unsqueeze"], 0); + EXPECT_EQ(op_to_count["ReduceSum"], 1); +} + +TEST_F(GraphTransformationTests, EmbedLayerNormFusionFormat7_OpSet13) { + std::map op_to_count; + TestEmbedLayerNormFusionDistilBert(MODEL_FOLDER "fusion/embed_layer_norm_format7_opset13.onnx", op_to_count, logger_.get()); + EXPECT_EQ(op_to_count["com.microsoft.EmbedLayerNormalization"], 1); + EXPECT_EQ(op_to_count["com.microsoft.Attention"], 1); + EXPECT_EQ(op_to_count["Cast"], 2); + EXPECT_EQ(op_to_count["Shape"], 0); + EXPECT_EQ(op_to_count["Gather"], 0); + EXPECT_EQ(op_to_count["Unsqueeze"], 0); + EXPECT_EQ(op_to_count["ReduceSum"], 1); +} + +TEST_F(GraphTransformationTests, EmbedLayerNormFusionFormat8) { + std::map op_to_count; + TestEmbedLayerNormFusionDistilBert(MODEL_FOLDER "fusion/embed_layer_norm_format8.onnx", op_to_count, logger_.get()); + EXPECT_EQ(op_to_count["com.microsoft.EmbedLayerNormalization"], 1); + EXPECT_EQ(op_to_count["com.microsoft.Attention"], 1); + EXPECT_EQ(op_to_count["Cast"], 2); + EXPECT_EQ(op_to_count["Shape"], 0); + EXPECT_EQ(op_to_count["Gather"], 0); + EXPECT_EQ(op_to_count["Unsqueeze"], 0); + EXPECT_EQ(op_to_count["ReduceSum"], 1); +} + +TEST_F(GraphTransformationTests, EmbedLayerNormFusionFormat8_OpSet13) { + std::map op_to_count; + TestEmbedLayerNormFusionDistilBert(MODEL_FOLDER "fusion/embed_layer_norm_format8_opset13.onnx", op_to_count, logger_.get()); + EXPECT_EQ(op_to_count["com.microsoft.EmbedLayerNormalization"], 1); + EXPECT_EQ(op_to_count["com.microsoft.Attention"], 1); + EXPECT_EQ(op_to_count["Cast"], 2); + EXPECT_EQ(op_to_count["Shape"], 0); + EXPECT_EQ(op_to_count["Gather"], 0); + EXPECT_EQ(op_to_count["Unsqueeze"], 0); + EXPECT_EQ(op_to_count["ReduceSum"], 1); +} + +TEST_F(GraphTransformationTests, EmbedLayerNormFusionFormat9) { + std::map op_to_count; + TestEmbedLayerNormFusionDistilBert(MODEL_FOLDER "fusion/embed_layer_norm_format9.onnx", op_to_count, logger_.get()); + EXPECT_EQ(op_to_count["com.microsoft.EmbedLayerNormalization"], 1); + EXPECT_EQ(op_to_count["com.microsoft.Attention"], 1); + EXPECT_EQ(op_to_count["Cast"], 2); + EXPECT_EQ(op_to_count["Shape"], 1); + EXPECT_EQ(op_to_count["Gather"], 2); + EXPECT_EQ(op_to_count["Unsqueeze"], 2); + EXPECT_EQ(op_to_count["ReduceSum"], 1); +} + +TEST_F(GraphTransformationTests, EmbedLayerNormFusionFormat9_OpSet13) { + std::map op_to_count; + TestEmbedLayerNormFusionDistilBert(MODEL_FOLDER "fusion/embed_layer_norm_format9_opset13.onnx", op_to_count, logger_.get()); + EXPECT_EQ(op_to_count["com.microsoft.EmbedLayerNormalization"], 1); + EXPECT_EQ(op_to_count["com.microsoft.Attention"], 1); + EXPECT_EQ(op_to_count["Cast"], 2); + EXPECT_EQ(op_to_count["Shape"], 1); + EXPECT_EQ(op_to_count["Gather"], 2); + EXPECT_EQ(op_to_count["Unsqueeze"], 2); + EXPECT_EQ(op_to_count["ReduceSum"], 1); +} + +static void EmbedLayerNormFusionFormatMultiple(const std::basic_string& file_path, logging::Logger* logger) { + std::shared_ptr p_model; + ASSERT_TRUE(Model::Load(file_path, p_model, nullptr, *logger).IsOK()); + Graph& graph = p_model->MainGraph(); + + onnxruntime::GraphTransformerManager graph_transformation_mgr{5}; + ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::make_unique(), TransformerLevel::Level2)); + ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2, *logger)); + + std::map op_to_count = CountOpsInGraph(graph); + EXPECT_EQ(op_to_count["Shape"], 0); + EXPECT_EQ(op_to_count["Expand"], 0); + EXPECT_EQ(op_to_count["Gather"], 0); + EXPECT_EQ(op_to_count["Unsqueeze"], 0); + EXPECT_EQ(op_to_count["LayerNormalization"], 0); + EXPECT_EQ(op_to_count["com.microsoft.SkipLayerNormalization"], 0); + EXPECT_EQ(op_to_count["ReduceSum"], 2); + EXPECT_EQ(op_to_count["MatMul"], 2); + EXPECT_EQ(op_to_count["Add"], 5); + EXPECT_EQ(op_to_count["Cast"], 6); + EXPECT_EQ(op_to_count["com.microsoft.Attention"], 2); + EXPECT_EQ(op_to_count["com.microsoft.EmbedLayerNormalization"], 2); +} + +TEST_F(GraphTransformationTests, EmbedLayerNormFusionMultiple) { + EmbedLayerNormFusionFormatMultiple(MODEL_FOLDER "fusion/embed_layer_norm_multiple.onnx", logger_.get()); +} + +TEST_F(GraphTransformationTests, EmbedLayerNormFusionMultiple_OpSet13) { + EmbedLayerNormFusionFormatMultiple(MODEL_FOLDER "fusion/embed_layer_norm_multiple_opset13.onnx", logger_.get()); +} + +} // namespace test +} // namespace onnxruntime