mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-31 23:27:43 +00:00
Keep original name during fusion (#20097)
### Keep original name during fusion This could be helpful to know where the fused node coming from, I feel this is very useful when debugging the execution order issues between different transformer layers. For example: - A node named `/_original_module/model/layers.1/self_attn/MatMul/MatmulTransposeFusion//MatMulScaleFusion/` goes through two fusion paths in the 1st transformer layer - e.g. `MatmulTransposeFusion` and `MatMulScaleFusion`. - `/_original_module/model/layers.2/post_attention_layernorm/Mul_1/SimplifiedLayerNormFusion/` node is a fused node by `SimplifiedLayerNormFusion`. ### Motivation and Context <!-- - Why is this change required? What problem does it solve? - If it fixes an open issue, please link to the issue here. -->
This commit is contained in:
parent
a9d9b083e4
commit
55f63a48ca
8 changed files with 13 additions and 13 deletions
|
|
@ -273,7 +273,7 @@ Status GatherSliceToSplitFusion::ApplyImpl(Graph& graph, bool& modified, int gra
|
|||
split_initializer_proto.add_dims(static_cast<int64_t>(split_values.size()));
|
||||
split_initializer_proto.mutable_int64_data()->Add(split_values.begin(), split_values.end());
|
||||
NodeArg* split_initializer_arg = &graph_utils::AddInitializer(graph, split_initializer_proto);
|
||||
Node& split_node = graph.AddNode(graph.GenerateNodeName("Split"), "Split", "Split for Fused Gather nodes",
|
||||
Node& split_node = graph.AddNode(nodes_to_fuse[0].get().Name() + "/GatherSliceToSplitFusion/", "Split", "Split for Fused Gather nodes",
|
||||
{graph.GetNodeArg(node_arg->Name()), split_initializer_arg}, split_outputs);
|
||||
split_node.AddAttribute("axis", axis);
|
||||
split_node.SetExecutionProviderType(nodes_to_fuse[0].get().GetExecutionProviderType());
|
||||
|
|
|
|||
|
|
@ -75,7 +75,7 @@ Status GemmTransposeFusion::Apply(Graph& graph, Node& node, RewriteRuleEffect& m
|
|||
nodes_to_remove.push_back(output_node);
|
||||
}
|
||||
|
||||
Node& new_gemm_node = graph.AddNode(graph.GenerateNodeName(gemm_node.Name() + "_transformed"),
|
||||
Node& new_gemm_node = graph.AddNode(graph.GenerateNodeName(gemm_node.Name() + "/GemmTransposeFusion/"),
|
||||
gemm_node.OpType(),
|
||||
"Fused Gemm with Transpose",
|
||||
new_gemm_input_defs,
|
||||
|
|
|
|||
|
|
@ -455,7 +455,7 @@ Status LayerNormFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level,
|
|||
}
|
||||
|
||||
InlinedVector<NodeArg*> layer_norm_input_defs{x_input, scale, bias};
|
||||
Node& layer_norm_node = graph.AddNode(graph.GenerateNodeName("LayerNormalization"),
|
||||
Node& layer_norm_node = graph.AddNode(graph.GenerateNodeName(mul_node.Name() + "/LayerNormFusion/"),
|
||||
"LayerNormalization",
|
||||
"fused LayerNorm subgraphs ",
|
||||
layer_norm_input_defs,
|
||||
|
|
@ -705,7 +705,7 @@ Status SimplifiedLayerNormFusion::ApplyImpl(Graph& graph, bool& modified, int gr
|
|||
|
||||
InlinedVector<NodeArg*> layer_norm_input_defs{x_input, scale};
|
||||
Node& layer_norm_node =
|
||||
graph.AddNode(graph.GenerateNodeName("SimplifiedLayerNormalization"), "SimplifiedLayerNormalization",
|
||||
graph.AddNode(graph.GenerateNodeName(mul_node.Name() + "/SimplifiedLayerNormFusion/"), "SimplifiedLayerNormalization",
|
||||
"fused LayerNorm subgraphs ", layer_norm_input_defs, {}, {}, kOnnxDomain);
|
||||
|
||||
// Get constant "epsilon" from "Add" node if available. Else, default value will be used.
|
||||
|
|
|
|||
|
|
@ -245,7 +245,7 @@ Status ProcessNode(
|
|||
}
|
||||
|
||||
Node& matmul_scale_node = graph.AddNode(
|
||||
graph.GenerateNodeName(node.Name() + "_FusedMatMulAndScale"),
|
||||
graph.GenerateNodeName(node.Name() + "/MatMulScaleFusion/"),
|
||||
"FusedMatMul",
|
||||
"Fused MatMul and Scale",
|
||||
fused_node_inputs,
|
||||
|
|
|
|||
|
|
@ -154,14 +154,14 @@ static Node* ReorderCastAndTranspose(Graph& graph, Node* cast,
|
|||
const ONNX_NAMESPACE::TensorProto_DataType element_type =
|
||||
static_cast<ONNX_NAMESPACE::TensorProto_DataType>(cast_output->TypeAsProto()->tensor_type().elem_type());
|
||||
new_cast_output_type_proto.mutable_tensor_type()->set_elem_type(element_type);
|
||||
auto& new_cast_output = graph.GetOrCreateNodeArg(cast_output->Name() + "_transformed", &new_cast_output_type_proto);
|
||||
auto& new_cast_output = graph.GetOrCreateNodeArg(cast_output->Name() + "/MatmulTransposeFusion/", &new_cast_output_type_proto);
|
||||
|
||||
const std::array new_cast_input_defs{transpose_input};
|
||||
const std::array new_cast_output_defs{&new_cast_output};
|
||||
const std::array new_transpose_input_defs = {&new_cast_output};
|
||||
const std::array new_transpose_output_defs = {cast_output};
|
||||
|
||||
Node& new_cast = graph.AddNode(graph.GenerateNodeName(cast->Name() + "_transformed"),
|
||||
Node& new_cast = graph.AddNode(graph.GenerateNodeName(cast->Name() + "/MatmulTransposeFusion/"),
|
||||
cast->OpType(),
|
||||
"Created a new Cast node to interchange Cast and Transpose nodes",
|
||||
new_cast_input_defs,
|
||||
|
|
@ -385,7 +385,7 @@ Status MatmulTransposeFusion::ApplyImpl(Graph& graph, bool& modified, int graph_
|
|||
const std::array input_defs{left_input, right_input};
|
||||
const std::array output_defs{node.MutableOutputDefs()[0]};
|
||||
|
||||
Node& matmul_node = graph.AddNode(graph.GenerateNodeName("MatMul_With_Transpose"),
|
||||
Node& matmul_node = graph.AddNode(graph.GenerateNodeName(node.Name() + "/MatmulTransposeFusion/"),
|
||||
"FusedMatMul",
|
||||
"fused MatMul and Transpose ",
|
||||
input_defs,
|
||||
|
|
|
|||
|
|
@ -88,7 +88,7 @@ Status QuickGeluFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level,
|
|||
|
||||
NodeArg* quick_gelu_output_arg = mul_node.MutableOutputDefs()[0];
|
||||
Node& quick_gelu_node =
|
||||
graph.AddNode(graph.GenerateNodeName("QuickGelu"), "QuickGelu", "QuickGelu", std::array{quick_gelu_input_arg},
|
||||
graph.AddNode(graph.GenerateNodeName(mul_node.Name() + "/QuickGeluFusion/"), "QuickGelu", "QuickGelu", std::array{quick_gelu_input_arg},
|
||||
std::array{quick_gelu_output_arg}, {}, kMSDomain);
|
||||
quick_gelu_node.AddAttribute("alpha", alpha);
|
||||
quick_gelu_node.SetExecutionProviderType(node.GetExecutionProviderType());
|
||||
|
|
|
|||
|
|
@ -2724,7 +2724,7 @@ TEST_F(GraphTransformationTests, GemmTransposeFusion2OutputsFromTranspose) {
|
|||
auto gemm_node =
|
||||
std::find_if(
|
||||
graph.Nodes().cbegin(), graph.Nodes().cend(),
|
||||
[](const Node& node) { return node.Name() == "Gemm_transformed"; });
|
||||
[](const Node& node) { return node.Name() == "Gemm/GemmTransposeFusion/"; });
|
||||
|
||||
auto& node = *gemm_node;
|
||||
ASSERT_TRUE(node.OpType() == "Gemm");
|
||||
|
|
@ -2760,7 +2760,7 @@ TEST_F(GraphTransformationTests, GemmTransposeFusion2OutputsFromTransposeTo2Gemm
|
|||
auto gemm1_node =
|
||||
std::find_if(
|
||||
graph.Nodes().cbegin(), graph.Nodes().cend(),
|
||||
[](const Node& node) { return node.Name() == "Gemm1_transformed"; });
|
||||
[](const Node& node) { return node.Name() == "Gemm1/GemmTransposeFusion/"; });
|
||||
|
||||
auto& node1 = *gemm1_node;
|
||||
ASSERT_TRUE(node1.OpType() == "Gemm");
|
||||
|
|
@ -2773,7 +2773,7 @@ TEST_F(GraphTransformationTests, GemmTransposeFusion2OutputsFromTransposeTo2Gemm
|
|||
auto gemm2_node =
|
||||
std::find_if(
|
||||
graph.Nodes().cbegin(), graph.Nodes().cend(),
|
||||
[](const Node& node) { return node.Name() == "Gemm2_transformed"; });
|
||||
[](const Node& node) { return node.Name() == "Gemm2/GemmTransposeFusion/"; });
|
||||
|
||||
auto& node2 = *gemm2_node;
|
||||
ASSERT_TRUE(node2.OpType() == "Gemm");
|
||||
|
|
|
|||
|
|
@ -23,7 +23,7 @@ Status ConcatReplacement::Apply(Graph& graph, Node& concat_node, RewriteRuleEffe
|
|||
|
||||
concat_outputs.push_back(&ip_shape_op);
|
||||
|
||||
Node& concat_training_node = graph.AddNode(graph.GenerateNodeName("ConcatTraining"),
|
||||
Node& concat_training_node = graph.AddNode(graph.GenerateNodeName(concat_node.Name() + "/ConcatReplacement/"),
|
||||
"ConcatTraining",
|
||||
"Concat with extra output",
|
||||
concat_inputs,
|
||||
|
|
|
|||
Loading…
Reference in a new issue