diff --git a/onnxruntime/test/optimizer/graph_transform_test_layernorm.cc b/onnxruntime/test/optimizer/graph_transform_test_layernorm.cc index 0706040b7d..0419a54bd0 100755 --- a/onnxruntime/test/optimizer/graph_transform_test_layernorm.cc +++ b/onnxruntime/test/optimizer/graph_transform_test_layernorm.cc @@ -72,6 +72,21 @@ TEST_F(GraphTransformationTests, LayerNormFusionTest) { } } +TEST_F(GraphTransformationTests, TwoLayerNormShareSameInput) { + constexpr const ORTCHAR_T* model_uri = MODEL_FOLDER "fusion/layer_norm_shared_input.onnx"; + std::shared_ptr p_model; + ASSERT_STATUS_OK(Model::Load(model_uri, p_model, nullptr, *logger_)); + Graph& graph = p_model->MainGraph(); + + onnxruntime::GraphTransformerManager graph_transformation_mgr{5}; + ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::make_unique(), TransformerLevel::Level2)); + ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2, *logger_)); + + std::map op_to_count = CountOpsInGraph(graph); + ASSERT_TRUE(op_to_count.size() == 1); + ASSERT_TRUE(op_to_count["LayerNormalization"] == 2); +} + TEST_F(GraphTransformationTests, LayerNormWithCastFusionTest) { constexpr const ORTCHAR_T* model_uri = MODEL_FOLDER "fusion/layer_norm_with_cast.onnx"; std::shared_ptr p_model; diff --git a/onnxruntime/test/testdata/transform/fusion/layer_norm_shared_input.onnx b/onnxruntime/test/testdata/transform/fusion/layer_norm_shared_input.onnx new file mode 100644 index 0000000000..ae13881d68 Binary files /dev/null and b/onnxruntime/test/testdata/transform/fusion/layer_norm_shared_input.onnx differ diff --git a/onnxruntime/test/testdata/transform/fusion/layer_norm_shared_input.py b/onnxruntime/test/testdata/transform/fusion/layer_norm_shared_input.py new file mode 100644 index 0000000000..0b838ca88e --- /dev/null +++ b/onnxruntime/test/testdata/transform/fusion/layer_norm_shared_input.py @@ -0,0 +1,71 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +import onnx +from onnx import OperatorSetIdProto, TensorProto, helper + + +# in gpt_j_residual, there will be 2 LN share the same input +def GenerateModel(model_name): # noqa: N802 + nodes = [ + # LN1 subgraph + helper.make_node("ReduceMean", ["A"], ["LN1/rd1_out"], "LN1/reduce", axes=[-1]), + helper.make_node("Sub", ["A", "LN1/rd1_out"], ["LN1/sub1_out"], "LN1/sub"), + helper.make_node("Pow", ["LN1/sub1_out", "LN1/pow_in_2"], ["LN1/pow_out"], "LN1/pow"), + helper.make_node("ReduceMean", ["LN1/pow_out"], ["LN1/rd2_out"], "LN1/reduce2", axes=[-1]), + helper.make_node("Add", ["LN1/rd2_out", "LN1/const_0"], ["LN1/add1_out"], "LN1/add"), + helper.make_node("Sqrt", ["LN1/add1_out"], ["LN1/sqrt_out"], "LN1/sqrt"), + helper.make_node("Div", ["LN1/sub1_out", "LN1/sqrt_out"], ["LN1/div_out"], "LN1/div"), + helper.make_node("Mul", ["LN1/gamma", "LN1/div_out"], ["LN1/mul_out"], "LN1/mul"), + helper.make_node("Add", ["LN1/beta", "LN1/mul_out"], ["LN1/C"], "LN1/add2"), + # LN2 subgraph + helper.make_node("ReduceMean", ["A"], ["LN2/rd1_out"], "LN2/reduce", axes=[-1]), + helper.make_node("Sub", ["A", "LN2/rd1_out"], ["LN2/sub1_out"], "LN2/sub"), + helper.make_node("Pow", ["LN2/sub1_out", "LN2/pow_in_2"], ["LN2/pow_out"], "LN2/pow"), + helper.make_node("ReduceMean", ["LN2/pow_out"], ["LN2/rd2_out"], "LN2/reduce2", axes=[-1]), + helper.make_node("Add", ["LN2/rd2_out", "LN2/const_0"], ["LN2/add1_out"], "LN2/add"), + helper.make_node("Sqrt", ["LN2/add1_out"], ["LN2/sqrt_out"], "LN2/sqrt"), + helper.make_node("Div", ["LN2/sub1_out", "LN2/sqrt_out"], ["LN2/div_out"], "LN2/div"), + helper.make_node("Mul", ["LN2/gamma", "LN2/div_out"], ["LN2/mul_out"], "LN2/mul"), + helper.make_node("Add", ["LN2/beta", "LN2/mul_out"], ["LN2/C"], "LN2/add2"), + ] + + initializers = [ + # LN1 initializers + helper.make_tensor("LN1/pow_in_2", TensorProto.FLOAT, [], [2]), + helper.make_tensor("LN1/const_0", TensorProto.FLOAT, [], [0]), + helper.make_tensor("LN1/gamma", TensorProto.FLOAT, [4], [1, 2, 3, 4]), + helper.make_tensor("LN1/beta", TensorProto.FLOAT, [4], [1, 2, 3, 4]), + # LN2 initializers + helper.make_tensor("LN2/pow_in_2", TensorProto.FLOAT, [], [2]), + helper.make_tensor("LN2/const_0", TensorProto.FLOAT, [], [0]), + helper.make_tensor("LN2/gamma", TensorProto.FLOAT, [4], [1, 2, 3, 4]), + helper.make_tensor("LN2/beta", TensorProto.FLOAT, [4], [1, 2, 3, 4]), + ] + + graph = helper.make_graph( + nodes, + "2LayerNormShareSameInput", # name + [ # inputs + helper.make_tensor_value_info("A", TensorProto.FLOAT, [16, 32, 4]), + ], + [ # outputs + helper.make_tensor_value_info("LN1/C", TensorProto.FLOAT, [16, 32, 4]), + helper.make_tensor_value_info("LN2/C", TensorProto.FLOAT, [16, 32, 4]), + ], + initializers, + ) + + onnxdomain = OperatorSetIdProto() + onnxdomain.version = 12 + # The empty string ("") or absence of this field implies the operator set that is defined as part of the ONNX specification. + onnxdomain.domain = "" + msdomain = OperatorSetIdProto() + msdomain.version = 1 + msdomain.domain = "com.microsoft" + opsets = [onnxdomain, msdomain] + + model = helper.make_model(graph, opset_imports=opsets) + onnx.save(model, model_name) + + +GenerateModel("layer_norm_shared_input.onnx") diff --git a/orttraining/orttraining/core/optimizer/graph_transformer_utils.cc b/orttraining/orttraining/core/optimizer/graph_transformer_utils.cc index 6cf850c57e..92816fe6e3 100644 --- a/orttraining/orttraining/core/optimizer/graph_transformer_utils.cc +++ b/orttraining/orttraining/core/optimizer/graph_transformer_utils.cc @@ -109,6 +109,8 @@ std::vector> GeneratePreTrainingTransformers( // CSE. For example, if A and B nodes both do Add operation with a same value but different initializers, by // default, CSE will not merge them, because the different initializers are represented by different NodeArg. transformers.emplace_back(std::make_unique(compatible_eps)); + // LayerNormFusion must be applied before CommonSubexpressionElimination as the latter will break the pattern when 2 LayerNormFusion share the same input. + transformers.emplace_back(std::make_unique(compatible_eps)); // Remove duplicate nodes. Must be applied before any recompute transformations. if (config.gelu_recompute || config.attn_dropout_recompute || config.transformer_layer_recompute) { transformers.emplace_back(std::make_unique(compatible_eps)); @@ -117,7 +119,6 @@ std::vector> GeneratePreTrainingTransformers( } transformers.emplace_back(std::make_unique(compatible_eps)); - transformers.emplace_back(std::make_unique(compatible_eps)); #if defined(USE_CUDA) || defined(USE_ROCM) transformers.emplace_back(std::make_unique(compatible_eps, true /* skip_device_check*/));