From d56fc7ebf5377abc96db728eafaffd8bf79a3b81 Mon Sep 17 00:00:00 2001 From: Abhishek Jindal Date: Thu, 21 Sep 2023 14:16:41 -0700 Subject: [PATCH] Layer norm fusion deepspeed stage3 changes (#17614) ### Description Layer norm fusion changes required for deepspeed stage 3, also includes test case. ### Motivation and Context It helps fusing layer norm for Deepspeed Stage 3. Added a test case scenario which ensures that the fusion is working properly for the scenario. --- .../core/optimizer/layer_norm_fusion.cc | 46 +++++----- .../graph_transform_test_layernorm.cc | 34 ++++++++ .../fusion/layer_norm_fusion_scale_bias.onnx | Bin 0 -> 854 bytes .../fusion/layer_norm_fusion_scale_bias.py | 81 ++++++++++++++++++ 4 files changed, 138 insertions(+), 23 deletions(-) create mode 100644 onnxruntime/test/testdata/transform/fusion/layer_norm_fusion_scale_bias.onnx create mode 100644 onnxruntime/test/testdata/transform/fusion/layer_norm_fusion_scale_bias.py diff --git a/onnxruntime/core/optimizer/layer_norm_fusion.cc b/onnxruntime/core/optimizer/layer_norm_fusion.cc index bf36f11521..159e3b23d1 100644 --- a/onnxruntime/core/optimizer/layer_norm_fusion.cc +++ b/onnxruntime/core/optimizer/layer_norm_fusion.cc @@ -414,20 +414,20 @@ Status LayerNormFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level, NodeArg* scale = nullptr; NodeArg* bias = nullptr; for (size_t i = 0; i < mul_node.MutableInputDefs().size(); i++) { - if (graph_utils::NodeArgIsConstant(graph, *(mul_node.MutableInputDefs()[i])) || - graph_utils::IsGraphInput(graph, mul_node.MutableInputDefs()[i])) { - if (mul_node.MutableInputDefs()[i]->Shape()->dim_size() == static_cast(axes_values.size())) { - scale = mul_node.MutableInputDefs()[i]; - } + if (mul_node.MutableInputDefs()[i]->Shape() == nullptr) { + continue; + } + if (mul_node.MutableInputDefs()[i]->Shape()->dim_size() == static_cast(axes_values.size())) { + scale = mul_node.MutableInputDefs()[i]; } } for (size_t i = 0; i < last_add_node.MutableInputDefs().size(); i++) { - if (graph_utils::NodeArgIsConstant(graph, *(last_add_node.MutableInputDefs()[i])) || - graph_utils::IsGraphInput(graph, last_add_node.MutableInputDefs()[i])) { - if (last_add_node.MutableInputDefs()[i]->Shape()->dim_size() == static_cast(axes_values.size())) { - bias = last_add_node.MutableInputDefs()[i]; - } + if (last_add_node.MutableInputDefs()[i]->Shape() == nullptr) { + continue; + } + if (last_add_node.MutableInputDefs()[i]->Shape()->dim_size() == static_cast(axes_values.size())) { + bias = last_add_node.MutableInputDefs()[i]; } } if (scale == nullptr || bias == nullptr) { @@ -667,20 +667,20 @@ Status SimplifiedLayerNormFusion::ApplyImpl(Graph& graph, bool& modified, int gr // because SkipLayerNorm kernel, for example, has dependency on single dim size NodeArg* scale = nullptr; for (size_t i = 0; i < mul_node.MutableInputDefs().size(); i++) { - if (graph_utils::NodeArgIsConstant(graph, *(mul_node.MutableInputDefs()[i])) || - graph_utils::IsGraphInput(graph, mul_node.MutableInputDefs()[i])) { -#ifdef ENABLE_TRAINING_CORE - if (axes_values.empty() || - mul_node.MutableInputDefs()[i]->Shape()->dim_size() == static_cast(axes_values.size())) { - scale = mul_node.MutableInputDefs()[i]; - } -#else - // Scale must be 1d. - if (mul_node.MutableInputDefs()[i]->Shape()->dim_size() == 1) { - scale = mul_node.MutableInputDefs()[i]; - } -#endif + if (mul_node.MutableInputDefs()[i]->Shape() == nullptr) { + continue; } +#ifdef ENABLE_TRAINING_CORE + if (axes_values.empty() || + mul_node.MutableInputDefs()[i]->Shape()->dim_size() == static_cast(axes_values.size())) { + scale = mul_node.MutableInputDefs()[i]; + } +#else + // Scale must be 1d. + if (mul_node.MutableInputDefs()[i]->Shape()->dim_size() == 1) { + scale = mul_node.MutableInputDefs()[i]; + } +#endif } if (scale == nullptr) { diff --git a/onnxruntime/test/optimizer/graph_transform_test_layernorm.cc b/onnxruntime/test/optimizer/graph_transform_test_layernorm.cc index 1f671e9009..a55238396c 100755 --- a/onnxruntime/test/optimizer/graph_transform_test_layernorm.cc +++ b/onnxruntime/test/optimizer/graph_transform_test_layernorm.cc @@ -429,6 +429,40 @@ TEST_F(GraphTransformationTests, SimplifiedLayerNormFusionTest) { } } +// It tests the scenario when scale or bias are not Graph Inputs and not initialized in Graph +// To test this added a Identity node after Scale and Bias terms to ensure LayerNormFusion works properly +TEST_F(GraphTransformationTests, LayerNormScaleBiasTest) { + constexpr const ORTCHAR_T* model_uri = MODEL_FOLDER "fusion/layer_norm_fusion_scale_bias.onnx"; + std::shared_ptr p_model; + ASSERT_STATUS_OK(Model::Load(model_uri, p_model, nullptr, *logger_)); + Graph& graph = p_model->MainGraph(); + + onnxruntime::GraphTransformerManager graph_transformation_mgr{5}; + ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::make_unique(), TransformerLevel::Level2)); + ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2, *logger_)); + + std::map op_to_count = CountOpsInGraph(graph); + ASSERT_EQ(op_to_count["ReduceMean"], 0); + ASSERT_EQ(op_to_count["Sub"], 0); + ASSERT_EQ(op_to_count["Cast"], 0); + ASSERT_EQ(op_to_count["Pow"], 0); + ASSERT_EQ(op_to_count["Add"], 0); + ASSERT_EQ(op_to_count["Sqrt"], 0); + ASSERT_EQ(op_to_count["Div"], 0); + ASSERT_EQ(op_to_count["Mul"], 0); + ASSERT_EQ(op_to_count["LayerNormalization"], 1); + + for (const Node& node : graph.Nodes()) { + if (node.OpType() == "LayerNormalization") { + // LayerNormalization should have three inputs. + EXPECT_EQ(node.InputDefs().size(), 3u) << "LayerNormalization number of inputs does not equal to 3. Got:" << node.InputDefs().size(); + // LayerNormalization input "scale" and "bias" should have the same dimension. + const TensorShapeProto* scale_shape = node.InputDefs()[1]->Shape(); + EXPECT_EQ(scale_shape->dim_size(), 1) << "LayerNormalization scale should be 1D. Got: " << scale_shape->dim_size(); + } + } +} + // If EP is non-GPU EP or unknown, the sub-graph will be not fused because CPU impl for SimplifiedLayerNormalization // doesn't support input and scale having different data types. TEST_F(GraphTransformationTests, SimplifiedLayerNormWithCastsFusionTest) { diff --git a/onnxruntime/test/testdata/transform/fusion/layer_norm_fusion_scale_bias.onnx b/onnxruntime/test/testdata/transform/fusion/layer_norm_fusion_scale_bias.onnx new file mode 100644 index 0000000000000000000000000000000000000000..ec0f9a97815b888701198d94c92e3f61c581d6dc GIT binary patch literal 854 zcmbVLO;6h}7_J-BxG#kTYZE93gnT22m1Yu$ooG5~nzW*c-nc|*?V^aLfqd|B%VCEd z_!0b!9ry|SC$N(kk+Bn&96fr!@;r}i(*63k1K$A+X(!=+oM*O~2%gWxfWb)##v)ic z9{~q9B0YN23*95r`2gfxhzlM@>6Q$%VOtJ@dJr|!d|FO4Bw)rQpTZvWWxz-=(HP>i32O%=i^w!!hU}H52Z>Cg;9~+&<_rur`aAl7<+#{``weNx=D_oR1Y^ z#*lN^ftN5P>1C2t1qv}dk>9s!bQLvucvY#9fEnMyE9gV-EQq4O4@;YCBkDS8M){&@ zkboKEd;zSzP?b?MvK3XgqUwV7nl}8kiFTXek@Vf^LOYAAqdEXhvhLB8 zJ7tgC=m2%NeOM_K(1s9uUCR>7EX-~h`N1m$Ln*TIxg<{C$gn@X&P!+h9YMQ4gIkdt z$4TT+3o+bkwT`@(TjOl1*yF?9p4U84XNzD99K0cy*C0`6R*RdWK*euV{6StN>vU7S p0tyxZ+JiQ+