diff --git a/onnxruntime/core/graph/function.cc b/onnxruntime/core/graph/function.cc index f4f5c363b4..785bc6d938 100644 --- a/onnxruntime/core/graph/function.cc +++ b/onnxruntime/core/graph/function.cc @@ -10,18 +10,18 @@ namespace onnxruntime { // Auto inferred and generate an opschema for stand-alone functions // TODO: revisit to see if we can eliminate typeconstraint step -void IOTypeConstraintHelper(const ONNX_NAMESPACE::FunctionProto* onnx_func_proto_, +void IOTypeConstraintHelper(const ONNX_NAMESPACE::FunctionProto& onnx_func_proto_, std::unique_ptr& op_schema_, const std::unordered_map& input_name_idx_map, const std::unordered_map& output_name_idx_map) { - std::vector> input_types_list(onnx_func_proto_->input_size()); - std::vector> output_types_list(onnx_func_proto_->output_size()); + std::vector> input_types_list(onnx_func_proto_.input_size()); + std::vector> output_types_list(onnx_func_proto_.output_size()); std::unordered_map> type_constraint_map; std::unordered_map attribute_type_map; auto schema_registry = ONNX_NAMESPACE::OpSchemaRegistry::Instance(); - for (auto& node : onnx_func_proto_->node()) { + for (auto& node : onnx_func_proto_.node()) { const auto node_op_schema = - schema_registry->GetSchema(node.op_type(), static_cast(onnx_func_proto_->since_version()), node.domain()); + schema_registry->GetSchema(node.op_type(), static_cast(onnx_func_proto_.since_version()), node.domain()); for (int i = 0; i < node.input_size(); ++i) { auto& in_name = node.input().Get(i); auto iter = input_name_idx_map.find(in_name); @@ -77,22 +77,83 @@ void IOTypeConstraintHelper(const ONNX_NAMESPACE::FunctionProto* onnx_func_proto op_schema_->TypeConstraint(tc.first, tc.second, ""); } - for (auto& attribute_name : onnx_func_proto_->attribute()) { + for (auto& attribute_name : onnx_func_proto_.attribute()) { if (attribute_type_map.count(attribute_name)) op_schema_->Attr(attribute_name, "", attribute_type_map[attribute_name], false); } } +// This method updates the names of inputs/outputs of nodes in subgraphs +// within nodes in an op that has a FunctionBody. +// Subgraphs within an op with a FunctionBody could be referencing inputs/outputs in the OpSchema +// and we need to replace these names with the corresponding input/output names from the actual model graph + +// The arguments to this method are : +// (1) The 'subgraph' from a node containing it (ONNX::GraphProto) +// (2) The parent 'graph' - main model graph (OnnxRuntime::Graph) +// (3) The node with a function body (ONNX::NodeProto) +// (4) A map containing the input name from the op schema to the corresponding index +// E.g. For Range-11, {"start" : 0, "limit": 1, "delta": 2} +// (5) A map containing the output name from the op schema to the corresponding index +// E.g. For Range-11, {"output" : 0} +static void update_subgraphs_within_function_body(ONNX_NAMESPACE::GraphProto& subgraph_proto, + const Graph& parent_graph, + const ONNX_NAMESPACE::NodeProto& function_node_in_parent_graph, + const std::unordered_map& input_name_idx_map, + const std::unordered_map& output_name_idx_map) { + // Iterate through all the nodes in the subgraph + for (auto subgraph_node = subgraph_proto.mutable_node()->begin(); + subgraph_node != subgraph_proto.mutable_node()->end(); ++subgraph_node) { + // Iterate through all the inputs of the current node + for (int idx = 0; idx < (*subgraph_node).input_size(); ++idx) { + const std::string& tensor_name = (*subgraph_node).input().Get(idx); + auto iter = input_name_idx_map.find(tensor_name); + // If an input pertaining to the name in the op schema is found, + // replace it with the corresponding input to the node with function body from the actual model graph + if (iter != input_name_idx_map.end()) { + const auto parent_graph_input_to_function_node = function_node_in_parent_graph.input().Get(iter->second); + (*subgraph_node).set_input(idx, parent_graph_input_to_function_node); + } + } + // Iterate through all the output of the current node + for (int idx = 0; idx < (*subgraph_node).output_size(); ++idx) { + const std::string& tensor_name = (*subgraph_node).output().Get(idx); + auto iter = output_name_idx_map.find(tensor_name); + if (iter != output_name_idx_map.end()) { + // If an input pertaining to the name in the op schema is found, + // replace it with the corresponding output to the node with function body from the actual model graph + const auto& parent_graph_output_to_function_node = function_node_in_parent_graph.output().Get(iter->second); + (*subgraph_node).set_output(idx, parent_graph_output_to_function_node); + } + } + + for (auto subgraph_node_attr = (*subgraph_node).mutable_attribute()->begin(); + subgraph_node_attr != (*subgraph_node).mutable_attribute()->end(); ++subgraph_node_attr) { + if ((*subgraph_node_attr).has_f()) { + ORT_THROW( + "A node with a function body within a subgraph within another function body " + "is currently not supported in ORT"); + } + // Recurse into any subgraphs in the current subgraph being processed + if ((*subgraph_node_attr).has_g()) { + update_subgraphs_within_function_body(*(*subgraph_node_attr).mutable_g(), + parent_graph, function_node_in_parent_graph, + input_name_idx_map, output_name_idx_map); + } + } + } +} + FunctionImpl::FunctionImpl(const onnxruntime::Graph& graph, std::unique_ptr customized_func) - : parent_graph_(&graph), onnx_func_proto_{nullptr} { + : parent_graph_(&graph) { customized_func_body_ = std::move(customized_func); // Construct body. body_ = onnxruntime::make_unique("fused_function_subgraph", false, onnxruntime::ModelMetaData(), - IOnnxRuntimeOpSchemaRegistryList({graph.GetSchemaRegistry()}), - graph.DomainToVersionMap()); - auto& sub_graph = body_->MainGraph(); + IOnnxRuntimeOpSchemaRegistryList({graph.GetSchemaRegistry()}), + graph.DomainToVersionMap()); + auto& function_body_graph = body_->MainGraph(); auto meta_def = customized_func_body_->GetMetaDef(); op_schema_ = onnxruntime::make_unique(); @@ -101,30 +162,30 @@ FunctionImpl::FunctionImpl(const onnxruntime::Graph& graph, op_schema_->SetDoc(meta_def->doc_string); op_schema_->SinceVersion(meta_def->since_version); int i = 0; - std::vector sub_graph_inputs; - sub_graph_inputs.resize(meta_def->inputs.size()); + std::vector function_body_graph_inputs; + function_body_graph_inputs.resize(meta_def->inputs.size()); for (auto& input : meta_def->inputs) { auto input_arg = parent_graph_->GetNodeArg(input); - auto& sub_graph_input_arg = sub_graph.GetOrCreateNodeArg(input_arg->Name(), input_arg->TypeAsProto()); - sub_graph_inputs[i] = &sub_graph_input_arg; + auto& function_body_graph_input_arg = function_body_graph.GetOrCreateNodeArg(input_arg->Name(), input_arg->TypeAsProto()); + function_body_graph_inputs[i] = &function_body_graph_input_arg; ORT_ENFORCE(input_arg->Type() != nullptr); op_schema_->Input(i, input, "", *input_arg->Type()); ++i; } i = 0; - std::vector sub_graph_outputs; - sub_graph_outputs.resize(meta_def->outputs.size()); + std::vector function_body_graph_outputs; + function_body_graph_outputs.resize(meta_def->outputs.size()); for (auto& output : meta_def->outputs) { auto output_arg = parent_graph_->GetNodeArg(output); - auto& sub_graph_output_arg = sub_graph.GetOrCreateNodeArg(output_arg->Name(), output_arg->TypeAsProto()); - sub_graph_outputs[i] = &sub_graph_output_arg; + auto& function_body_graph_output_arg = function_body_graph.GetOrCreateNodeArg(output_arg->Name(), output_arg->TypeAsProto()); + function_body_graph_outputs[i] = &function_body_graph_output_arg; op_schema_->Output(i, output, "", *output_arg->Type()); ++i; } op_schema_->Finalize(); - sub_graph.SetInputs(sub_graph_inputs); - sub_graph.SetOutputs(sub_graph_outputs); + function_body_graph.SetInputs(function_body_graph_inputs); + function_body_graph.SetOutputs(function_body_graph_outputs); //Add node and node args //TODO: for better performance, we could try to transfer the nodes in parent graph to sub-graph directly, //instead of create new nodes. @@ -133,46 +194,51 @@ FunctionImpl::FunctionImpl(const onnxruntime::Graph& graph, std::vector inputs; std::vector outputs; for (auto input : node->InputDefs()) { - auto& n_input = sub_graph.GetOrCreateNodeArg(input->Name(), input->TypeAsProto()); + auto& n_input = function_body_graph.GetOrCreateNodeArg(input->Name(), input->TypeAsProto()); inputs.push_back(&n_input); } for (auto output : node->OutputDefs()) { - auto& n_output = sub_graph.GetOrCreateNodeArg(output->Name(), output->TypeAsProto()); + auto& n_output = function_body_graph.GetOrCreateNodeArg(output->Name(), output->TypeAsProto()); outputs.push_back(&n_output); } - sub_graph.AddNode(node->Name(), node->OpType(), node->Description(), inputs, outputs, &node->GetAttributes(), node->Domain()); + function_body_graph.AddNode(node->Name(), node->OpType(), node->Description(), inputs, outputs, &node->GetAttributes(), node->Domain()); } for (const auto& input : meta_def->inputs) { const ONNX_NAMESPACE::TensorProto* initializer = nullptr; if (graph.GetInitializedTensor(input, initializer)) { - sub_graph.AddInitializedTensor(*initializer); + function_body_graph.AddInitializedTensor(*initializer); } } //TODO: if we reuse the nodes in parent graph, maybe we don't need to resolve it. - auto status = sub_graph.Resolve(); + auto status = function_body_graph.Resolve(); ORT_ENFORCE(status.IsOK(), status.ErrorMessage()); } FunctionImpl::FunctionImpl(const onnxruntime::Graph& graph, const onnxruntime::NodeIndex& node_index, - const ONNX_NAMESPACE::FunctionProto* onnx_func_proto) + const ONNX_NAMESPACE::FunctionProto& onnx_func_proto) : parent_graph_(&graph) { + // Make a copy of the FunctionProto. + // All FunctionBody ops with the same op type seem to share the same FunctionProto struct within a model. + // Hence, we make a copy prior to generating the graph representation of the function, + // as we might make some modifications to the FunctionProto along the way onnx_func_proto_ = onnx_func_proto; + auto node_in_parent_graph = parent_graph_->GetNode(node_index); op_schema_ = onnxruntime::make_unique(); - op_schema_->SetName(onnx_func_proto_->name()); - op_schema_->SetDomain(onnx_func_proto_->node().Get(0).domain()); - op_schema_->SetDoc(onnx_func_proto_->doc_string()); - op_schema_->SinceVersion(static_cast(onnx_func_proto_->since_version())); + op_schema_->SetName(onnx_func_proto_.name()); + op_schema_->SetDomain(onnx_func_proto_.node().Get(0).domain()); + op_schema_->SetDoc(onnx_func_proto_.doc_string()); + op_schema_->SinceVersion(static_cast(onnx_func_proto_.since_version())); std::unordered_map input_name_idx_map; std::unordered_map output_name_idx_map; - for (int i = 0; i < onnx_func_proto_->input_size(); ++i) { - input_name_idx_map[onnx_func_proto_->input().Get(i)] = i; + for (int i = 0; i < onnx_func_proto_.input_size(); ++i) { + input_name_idx_map[onnx_func_proto_.input().Get(i)] = i; } - for (int i = 0; i < onnx_func_proto_->output_size(); ++i) { - output_name_idx_map[onnx_func_proto_->output().Get(i)] = i; + for (int i = 0; i < onnx_func_proto_.output_size(); ++i) { + output_name_idx_map[onnx_func_proto_.output().Get(i)] = i; } auto cached_op_schema = node_in_parent_graph->Op(); @@ -219,75 +285,87 @@ FunctionImpl::FunctionImpl(const onnxruntime::Graph& graph, //construct body std::unordered_map domain_to_version; //TODO: set correct domain and version - domain_to_version[onnxruntime::kOnnxDomain] = static_cast(onnx_func_proto_->since_version()); - body_ = onnxruntime::make_unique(onnx_func_proto_->name(), false, onnxruntime::ModelMetaData(), - IOnnxRuntimeOpSchemaRegistryList(), domain_to_version); - auto& sub_graph = body_->MainGraph(); + domain_to_version[onnxruntime::kOnnxDomain] = static_cast(onnx_func_proto_.since_version()); + body_ = onnxruntime::make_unique(onnx_func_proto_.name(), false, onnxruntime::ModelMetaData(), + IOnnxRuntimeOpSchemaRegistryList(), domain_to_version); + auto& function_body_graph = body_->MainGraph(); // Add node and node args into subgraph // The subgraph preserved the input/output tensor names // in the parent graph for later inlining purpose const auto& attr_map = node_in_parent_graph->GetAttributes(); - for (auto& node : onnx_func_proto_->node()) { + + ONNX_NAMESPACE::NodeProto function_op_node_proto; // NodeProto pertaining to the op with a FunctionBody + node_in_parent_graph->ToProto(function_op_node_proto); + + // iterate over each node in the FunctionProto and fix inputs/outputs + for (auto node = onnx_func_proto_.mutable_node()->begin(); node != onnx_func_proto_.mutable_node()->end(); ++node) { std::vector inputs; std::vector outputs; - std::string uniq_identifier = node.name(); - if (!utils::HasName(node)) { + std::string uniq_identifier = (*node).name(); + if (!utils::HasName(*node)) { std::stringstream ss; - ss << static_cast(&node); + ss << static_cast(&(*node)); uniq_identifier = ss.str(); } - for (int idx = 0; idx < node.input_size(); ++idx) { - std::string tensor_name = node.input().Get(idx); + for (int idx = 0; idx < (*node).input_size(); ++idx) { + std::string tensor_name = (*node).input().Get(idx); auto iter = input_name_idx_map.find(tensor_name); if (iter != input_name_idx_map.end()) { // Preserving NodeArg and input/output names - ONNX_NAMESPACE::NodeProto temp_node_proto; - node_in_parent_graph->ToProto(temp_node_proto); - const onnxruntime::NodeArg* node_arg = parent_graph_->GetNodeArg(temp_node_proto.input().Get(input_name_idx_map[tensor_name])); - auto& n_input = sub_graph.GetOrCreateNodeArg( - temp_node_proto.input().Get(iter->second), node_arg->TypeAsProto()); + const onnxruntime::NodeArg* node_arg = parent_graph_->GetNodeArg(function_op_node_proto.input() + .Get(iter->second)); + auto& n_input = function_body_graph.GetOrCreateNodeArg( + function_op_node_proto.input().Get(iter->second), node_arg->TypeAsProto()); inputs.push_back(&n_input); } else { - auto& n_input = sub_graph.GetOrCreateNodeArg( + auto& n_input = function_body_graph.GetOrCreateNodeArg( tensor_name + "_" + std::to_string(node_index), nullptr); inputs.push_back(&n_input); } } - for (int idx = 0; idx < node.output_size(); ++idx) { - std::string tensor_name = node.output().Get(idx); + for (int idx = 0; idx < (*node).output_size(); ++idx) { + std::string tensor_name = (*node).output().Get(idx); auto iter = output_name_idx_map.find(tensor_name); if (iter != output_name_idx_map.end()) { // Preserving NodeArg and input/output names - ONNX_NAMESPACE::NodeProto temp_node_proto; - node_in_parent_graph->ToProto(temp_node_proto); - const onnxruntime::NodeArg* node_arg = parent_graph_->GetNodeArg(temp_node_proto.output().Get(output_name_idx_map[tensor_name])); - auto& n_output = sub_graph.GetOrCreateNodeArg( - temp_node_proto.output().Get(iter->second), node_arg->TypeAsProto()); + const onnxruntime::NodeArg* node_arg = parent_graph_->GetNodeArg(function_op_node_proto.output() + .Get(iter->second)); + auto& n_output = function_body_graph.GetOrCreateNodeArg( + function_op_node_proto.output().Get(iter->second), node_arg->TypeAsProto()); outputs.push_back(&n_output); } else { - auto& n_output = sub_graph.GetOrCreateNodeArg( + auto& n_output = function_body_graph.GetOrCreateNodeArg( tensor_name + "_" + std::to_string(node_index), nullptr); outputs.push_back(&n_output); } } onnxruntime::NodeAttributes new_attr_map; - for (auto& attr : node.attribute()) { - if (!attr.ref_attr_name().empty()) { - auto entry = attr_map.find(attr.ref_attr_name()); + for (auto node_attr = (*node).mutable_attribute()->begin(); + node_attr != (*node).mutable_attribute()->end(); ++node_attr) { + // If this node contains subgraphs, the node inputs/outputs within them needs to be fixed as well + if ((*node_attr).has_g()) { + update_subgraphs_within_function_body(*(*node_attr).mutable_g(), + *parent_graph_, function_op_node_proto, + input_name_idx_map, output_name_idx_map); + } + + if (!(*node_attr).ref_attr_name().empty()) { + auto entry = attr_map.find((*node_attr).ref_attr_name()); if (entry != attr_map.cend()) { - new_attr_map[attr.name()] = entry->second; + new_attr_map[(*node_attr).name()] = entry->second; } } else { - new_attr_map[attr.name()] = attr; + new_attr_map[(*node_attr).name()] = *node_attr; } } - sub_graph.AddNode(uniq_identifier + "_" + std::to_string(node_index), node.op_type(), node.doc_string(), inputs, outputs, &new_attr_map, node.domain()); + function_body_graph.AddNode(uniq_identifier + "_" + std::to_string(node_index), (*node).op_type(), (*node).doc_string(), inputs, outputs, &new_attr_map, (*node).domain()); } - auto status = sub_graph.Resolve(); + + auto status = function_body_graph.Resolve(); ORT_ENFORCE(status.IsOK(), "Resolve subgraph failed:", status.ErrorMessage()); -} +} // namespace onnxruntime FunctionImpl::~FunctionImpl() = default; @@ -304,7 +382,7 @@ const IndexedSubGraph& FunctionImpl::GetIndexedSubGraph() const { } const ONNX_NAMESPACE::FunctionProto* FunctionImpl::GetFuncProto() const { - return onnx_func_proto_; + return &onnx_func_proto_; } std::unique_ptr MakeFunction(const onnxruntime::Graph& graph, diff --git a/onnxruntime/core/graph/function_impl.h b/onnxruntime/core/graph/function_impl.h index 3ac1541bdc..900371170d 100644 --- a/onnxruntime/core/graph/function_impl.h +++ b/onnxruntime/core/graph/function_impl.h @@ -20,7 +20,7 @@ class FunctionImpl final : public Function { FunctionImpl(const onnxruntime::Graph& graph, const onnxruntime::NodeIndex& node_index, - const ONNX_NAMESPACE::FunctionProto* onnx_func); + const ONNX_NAMESPACE::FunctionProto& onnx_func); ~FunctionImpl() override; @@ -37,7 +37,7 @@ class FunctionImpl final : public Function { std::unique_ptr customized_func_body_; std::unique_ptr op_schema_; std::unique_ptr body_; - const ONNX_NAMESPACE::FunctionProto* onnx_func_proto_; + ONNX_NAMESPACE::FunctionProto onnx_func_proto_; }; } // namespace onnxruntime diff --git a/onnxruntime/core/graph/graph.cc b/onnxruntime/core/graph/graph.cc index 9f8c4550d5..e4005dc11d 100644 --- a/onnxruntime/core/graph/graph.cc +++ b/onnxruntime/core/graph/graph.cc @@ -669,8 +669,8 @@ Graph::Graph(GraphProto* graph_proto, const std::unordered_map const gsl::not_null tensor{graph_proto_->add_initializer()}; const AttributeProto& constant_attribute = node.attribute(0); // TODO: Add support for parsing 'sparse_value' attribute from a 'Constant' node - // Discussion surrounding handling the SparseTensorProto must be had. - // An easy way is to implement a method that converts a SparseTensorproto into a TensorProto + // Discussion surrounding handling the SparseTensorProto must be had. + // An easy way is to implement a method that converts a SparseTensorproto into a TensorProto // to use the same downstream flow, but that is going to impact peak memory usage and probably a smarter way is required. ORT_ENFORCE(constant_attribute.has_t(), "Only 'value' attribute is supported within a 'Constant' node in ORT"); *tensor = constant_attribute.t(); @@ -1722,7 +1722,7 @@ Status Graph::VerifyNodeAndOpMatch() { auto iter = model_functions_.find(node.OpType()); if (iter != model_functions_.end()) { const ONNX_NAMESPACE::FunctionProto* model_function_proto = iter->second; - auto model_func_ptr = onnxruntime::make_unique(*this, node.Index(), model_function_proto); + auto model_func_ptr = onnxruntime::make_unique(*this, node.Index(), *model_function_proto); function_container_.emplace_back(std::move(model_func_ptr)); node.SetFunctionBody(*function_container_.back()); } @@ -1743,7 +1743,7 @@ Status Graph::VerifyNodeAndOpMatch() { if (node.op_ && node.op_->HasFunction()) { auto onnx_function_proto = node.op_->GetFunction(); - auto func_ptr = onnxruntime::make_unique(*this, node.Index(), onnx_function_proto); + auto func_ptr = onnxruntime::make_unique(*this, node.Index(), *onnx_function_proto); function_container_.emplace_back(std::move(func_ptr)); node.SetFunctionBody(*function_container_.back()); } diff --git a/onnxruntime/test/ir/onnx_model_test.cc b/onnxruntime/test/ir/onnx_model_test.cc index 292b42a043..5c4dce0452 100644 --- a/onnxruntime/test/ir/onnx_model_test.cc +++ b/onnxruntime/test/ir/onnx_model_test.cc @@ -162,5 +162,20 @@ TEST(ONNXModelsTest, TestIRv4NonInputInitializers) { ASSERT_TRUE(Model::Load("testdata/subgraph_implicit_input_from_initializer.onnx", model).IsOK()); EXPECT_TRUE(model->MainGraph().Resolve().IsOK()); } + +// test a model that has an op with a FunctionBody and one of the nodes within the FunctionBody has a subgraph in it. +// The test model has is an opset-11 op with a 'Range' node. +// 'Range' has a FunctionBody and has a 'Loop' node with a subgraph. +// Graph::Resolve to succeed when processing the subgraph pertaining to the overall FunctionBody. +TEST(ONNXModelsTest, TestModelsWithAnOpContainingAFunctionBody) { + std::shared_ptr model; + + auto status = Model::Load("testdata/model_containing_op_with_function_body.onnx", model); + EXPECT_TRUE(status.IsOK()) << status; + + status = model->MainGraph().Resolve(); + EXPECT_TRUE(status.IsOK()) << status; +} + } // namespace test } // namespace onnxruntime diff --git a/onnxruntime/test/testdata/model_containing_op_with_function_body.onnx b/onnxruntime/test/testdata/model_containing_op_with_function_body.onnx new file mode 100644 index 0000000000..b1a73839b9 Binary files /dev/null and b/onnxruntime/test/testdata/model_containing_op_with_function_body.onnx differ