Bug fix for gather fusion with on-device training (#20891)

### Description
Update the initializer that's added in GatherSliceToSplitFusion to use
the GenerateNodeArgName function, rather than the GenerateNodeName
function.

GenerateNodeName goes through all the nodes in the graph to see if the
given name is already used and generates a unique one if it has been
used. GenerateNodeArgName iterates through all the node args in the
graph to see if the given name is already used.

### Motivation and Context
* on-device training goes through a generate artifacts step, where
optimizations are applied, then, when the training artifact is loaded,
additional optimizations are applied. In the first round of
optimizations, a "splits" initializer is added for phi-3. With the
second round of optimizations, another "splits" initializer with
different dimensions and data is added. Since we call GenerateNodeName
func, the first splits initializer isn't found, causing a type error
where it claims the shape of splits does not match the TensorProto
shape.
This commit is contained in:
Caroline Zhu 2024-06-03 14:41:39 -07:00 committed by GitHub
parent 456ab09d17
commit 94ce1209f9
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -268,7 +268,7 @@ Status GatherSliceToSplitFusion::ApplyImpl(Graph& graph, bool& modified, int gra
}
ONNX_NAMESPACE::TensorProto split_initializer_proto;
split_initializer_proto.set_name(graph.GenerateNodeName("splits"));
split_initializer_proto.set_name(graph.GenerateNodeArgName("splits"));
split_initializer_proto.set_data_type(ONNX_NAMESPACE::TensorProto_DataType_INT64);
split_initializer_proto.add_dims(static_cast<int64_t>(split_values.size()));
split_initializer_proto.mutable_int64_data()->Add(split_values.begin(), split_values.end());