Keep original name during fusion (#20097)

### Keep original name during fusion

This could be helpful to know where the fused node coming from, I feel
this is very useful when debugging the execution order issues between
different transformer layers.

For example:

- A node named
`/_original_module/model/layers.1/self_attn/MatMul/MatmulTransposeFusion//MatMulScaleFusion/`
goes through two fusion paths in the 1st transformer layer - e.g.
`MatmulTransposeFusion` and `MatMulScaleFusion`.

-
`/_original_module/model/layers.2/post_attention_layernorm/Mul_1/SimplifiedLayerNormFusion/`
node is a fused node by `SimplifiedLayerNormFusion`.


### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->
This commit is contained in:
pengwa 2024-03-28 08:40:34 +08:00 committed by GitHub
parent a9d9b083e4
commit 55f63a48ca
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
8 changed files with 13 additions and 13 deletions

View file

@ -273,7 +273,7 @@ Status GatherSliceToSplitFusion::ApplyImpl(Graph& graph, bool& modified, int gra
split_initializer_proto.add_dims(static_cast<int64_t>(split_values.size()));
split_initializer_proto.mutable_int64_data()->Add(split_values.begin(), split_values.end());
NodeArg* split_initializer_arg = &graph_utils::AddInitializer(graph, split_initializer_proto);
Node& split_node = graph.AddNode(graph.GenerateNodeName("Split"), "Split", "Split for Fused Gather nodes",
Node& split_node = graph.AddNode(nodes_to_fuse[0].get().Name() + "/GatherSliceToSplitFusion/", "Split", "Split for Fused Gather nodes",
{graph.GetNodeArg(node_arg->Name()), split_initializer_arg}, split_outputs);
split_node.AddAttribute("axis", axis);
split_node.SetExecutionProviderType(nodes_to_fuse[0].get().GetExecutionProviderType());

View file

@ -75,7 +75,7 @@ Status GemmTransposeFusion::Apply(Graph& graph, Node& node, RewriteRuleEffect& m
nodes_to_remove.push_back(output_node);
}
Node& new_gemm_node = graph.AddNode(graph.GenerateNodeName(gemm_node.Name() + "_transformed"),
Node& new_gemm_node = graph.AddNode(graph.GenerateNodeName(gemm_node.Name() + "/GemmTransposeFusion/"),
gemm_node.OpType(),
"Fused Gemm with Transpose",
new_gemm_input_defs,

View file

@ -455,7 +455,7 @@ Status LayerNormFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level,
}
InlinedVector<NodeArg*> layer_norm_input_defs{x_input, scale, bias};
Node& layer_norm_node = graph.AddNode(graph.GenerateNodeName("LayerNormalization"),
Node& layer_norm_node = graph.AddNode(graph.GenerateNodeName(mul_node.Name() + "/LayerNormFusion/"),
"LayerNormalization",
"fused LayerNorm subgraphs ",
layer_norm_input_defs,
@ -705,7 +705,7 @@ Status SimplifiedLayerNormFusion::ApplyImpl(Graph& graph, bool& modified, int gr
InlinedVector<NodeArg*> layer_norm_input_defs{x_input, scale};
Node& layer_norm_node =
graph.AddNode(graph.GenerateNodeName("SimplifiedLayerNormalization"), "SimplifiedLayerNormalization",
graph.AddNode(graph.GenerateNodeName(mul_node.Name() + "/SimplifiedLayerNormFusion/"), "SimplifiedLayerNormalization",
"fused LayerNorm subgraphs ", layer_norm_input_defs, {}, {}, kOnnxDomain);
// Get constant "epsilon" from "Add" node if available. Else, default value will be used.

View file

@ -245,7 +245,7 @@ Status ProcessNode(
}
Node& matmul_scale_node = graph.AddNode(
graph.GenerateNodeName(node.Name() + "_FusedMatMulAndScale"),
graph.GenerateNodeName(node.Name() + "/MatMulScaleFusion/"),
"FusedMatMul",
"Fused MatMul and Scale",
fused_node_inputs,

View file

@ -154,14 +154,14 @@ static Node* ReorderCastAndTranspose(Graph& graph, Node* cast,
const ONNX_NAMESPACE::TensorProto_DataType element_type =
static_cast<ONNX_NAMESPACE::TensorProto_DataType>(cast_output->TypeAsProto()->tensor_type().elem_type());
new_cast_output_type_proto.mutable_tensor_type()->set_elem_type(element_type);
auto& new_cast_output = graph.GetOrCreateNodeArg(cast_output->Name() + "_transformed", &new_cast_output_type_proto);
auto& new_cast_output = graph.GetOrCreateNodeArg(cast_output->Name() + "/MatmulTransposeFusion/", &new_cast_output_type_proto);
const std::array new_cast_input_defs{transpose_input};
const std::array new_cast_output_defs{&new_cast_output};
const std::array new_transpose_input_defs = {&new_cast_output};
const std::array new_transpose_output_defs = {cast_output};
Node& new_cast = graph.AddNode(graph.GenerateNodeName(cast->Name() + "_transformed"),
Node& new_cast = graph.AddNode(graph.GenerateNodeName(cast->Name() + "/MatmulTransposeFusion/"),
cast->OpType(),
"Created a new Cast node to interchange Cast and Transpose nodes",
new_cast_input_defs,
@ -385,7 +385,7 @@ Status MatmulTransposeFusion::ApplyImpl(Graph& graph, bool& modified, int graph_
const std::array input_defs{left_input, right_input};
const std::array output_defs{node.MutableOutputDefs()[0]};
Node& matmul_node = graph.AddNode(graph.GenerateNodeName("MatMul_With_Transpose"),
Node& matmul_node = graph.AddNode(graph.GenerateNodeName(node.Name() + "/MatmulTransposeFusion/"),
"FusedMatMul",
"fused MatMul and Transpose ",
input_defs,

View file

@ -88,7 +88,7 @@ Status QuickGeluFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level,
NodeArg* quick_gelu_output_arg = mul_node.MutableOutputDefs()[0];
Node& quick_gelu_node =
graph.AddNode(graph.GenerateNodeName("QuickGelu"), "QuickGelu", "QuickGelu", std::array{quick_gelu_input_arg},
graph.AddNode(graph.GenerateNodeName(mul_node.Name() + "/QuickGeluFusion/"), "QuickGelu", "QuickGelu", std::array{quick_gelu_input_arg},
std::array{quick_gelu_output_arg}, {}, kMSDomain);
quick_gelu_node.AddAttribute("alpha", alpha);
quick_gelu_node.SetExecutionProviderType(node.GetExecutionProviderType());

View file

@ -2724,7 +2724,7 @@ TEST_F(GraphTransformationTests, GemmTransposeFusion2OutputsFromTranspose) {
auto gemm_node =
std::find_if(
graph.Nodes().cbegin(), graph.Nodes().cend(),
[](const Node& node) { return node.Name() == "Gemm_transformed"; });
[](const Node& node) { return node.Name() == "Gemm/GemmTransposeFusion/"; });
auto& node = *gemm_node;
ASSERT_TRUE(node.OpType() == "Gemm");
@ -2760,7 +2760,7 @@ TEST_F(GraphTransformationTests, GemmTransposeFusion2OutputsFromTransposeTo2Gemm
auto gemm1_node =
std::find_if(
graph.Nodes().cbegin(), graph.Nodes().cend(),
[](const Node& node) { return node.Name() == "Gemm1_transformed"; });
[](const Node& node) { return node.Name() == "Gemm1/GemmTransposeFusion/"; });
auto& node1 = *gemm1_node;
ASSERT_TRUE(node1.OpType() == "Gemm");
@ -2773,7 +2773,7 @@ TEST_F(GraphTransformationTests, GemmTransposeFusion2OutputsFromTransposeTo2Gemm
auto gemm2_node =
std::find_if(
graph.Nodes().cbegin(), graph.Nodes().cend(),
[](const Node& node) { return node.Name() == "Gemm2_transformed"; });
[](const Node& node) { return node.Name() == "Gemm2/GemmTransposeFusion/"; });
auto& node2 = *gemm2_node;
ASSERT_TRUE(node2.OpType() == "Gemm");

View file

@ -23,7 +23,7 @@ Status ConcatReplacement::Apply(Graph& graph, Node& concat_node, RewriteRuleEffe
concat_outputs.push_back(&ip_shape_op);
Node& concat_training_node = graph.AddNode(graph.GenerateNodeName("ConcatTraining"),
Node& concat_training_node = graph.AddNode(graph.GenerateNodeName(concat_node.Name() + "/ConcatReplacement/"),
"ConcatTraining",
"Concat with extra output",
concat_inputs,