mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-07-03 03:58:54 +00:00
Support skiplayernorm fusion without beta in layernorm (#6617)
* support skiplayernorm fusion without beta in layernorm * use place holder * review comments
This commit is contained in:
parent
fd83e38dcf
commit
a7b6fc08f2
3 changed files with 20 additions and 2 deletions
|
|
@ -217,11 +217,13 @@ Status SkipLayerNormFusion::ApplyImpl(Graph& graph, bool& modified, int graph_le
|
|||
continue;
|
||||
}
|
||||
|
||||
NodeArg beta_place_holder("", nullptr);
|
||||
|
||||
// Get the inputs for the new SkipLayerNormalization node.
|
||||
std::vector<NodeArg*> skip_layer_norm_input_defs{p_add1->MutableInputDefs()[0],
|
||||
p_add1->MutableInputDefs()[1],
|
||||
ln_node.MutableInputDefs()[1],
|
||||
ln_node.MutableInputDefs()[2]};
|
||||
ln_node.MutableInputDefs().size() == 2 ? &beta_place_holder : ln_node.MutableInputDefs()[2]};
|
||||
|
||||
if (matched_format == Format::Format1) {
|
||||
skip_layer_norm_input_defs[0] = p_add2->MutableInputDefs()[0];
|
||||
|
|
@ -241,7 +243,6 @@ Status SkipLayerNormFusion::ApplyImpl(Graph& graph, bool& modified, int graph_le
|
|||
"fused SkipLayerNorm subgraphs ",
|
||||
skip_layer_norm_input_defs,
|
||||
ln_node.MutableOutputDefs(), {}, kMSDomain);
|
||||
|
||||
// Get attribute "epsilon" from "LayerNormalization" node if available. Else, default value
|
||||
// will be used.
|
||||
NodeAttributes ln_attrs = ln_node.GetAttributes();
|
||||
|
|
|
|||
|
|
@ -2751,6 +2751,23 @@ TEST_F(GraphTransformationTests, SkipLayerNormFusion_Input_Output_Check) {
|
|||
}
|
||||
}
|
||||
|
||||
TEST_F(GraphTransformationTests, SkipLayerNormFusion_NoBeta) {
|
||||
auto model_uri = MODEL_FOLDER "fusion/skip_layer_norm_no_beta.onnx";
|
||||
std::shared_ptr<Model> p_model;
|
||||
ASSERT_STATUS_OK(Model::Load(model_uri, p_model, nullptr, *logger_));
|
||||
Graph& graph = p_model->MainGraph();
|
||||
|
||||
onnxruntime::GraphTransformerManager graph_transformation_mgr{5};
|
||||
graph_transformation_mgr.Register(onnxruntime::make_unique<SkipLayerNormFusion>(), TransformerLevel::Level2);
|
||||
auto ret = graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2, *logger_);
|
||||
ASSERT_TRUE(ret.IsOK());
|
||||
|
||||
std::map<std::string, int> op_to_count = CountOpsInGraph(graph);
|
||||
ASSERT_TRUE(op_to_count["Add"] == 0);
|
||||
ASSERT_TRUE(op_to_count["LayerNormalization"] == 0);
|
||||
ASSERT_TRUE(op_to_count["com.microsoft.SkipLayerNormalization"] == 1);
|
||||
}
|
||||
|
||||
TEST_F(GraphTransformationTests, EmbedLayerNormFusionFormat1) {
|
||||
auto model_uri = MODEL_FOLDER "fusion/embed_layer_norm_format1.onnx";
|
||||
std::shared_ptr<Model> p_model;
|
||||
|
|
|
|||
BIN
onnxruntime/test/testdata/transform/fusion/skip_layer_norm_no_beta.onnx
vendored
Normal file
BIN
onnxruntime/test/testdata/transform/fusion/skip_layer_norm_no_beta.onnx
vendored
Normal file
Binary file not shown.
Loading…
Reference in a new issue