mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-18 21:21:17 +00:00
Handle nodes with subgraphs in ORT function handling implementation (#2053)
* Initial commit * Update * Update * Nits * More updates * to be reverted * Update * Update * More changes * Updates * Update Function * Nits * Fix build break * Comment
This commit is contained in:
parent
2d4d0abd36
commit
2ba705ed99
5 changed files with 166 additions and 73 deletions
|
|
@ -10,18 +10,18 @@
|
|||
namespace onnxruntime {
|
||||
// Auto inferred and generate an opschema for stand-alone functions
|
||||
// TODO: revisit to see if we can eliminate typeconstraint step
|
||||
void IOTypeConstraintHelper(const ONNX_NAMESPACE::FunctionProto* onnx_func_proto_,
|
||||
void IOTypeConstraintHelper(const ONNX_NAMESPACE::FunctionProto& onnx_func_proto_,
|
||||
std::unique_ptr<ONNX_NAMESPACE::OpSchema>& op_schema_,
|
||||
const std::unordered_map<std::string, int>& input_name_idx_map,
|
||||
const std::unordered_map<std::string, int>& output_name_idx_map) {
|
||||
std::vector<std::pair<std::string, std::string>> input_types_list(onnx_func_proto_->input_size());
|
||||
std::vector<std::pair<std::string, std::string>> output_types_list(onnx_func_proto_->output_size());
|
||||
std::vector<std::pair<std::string, std::string>> input_types_list(onnx_func_proto_.input_size());
|
||||
std::vector<std::pair<std::string, std::string>> output_types_list(onnx_func_proto_.output_size());
|
||||
std::unordered_map<std::string, std::vector<std::string>> type_constraint_map;
|
||||
std::unordered_map<std::string, ONNX_NAMESPACE::AttributeProto_AttributeType> attribute_type_map;
|
||||
auto schema_registry = ONNX_NAMESPACE::OpSchemaRegistry::Instance();
|
||||
for (auto& node : onnx_func_proto_->node()) {
|
||||
for (auto& node : onnx_func_proto_.node()) {
|
||||
const auto node_op_schema =
|
||||
schema_registry->GetSchema(node.op_type(), static_cast<int>(onnx_func_proto_->since_version()), node.domain());
|
||||
schema_registry->GetSchema(node.op_type(), static_cast<int>(onnx_func_proto_.since_version()), node.domain());
|
||||
for (int i = 0; i < node.input_size(); ++i) {
|
||||
auto& in_name = node.input().Get(i);
|
||||
auto iter = input_name_idx_map.find(in_name);
|
||||
|
|
@ -77,22 +77,83 @@ void IOTypeConstraintHelper(const ONNX_NAMESPACE::FunctionProto* onnx_func_proto
|
|||
op_schema_->TypeConstraint(tc.first, tc.second, "");
|
||||
}
|
||||
|
||||
for (auto& attribute_name : onnx_func_proto_->attribute()) {
|
||||
for (auto& attribute_name : onnx_func_proto_.attribute()) {
|
||||
if (attribute_type_map.count(attribute_name))
|
||||
op_schema_->Attr(attribute_name, "", attribute_type_map[attribute_name], false);
|
||||
}
|
||||
}
|
||||
|
||||
// This method updates the names of inputs/outputs of nodes in subgraphs
|
||||
// within nodes in an op that has a FunctionBody.
|
||||
// Subgraphs within an op with a FunctionBody could be referencing inputs/outputs in the OpSchema
|
||||
// and we need to replace these names with the corresponding input/output names from the actual model graph
|
||||
|
||||
// The arguments to this method are :
|
||||
// (1) The 'subgraph' from a node containing it (ONNX::GraphProto)
|
||||
// (2) The parent 'graph' - main model graph (OnnxRuntime::Graph)
|
||||
// (3) The node with a function body (ONNX::NodeProto)
|
||||
// (4) A map containing the input name from the op schema to the corresponding index
|
||||
// E.g. For Range-11, {"start" : 0, "limit": 1, "delta": 2}
|
||||
// (5) A map containing the output name from the op schema to the corresponding index
|
||||
// E.g. For Range-11, {"output" : 0}
|
||||
static void update_subgraphs_within_function_body(ONNX_NAMESPACE::GraphProto& subgraph_proto,
|
||||
const Graph& parent_graph,
|
||||
const ONNX_NAMESPACE::NodeProto& function_node_in_parent_graph,
|
||||
const std::unordered_map<std::string, int>& input_name_idx_map,
|
||||
const std::unordered_map<std::string, int>& output_name_idx_map) {
|
||||
// Iterate through all the nodes in the subgraph
|
||||
for (auto subgraph_node = subgraph_proto.mutable_node()->begin();
|
||||
subgraph_node != subgraph_proto.mutable_node()->end(); ++subgraph_node) {
|
||||
// Iterate through all the inputs of the current node
|
||||
for (int idx = 0; idx < (*subgraph_node).input_size(); ++idx) {
|
||||
const std::string& tensor_name = (*subgraph_node).input().Get(idx);
|
||||
auto iter = input_name_idx_map.find(tensor_name);
|
||||
// If an input pertaining to the name in the op schema is found,
|
||||
// replace it with the corresponding input to the node with function body from the actual model graph
|
||||
if (iter != input_name_idx_map.end()) {
|
||||
const auto parent_graph_input_to_function_node = function_node_in_parent_graph.input().Get(iter->second);
|
||||
(*subgraph_node).set_input(idx, parent_graph_input_to_function_node);
|
||||
}
|
||||
}
|
||||
// Iterate through all the output of the current node
|
||||
for (int idx = 0; idx < (*subgraph_node).output_size(); ++idx) {
|
||||
const std::string& tensor_name = (*subgraph_node).output().Get(idx);
|
||||
auto iter = output_name_idx_map.find(tensor_name);
|
||||
if (iter != output_name_idx_map.end()) {
|
||||
// If an input pertaining to the name in the op schema is found,
|
||||
// replace it with the corresponding output to the node with function body from the actual model graph
|
||||
const auto& parent_graph_output_to_function_node = function_node_in_parent_graph.output().Get(iter->second);
|
||||
(*subgraph_node).set_output(idx, parent_graph_output_to_function_node);
|
||||
}
|
||||
}
|
||||
|
||||
for (auto subgraph_node_attr = (*subgraph_node).mutable_attribute()->begin();
|
||||
subgraph_node_attr != (*subgraph_node).mutable_attribute()->end(); ++subgraph_node_attr) {
|
||||
if ((*subgraph_node_attr).has_f()) {
|
||||
ORT_THROW(
|
||||
"A node with a function body within a subgraph within another function body "
|
||||
"is currently not supported in ORT");
|
||||
}
|
||||
// Recurse into any subgraphs in the current subgraph being processed
|
||||
if ((*subgraph_node_attr).has_g()) {
|
||||
update_subgraphs_within_function_body(*(*subgraph_node_attr).mutable_g(),
|
||||
parent_graph, function_node_in_parent_graph,
|
||||
input_name_idx_map, output_name_idx_map);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
FunctionImpl::FunctionImpl(const onnxruntime::Graph& graph,
|
||||
std::unique_ptr<IndexedSubGraph> customized_func)
|
||||
: parent_graph_(&graph), onnx_func_proto_{nullptr} {
|
||||
: parent_graph_(&graph) {
|
||||
customized_func_body_ = std::move(customized_func);
|
||||
|
||||
// Construct body.
|
||||
body_ = onnxruntime::make_unique<onnxruntime::Model>("fused_function_subgraph", false, onnxruntime::ModelMetaData(),
|
||||
IOnnxRuntimeOpSchemaRegistryList({graph.GetSchemaRegistry()}),
|
||||
graph.DomainToVersionMap());
|
||||
auto& sub_graph = body_->MainGraph();
|
||||
IOnnxRuntimeOpSchemaRegistryList({graph.GetSchemaRegistry()}),
|
||||
graph.DomainToVersionMap());
|
||||
auto& function_body_graph = body_->MainGraph();
|
||||
|
||||
auto meta_def = customized_func_body_->GetMetaDef();
|
||||
op_schema_ = onnxruntime::make_unique<ONNX_NAMESPACE::OpSchema>();
|
||||
|
|
@ -101,30 +162,30 @@ FunctionImpl::FunctionImpl(const onnxruntime::Graph& graph,
|
|||
op_schema_->SetDoc(meta_def->doc_string);
|
||||
op_schema_->SinceVersion(meta_def->since_version);
|
||||
int i = 0;
|
||||
std::vector<const NodeArg*> sub_graph_inputs;
|
||||
sub_graph_inputs.resize(meta_def->inputs.size());
|
||||
std::vector<const NodeArg*> function_body_graph_inputs;
|
||||
function_body_graph_inputs.resize(meta_def->inputs.size());
|
||||
for (auto& input : meta_def->inputs) {
|
||||
auto input_arg = parent_graph_->GetNodeArg(input);
|
||||
auto& sub_graph_input_arg = sub_graph.GetOrCreateNodeArg(input_arg->Name(), input_arg->TypeAsProto());
|
||||
sub_graph_inputs[i] = &sub_graph_input_arg;
|
||||
auto& function_body_graph_input_arg = function_body_graph.GetOrCreateNodeArg(input_arg->Name(), input_arg->TypeAsProto());
|
||||
function_body_graph_inputs[i] = &function_body_graph_input_arg;
|
||||
ORT_ENFORCE(input_arg->Type() != nullptr);
|
||||
op_schema_->Input(i, input, "", *input_arg->Type());
|
||||
++i;
|
||||
}
|
||||
i = 0;
|
||||
std::vector<const NodeArg*> sub_graph_outputs;
|
||||
sub_graph_outputs.resize(meta_def->outputs.size());
|
||||
std::vector<const NodeArg*> function_body_graph_outputs;
|
||||
function_body_graph_outputs.resize(meta_def->outputs.size());
|
||||
for (auto& output : meta_def->outputs) {
|
||||
auto output_arg = parent_graph_->GetNodeArg(output);
|
||||
auto& sub_graph_output_arg = sub_graph.GetOrCreateNodeArg(output_arg->Name(), output_arg->TypeAsProto());
|
||||
sub_graph_outputs[i] = &sub_graph_output_arg;
|
||||
auto& function_body_graph_output_arg = function_body_graph.GetOrCreateNodeArg(output_arg->Name(), output_arg->TypeAsProto());
|
||||
function_body_graph_outputs[i] = &function_body_graph_output_arg;
|
||||
op_schema_->Output(i, output, "", *output_arg->Type());
|
||||
++i;
|
||||
}
|
||||
op_schema_->Finalize();
|
||||
|
||||
sub_graph.SetInputs(sub_graph_inputs);
|
||||
sub_graph.SetOutputs(sub_graph_outputs);
|
||||
function_body_graph.SetInputs(function_body_graph_inputs);
|
||||
function_body_graph.SetOutputs(function_body_graph_outputs);
|
||||
//Add node and node args
|
||||
//TODO: for better performance, we could try to transfer the nodes in parent graph to sub-graph directly,
|
||||
//instead of create new nodes.
|
||||
|
|
@ -133,46 +194,51 @@ FunctionImpl::FunctionImpl(const onnxruntime::Graph& graph,
|
|||
std::vector<onnxruntime::NodeArg*> inputs;
|
||||
std::vector<onnxruntime::NodeArg*> outputs;
|
||||
for (auto input : node->InputDefs()) {
|
||||
auto& n_input = sub_graph.GetOrCreateNodeArg(input->Name(), input->TypeAsProto());
|
||||
auto& n_input = function_body_graph.GetOrCreateNodeArg(input->Name(), input->TypeAsProto());
|
||||
inputs.push_back(&n_input);
|
||||
}
|
||||
for (auto output : node->OutputDefs()) {
|
||||
auto& n_output = sub_graph.GetOrCreateNodeArg(output->Name(), output->TypeAsProto());
|
||||
auto& n_output = function_body_graph.GetOrCreateNodeArg(output->Name(), output->TypeAsProto());
|
||||
outputs.push_back(&n_output);
|
||||
}
|
||||
sub_graph.AddNode(node->Name(), node->OpType(), node->Description(), inputs, outputs, &node->GetAttributes(), node->Domain());
|
||||
function_body_graph.AddNode(node->Name(), node->OpType(), node->Description(), inputs, outputs, &node->GetAttributes(), node->Domain());
|
||||
}
|
||||
|
||||
for (const auto& input : meta_def->inputs) {
|
||||
const ONNX_NAMESPACE::TensorProto* initializer = nullptr;
|
||||
if (graph.GetInitializedTensor(input, initializer)) {
|
||||
sub_graph.AddInitializedTensor(*initializer);
|
||||
function_body_graph.AddInitializedTensor(*initializer);
|
||||
}
|
||||
}
|
||||
|
||||
//TODO: if we reuse the nodes in parent graph, maybe we don't need to resolve it.
|
||||
auto status = sub_graph.Resolve();
|
||||
auto status = function_body_graph.Resolve();
|
||||
ORT_ENFORCE(status.IsOK(), status.ErrorMessage());
|
||||
}
|
||||
|
||||
FunctionImpl::FunctionImpl(const onnxruntime::Graph& graph,
|
||||
const onnxruntime::NodeIndex& node_index,
|
||||
const ONNX_NAMESPACE::FunctionProto* onnx_func_proto)
|
||||
const ONNX_NAMESPACE::FunctionProto& onnx_func_proto)
|
||||
: parent_graph_(&graph) {
|
||||
// Make a copy of the FunctionProto.
|
||||
// All FunctionBody ops with the same op type seem to share the same FunctionProto struct within a model.
|
||||
// Hence, we make a copy prior to generating the graph representation of the function,
|
||||
// as we might make some modifications to the FunctionProto along the way
|
||||
onnx_func_proto_ = onnx_func_proto;
|
||||
|
||||
auto node_in_parent_graph = parent_graph_->GetNode(node_index);
|
||||
op_schema_ = onnxruntime::make_unique<ONNX_NAMESPACE::OpSchema>();
|
||||
op_schema_->SetName(onnx_func_proto_->name());
|
||||
op_schema_->SetDomain(onnx_func_proto_->node().Get(0).domain());
|
||||
op_schema_->SetDoc(onnx_func_proto_->doc_string());
|
||||
op_schema_->SinceVersion(static_cast<ONNX_NAMESPACE::OperatorSetVersion>(onnx_func_proto_->since_version()));
|
||||
op_schema_->SetName(onnx_func_proto_.name());
|
||||
op_schema_->SetDomain(onnx_func_proto_.node().Get(0).domain());
|
||||
op_schema_->SetDoc(onnx_func_proto_.doc_string());
|
||||
op_schema_->SinceVersion(static_cast<ONNX_NAMESPACE::OperatorSetVersion>(onnx_func_proto_.since_version()));
|
||||
std::unordered_map<std::string, int> input_name_idx_map;
|
||||
std::unordered_map<std::string, int> output_name_idx_map;
|
||||
for (int i = 0; i < onnx_func_proto_->input_size(); ++i) {
|
||||
input_name_idx_map[onnx_func_proto_->input().Get(i)] = i;
|
||||
for (int i = 0; i < onnx_func_proto_.input_size(); ++i) {
|
||||
input_name_idx_map[onnx_func_proto_.input().Get(i)] = i;
|
||||
}
|
||||
for (int i = 0; i < onnx_func_proto_->output_size(); ++i) {
|
||||
output_name_idx_map[onnx_func_proto_->output().Get(i)] = i;
|
||||
for (int i = 0; i < onnx_func_proto_.output_size(); ++i) {
|
||||
output_name_idx_map[onnx_func_proto_.output().Get(i)] = i;
|
||||
}
|
||||
|
||||
auto cached_op_schema = node_in_parent_graph->Op();
|
||||
|
|
@ -219,75 +285,87 @@ FunctionImpl::FunctionImpl(const onnxruntime::Graph& graph,
|
|||
//construct body
|
||||
std::unordered_map<std::string, int> domain_to_version;
|
||||
//TODO: set correct domain and version
|
||||
domain_to_version[onnxruntime::kOnnxDomain] = static_cast<int>(onnx_func_proto_->since_version());
|
||||
body_ = onnxruntime::make_unique<onnxruntime::Model>(onnx_func_proto_->name(), false, onnxruntime::ModelMetaData(),
|
||||
IOnnxRuntimeOpSchemaRegistryList(), domain_to_version);
|
||||
auto& sub_graph = body_->MainGraph();
|
||||
domain_to_version[onnxruntime::kOnnxDomain] = static_cast<int>(onnx_func_proto_.since_version());
|
||||
body_ = onnxruntime::make_unique<onnxruntime::Model>(onnx_func_proto_.name(), false, onnxruntime::ModelMetaData(),
|
||||
IOnnxRuntimeOpSchemaRegistryList(), domain_to_version);
|
||||
auto& function_body_graph = body_->MainGraph();
|
||||
// Add node and node args into subgraph
|
||||
// The subgraph preserved the input/output tensor names
|
||||
// in the parent graph for later inlining purpose
|
||||
const auto& attr_map = node_in_parent_graph->GetAttributes();
|
||||
for (auto& node : onnx_func_proto_->node()) {
|
||||
|
||||
ONNX_NAMESPACE::NodeProto function_op_node_proto; // NodeProto pertaining to the op with a FunctionBody
|
||||
node_in_parent_graph->ToProto(function_op_node_proto);
|
||||
|
||||
// iterate over each node in the FunctionProto and fix inputs/outputs
|
||||
for (auto node = onnx_func_proto_.mutable_node()->begin(); node != onnx_func_proto_.mutable_node()->end(); ++node) {
|
||||
std::vector<onnxruntime::NodeArg*> inputs;
|
||||
std::vector<onnxruntime::NodeArg*> outputs;
|
||||
std::string uniq_identifier = node.name();
|
||||
if (!utils::HasName(node)) {
|
||||
std::string uniq_identifier = (*node).name();
|
||||
if (!utils::HasName(*node)) {
|
||||
std::stringstream ss;
|
||||
ss << static_cast<const void*>(&node);
|
||||
ss << static_cast<const void*>(&(*node));
|
||||
uniq_identifier = ss.str();
|
||||
}
|
||||
|
||||
for (int idx = 0; idx < node.input_size(); ++idx) {
|
||||
std::string tensor_name = node.input().Get(idx);
|
||||
for (int idx = 0; idx < (*node).input_size(); ++idx) {
|
||||
std::string tensor_name = (*node).input().Get(idx);
|
||||
auto iter = input_name_idx_map.find(tensor_name);
|
||||
if (iter != input_name_idx_map.end()) {
|
||||
// Preserving NodeArg and input/output names
|
||||
ONNX_NAMESPACE::NodeProto temp_node_proto;
|
||||
node_in_parent_graph->ToProto(temp_node_proto);
|
||||
const onnxruntime::NodeArg* node_arg = parent_graph_->GetNodeArg(temp_node_proto.input().Get(input_name_idx_map[tensor_name]));
|
||||
auto& n_input = sub_graph.GetOrCreateNodeArg(
|
||||
temp_node_proto.input().Get(iter->second), node_arg->TypeAsProto());
|
||||
const onnxruntime::NodeArg* node_arg = parent_graph_->GetNodeArg(function_op_node_proto.input()
|
||||
.Get(iter->second));
|
||||
auto& n_input = function_body_graph.GetOrCreateNodeArg(
|
||||
function_op_node_proto.input().Get(iter->second), node_arg->TypeAsProto());
|
||||
inputs.push_back(&n_input);
|
||||
} else {
|
||||
auto& n_input = sub_graph.GetOrCreateNodeArg(
|
||||
auto& n_input = function_body_graph.GetOrCreateNodeArg(
|
||||
tensor_name + "_" + std::to_string(node_index), nullptr);
|
||||
inputs.push_back(&n_input);
|
||||
}
|
||||
}
|
||||
for (int idx = 0; idx < node.output_size(); ++idx) {
|
||||
std::string tensor_name = node.output().Get(idx);
|
||||
for (int idx = 0; idx < (*node).output_size(); ++idx) {
|
||||
std::string tensor_name = (*node).output().Get(idx);
|
||||
auto iter = output_name_idx_map.find(tensor_name);
|
||||
if (iter != output_name_idx_map.end()) {
|
||||
// Preserving NodeArg and input/output names
|
||||
ONNX_NAMESPACE::NodeProto temp_node_proto;
|
||||
node_in_parent_graph->ToProto(temp_node_proto);
|
||||
const onnxruntime::NodeArg* node_arg = parent_graph_->GetNodeArg(temp_node_proto.output().Get(output_name_idx_map[tensor_name]));
|
||||
auto& n_output = sub_graph.GetOrCreateNodeArg(
|
||||
temp_node_proto.output().Get(iter->second), node_arg->TypeAsProto());
|
||||
const onnxruntime::NodeArg* node_arg = parent_graph_->GetNodeArg(function_op_node_proto.output()
|
||||
.Get(iter->second));
|
||||
auto& n_output = function_body_graph.GetOrCreateNodeArg(
|
||||
function_op_node_proto.output().Get(iter->second), node_arg->TypeAsProto());
|
||||
outputs.push_back(&n_output);
|
||||
} else {
|
||||
auto& n_output = sub_graph.GetOrCreateNodeArg(
|
||||
auto& n_output = function_body_graph.GetOrCreateNodeArg(
|
||||
tensor_name + "_" + std::to_string(node_index), nullptr);
|
||||
outputs.push_back(&n_output);
|
||||
}
|
||||
}
|
||||
|
||||
onnxruntime::NodeAttributes new_attr_map;
|
||||
for (auto& attr : node.attribute()) {
|
||||
if (!attr.ref_attr_name().empty()) {
|
||||
auto entry = attr_map.find(attr.ref_attr_name());
|
||||
for (auto node_attr = (*node).mutable_attribute()->begin();
|
||||
node_attr != (*node).mutable_attribute()->end(); ++node_attr) {
|
||||
// If this node contains subgraphs, the node inputs/outputs within them needs to be fixed as well
|
||||
if ((*node_attr).has_g()) {
|
||||
update_subgraphs_within_function_body(*(*node_attr).mutable_g(),
|
||||
*parent_graph_, function_op_node_proto,
|
||||
input_name_idx_map, output_name_idx_map);
|
||||
}
|
||||
|
||||
if (!(*node_attr).ref_attr_name().empty()) {
|
||||
auto entry = attr_map.find((*node_attr).ref_attr_name());
|
||||
if (entry != attr_map.cend()) {
|
||||
new_attr_map[attr.name()] = entry->second;
|
||||
new_attr_map[(*node_attr).name()] = entry->second;
|
||||
}
|
||||
} else {
|
||||
new_attr_map[attr.name()] = attr;
|
||||
new_attr_map[(*node_attr).name()] = *node_attr;
|
||||
}
|
||||
}
|
||||
sub_graph.AddNode(uniq_identifier + "_" + std::to_string(node_index), node.op_type(), node.doc_string(), inputs, outputs, &new_attr_map, node.domain());
|
||||
function_body_graph.AddNode(uniq_identifier + "_" + std::to_string(node_index), (*node).op_type(), (*node).doc_string(), inputs, outputs, &new_attr_map, (*node).domain());
|
||||
}
|
||||
auto status = sub_graph.Resolve();
|
||||
|
||||
auto status = function_body_graph.Resolve();
|
||||
ORT_ENFORCE(status.IsOK(), "Resolve subgraph failed:", status.ErrorMessage());
|
||||
}
|
||||
} // namespace onnxruntime
|
||||
|
||||
FunctionImpl::~FunctionImpl() = default;
|
||||
|
||||
|
|
@ -304,7 +382,7 @@ const IndexedSubGraph& FunctionImpl::GetIndexedSubGraph() const {
|
|||
}
|
||||
|
||||
const ONNX_NAMESPACE::FunctionProto* FunctionImpl::GetFuncProto() const {
|
||||
return onnx_func_proto_;
|
||||
return &onnx_func_proto_;
|
||||
}
|
||||
|
||||
std::unique_ptr<Function> MakeFunction(const onnxruntime::Graph& graph,
|
||||
|
|
|
|||
|
|
@ -20,7 +20,7 @@ class FunctionImpl final : public Function {
|
|||
|
||||
FunctionImpl(const onnxruntime::Graph& graph,
|
||||
const onnxruntime::NodeIndex& node_index,
|
||||
const ONNX_NAMESPACE::FunctionProto* onnx_func);
|
||||
const ONNX_NAMESPACE::FunctionProto& onnx_func);
|
||||
|
||||
~FunctionImpl() override;
|
||||
|
||||
|
|
@ -37,7 +37,7 @@ class FunctionImpl final : public Function {
|
|||
std::unique_ptr<IndexedSubGraph> customized_func_body_;
|
||||
std::unique_ptr<ONNX_NAMESPACE::OpSchema> op_schema_;
|
||||
std::unique_ptr<onnxruntime::Model> body_;
|
||||
const ONNX_NAMESPACE::FunctionProto* onnx_func_proto_;
|
||||
ONNX_NAMESPACE::FunctionProto onnx_func_proto_;
|
||||
};
|
||||
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -669,8 +669,8 @@ Graph::Graph(GraphProto* graph_proto, const std::unordered_map<std::string, int>
|
|||
const gsl::not_null<TensorProto*> tensor{graph_proto_->add_initializer()};
|
||||
const AttributeProto& constant_attribute = node.attribute(0);
|
||||
// TODO: Add support for parsing 'sparse_value' attribute from a 'Constant' node
|
||||
// Discussion surrounding handling the SparseTensorProto must be had.
|
||||
// An easy way is to implement a method that converts a SparseTensorproto into a TensorProto
|
||||
// Discussion surrounding handling the SparseTensorProto must be had.
|
||||
// An easy way is to implement a method that converts a SparseTensorproto into a TensorProto
|
||||
// to use the same downstream flow, but that is going to impact peak memory usage and probably a smarter way is required.
|
||||
ORT_ENFORCE(constant_attribute.has_t(), "Only 'value' attribute is supported within a 'Constant' node in ORT");
|
||||
*tensor = constant_attribute.t();
|
||||
|
|
@ -1722,7 +1722,7 @@ Status Graph::VerifyNodeAndOpMatch() {
|
|||
auto iter = model_functions_.find(node.OpType());
|
||||
if (iter != model_functions_.end()) {
|
||||
const ONNX_NAMESPACE::FunctionProto* model_function_proto = iter->second;
|
||||
auto model_func_ptr = onnxruntime::make_unique<onnxruntime::FunctionImpl>(*this, node.Index(), model_function_proto);
|
||||
auto model_func_ptr = onnxruntime::make_unique<onnxruntime::FunctionImpl>(*this, node.Index(), *model_function_proto);
|
||||
function_container_.emplace_back(std::move(model_func_ptr));
|
||||
node.SetFunctionBody(*function_container_.back());
|
||||
}
|
||||
|
|
@ -1743,7 +1743,7 @@ Status Graph::VerifyNodeAndOpMatch() {
|
|||
|
||||
if (node.op_ && node.op_->HasFunction()) {
|
||||
auto onnx_function_proto = node.op_->GetFunction();
|
||||
auto func_ptr = onnxruntime::make_unique<onnxruntime::FunctionImpl>(*this, node.Index(), onnx_function_proto);
|
||||
auto func_ptr = onnxruntime::make_unique<onnxruntime::FunctionImpl>(*this, node.Index(), *onnx_function_proto);
|
||||
function_container_.emplace_back(std::move(func_ptr));
|
||||
node.SetFunctionBody(*function_container_.back());
|
||||
}
|
||||
|
|
|
|||
|
|
@ -162,5 +162,20 @@ TEST(ONNXModelsTest, TestIRv4NonInputInitializers) {
|
|||
ASSERT_TRUE(Model::Load("testdata/subgraph_implicit_input_from_initializer.onnx", model).IsOK());
|
||||
EXPECT_TRUE(model->MainGraph().Resolve().IsOK());
|
||||
}
|
||||
|
||||
// test a model that has an op with a FunctionBody and one of the nodes within the FunctionBody has a subgraph in it.
|
||||
// The test model has is an opset-11 op with a 'Range' node.
|
||||
// 'Range' has a FunctionBody and has a 'Loop' node with a subgraph.
|
||||
// Graph::Resolve to succeed when processing the subgraph pertaining to the overall FunctionBody.
|
||||
TEST(ONNXModelsTest, TestModelsWithAnOpContainingAFunctionBody) {
|
||||
std::shared_ptr<Model> model;
|
||||
|
||||
auto status = Model::Load("testdata/model_containing_op_with_function_body.onnx", model);
|
||||
EXPECT_TRUE(status.IsOK()) << status;
|
||||
|
||||
status = model->MainGraph().Resolve();
|
||||
EXPECT_TRUE(status.IsOK()) << status;
|
||||
}
|
||||
|
||||
} // namespace test
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
BIN
onnxruntime/test/testdata/model_containing_op_with_function_body.onnx
vendored
Normal file
BIN
onnxruntime/test/testdata/model_containing_op_with_function_body.onnx
vendored
Normal file
Binary file not shown.
Loading…
Reference in a new issue