From 94ce1209f9ce1f0efeb76f817fcb0efd63604889 Mon Sep 17 00:00:00 2001 From: Caroline Zhu Date: Mon, 3 Jun 2024 14:41:39 -0700 Subject: [PATCH] Bug fix for gather fusion with on-device training (#20891) ### Description Update the initializer that's added in GatherSliceToSplitFusion to use the GenerateNodeArgName function, rather than the GenerateNodeName function. GenerateNodeName goes through all the nodes in the graph to see if the given name is already used and generates a unique one if it has been used. GenerateNodeArgName iterates through all the node args in the graph to see if the given name is already used. ### Motivation and Context * on-device training goes through a generate artifacts step, where optimizations are applied, then, when the training artifact is loaded, additional optimizations are applied. In the first round of optimizations, a "splits" initializer is added for phi-3. With the second round of optimizations, another "splits" initializer with different dimensions and data is added. Since we call GenerateNodeName func, the first splits initializer isn't found, causing a type error where it claims the shape of splits does not match the TensorProto shape. --- onnxruntime/core/optimizer/gather_fusion.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxruntime/core/optimizer/gather_fusion.cc b/onnxruntime/core/optimizer/gather_fusion.cc index 1f2b31526c..2bde320786 100644 --- a/onnxruntime/core/optimizer/gather_fusion.cc +++ b/onnxruntime/core/optimizer/gather_fusion.cc @@ -268,7 +268,7 @@ Status GatherSliceToSplitFusion::ApplyImpl(Graph& graph, bool& modified, int gra } ONNX_NAMESPACE::TensorProto split_initializer_proto; - split_initializer_proto.set_name(graph.GenerateNodeName("splits")); + split_initializer_proto.set_name(graph.GenerateNodeArgName("splits")); split_initializer_proto.set_data_type(ONNX_NAMESPACE::TensorProto_DataType_INT64); split_initializer_proto.add_dims(static_cast(split_values.size())); split_initializer_proto.mutable_int64_data()->Add(split_values.begin(), split_values.end());