mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-22 22:01:08 +00:00
Fix issue in constant-propagation inside function subgraph (#16330)
### Description The SequenceMap function-op has a graph-attribute. ORT's constant-folding optimization may identify constant-expressions inside the subgraph and promote them to constants, stored as initializers in the main graph. When it does this, the optimization updates the subgraph to remove the corresponding nodes. When we expand a SequenceMap node by inlining its function-expansion, we need to use this updated subgraph. However, the existing code uses the original graph-attribute (GraphProto), instead of regenerating it from the modified subgraph. This results in producing a graph with duplicate definitions for the constant-folded variable, resulting in an error during graph-resolve. This PR fixes this issue (just a single line fix), and adds a test-case to cover this scenario. --------- Signed-off-by: Ganesan Ramalingam <grama@microsoft.com> Co-authored-by: Suryaprakash Shanmugam <suryaprakash.shanmugam@intel.com>
This commit is contained in:
parent
ea43671eb6
commit
4faee2e44c
4 changed files with 33 additions and 4 deletions
|
|
@ -585,7 +585,7 @@ bool Node::TryGetFunctionProto(ONNX_NAMESPACE::FunctionProto& onnx_function_prot
|
|||
// Check if this node has a schema defined function proto.
|
||||
if (op_->HasContextDependentFunction()) {
|
||||
NodeProto node_proto;
|
||||
ToProto(node_proto);
|
||||
ToProto(node_proto, true);
|
||||
std::vector<TypeProto> input_types;
|
||||
for (size_t i = 0, n = InputDefs().size(); i < n; i++) {
|
||||
auto p_node_arg = InputDefs().at(i);
|
||||
|
|
|
|||
|
|
@ -75,11 +75,15 @@ std::vector<std::unique_ptr<ComputeCapability>> GetCapability::Execute() {
|
|||
}
|
||||
|
||||
const auto& nodes = graph_viewer_.GetNodesInTopologicalOrder();
|
||||
|
||||
// Handle cases where lone, reoccuring Ops in smaller models cannot be supported in OpenVINO
|
||||
// If only a node of the same lone,unsupported type is present, then do not proceed with the subgraph
|
||||
const auto& node = graph_viewer_.GetNode(nodes[0]);
|
||||
if (data_ops_->IsOpSupportedOnlyInModel(node->OpType()))
|
||||
return result;
|
||||
|
||||
// Nodes that work well in models but not as a single node
|
||||
if (nodes.size() == 1) {
|
||||
const auto& node = graph_viewer_.GetNode(nodes[0]);
|
||||
if (data_ops_->IsOpSupportedOnlyInModel(node->OpType()))
|
||||
return result;
|
||||
// If reshape is not an intermediate node, shape needs to be an initializer
|
||||
if (data_ops_->SpecialConditionForClusterSizeOne(ng_required_initializers, node)) {
|
||||
return result;
|
||||
|
|
|
|||
|
|
@ -30,6 +30,7 @@ namespace openvino_ep {
|
|||
|
||||
// Ops which are supported only in models(as intermediate nodes) and not in unit tests
|
||||
std::set<std::string> ops_supported_only_in_model = {
|
||||
"Add",
|
||||
"Cast",
|
||||
"Concat",
|
||||
"ConstantOfShape",
|
||||
|
|
|
|||
|
|
@ -506,5 +506,29 @@ TEST(FunctionTest, UnusedFunctionInputs) {
|
|||
Check(code, "x", {1.0, 2.0, 3.0}, "y", {1.0, 4.0, 9.0});
|
||||
}
|
||||
|
||||
// Test constant-folding inside a sub-graph is handled correctly
|
||||
// for functions that are inlined.
|
||||
TEST(FunctionTest, ConstantFoldingInSubGraph) {
|
||||
const char* code = R"(
|
||||
<ir_version: 8, opset_import: [ "" : 17 ]>
|
||||
agraph (float[N] X) => (float[M] Y) {
|
||||
seq1 = SequenceConstruct(X, X, X)
|
||||
seq2 = SequenceMap (seq1) <body =
|
||||
add1 (float[K] Z) => (float[K] W) {
|
||||
C1 = Constant <value = float {1.0}> ()
|
||||
C2 = Constant <value = float {1.0}> ()
|
||||
# C is a constant, which will be constant-folded into an initializer out of the sub-graph.
|
||||
C = Add (C1, C2)
|
||||
# After optimization, only following Add will be left in this sub-graph.
|
||||
W = Add (Z, C)
|
||||
}
|
||||
>
|
||||
Y = ConcatFromSequence <axis=0> (seq2)
|
||||
}
|
||||
)";
|
||||
|
||||
Check(code, "X", {1.0, 2.0, 3.0}, "Y", {3.0, 4.0, 5.0, 3.0, 4.0, 5.0, 3.0, 4.0, 5.0});
|
||||
}
|
||||
|
||||
} // namespace test
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
Loading…
Reference in a new issue