Support skiplayernorm fusion without beta in layernorm (#6617)

* support skiplayernorm fusion without beta in layernorm

* use place holder

* review comments
This commit is contained in:
Ye Wang 2021-02-10 17:50:10 -08:00 committed by GitHub
parent fd83e38dcf
commit a7b6fc08f2
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 20 additions and 2 deletions

View file

@ -217,11 +217,13 @@ Status SkipLayerNormFusion::ApplyImpl(Graph& graph, bool& modified, int graph_le
continue;
}
NodeArg beta_place_holder("", nullptr);
// Get the inputs for the new SkipLayerNormalization node.
std::vector<NodeArg*> skip_layer_norm_input_defs{p_add1->MutableInputDefs()[0],
p_add1->MutableInputDefs()[1],
ln_node.MutableInputDefs()[1],
ln_node.MutableInputDefs()[2]};
ln_node.MutableInputDefs().size() == 2 ? &beta_place_holder : ln_node.MutableInputDefs()[2]};
if (matched_format == Format::Format1) {
skip_layer_norm_input_defs[0] = p_add2->MutableInputDefs()[0];
@ -241,7 +243,6 @@ Status SkipLayerNormFusion::ApplyImpl(Graph& graph, bool& modified, int graph_le
"fused SkipLayerNorm subgraphs ",
skip_layer_norm_input_defs,
ln_node.MutableOutputDefs(), {}, kMSDomain);
// Get attribute "epsilon" from "LayerNormalization" node if available. Else, default value
// will be used.
NodeAttributes ln_attrs = ln_node.GetAttributes();

View file

@ -2751,6 +2751,23 @@ TEST_F(GraphTransformationTests, SkipLayerNormFusion_Input_Output_Check) {
}
}
TEST_F(GraphTransformationTests, SkipLayerNormFusion_NoBeta) {
auto model_uri = MODEL_FOLDER "fusion/skip_layer_norm_no_beta.onnx";
std::shared_ptr<Model> p_model;
ASSERT_STATUS_OK(Model::Load(model_uri, p_model, nullptr, *logger_));
Graph& graph = p_model->MainGraph();
onnxruntime::GraphTransformerManager graph_transformation_mgr{5};
graph_transformation_mgr.Register(onnxruntime::make_unique<SkipLayerNormFusion>(), TransformerLevel::Level2);
auto ret = graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2, *logger_);
ASSERT_TRUE(ret.IsOK());
std::map<std::string, int> op_to_count = CountOpsInGraph(graph);
ASSERT_TRUE(op_to_count["Add"] == 0);
ASSERT_TRUE(op_to_count["LayerNormalization"] == 0);
ASSERT_TRUE(op_to_count["com.microsoft.SkipLayerNormalization"] == 1);
}
TEST_F(GraphTransformationTests, EmbedLayerNormFusionFormat1) {
auto model_uri = MODEL_FOLDER "fusion/embed_layer_norm_format1.onnx";
std::shared_ptr<Model> p_model;