diff --git a/include/onnxruntime/core/graph/graph.h b/include/onnxruntime/core/graph/graph.h index 6eb39db064..e726bc0d7c 100644 --- a/include/onnxruntime/core/graph/graph.h +++ b/include/onnxruntime/core/graph/graph.h @@ -121,7 +121,7 @@ class Node { /** Gets the Node's Node::Type. */ Node::Type NodeType() const noexcept { return node_type_; } - + /** Gets the opset version that the Node's operator was first defined in. @returns Opset version. If -1 the Node's operator has not been set. @remarks Prefer over Op()->SinceVersion() as Op() is disabled in a minimal build @@ -1029,13 +1029,12 @@ class Graph { /** Returns true if the name is for a value that is coming from outer scope */ bool IsOuterScopeValue(const std::string& name) const { -#if !defined(ORT_MINIMAL_BUILD) - return resolve_context_.outer_scope_node_args.find(name) != resolve_context_.outer_scope_node_args.cend(); -#else - // we shouldn't have code that calls this in a minimal build - ORT_UNUSED_PARAMETER(name); - ORT_THROW("Internal error. Outer scope value lookup is not currently supported in a minimal build."); -#endif + if (!parent_node_) return false; + const auto& implicit_input_defs = parent_node_->ImplicitInputDefs(); + return std::any_of(implicit_input_defs.cbegin(), implicit_input_defs.cend(), + [&name](const NodeArg* implicit_input) { + return implicit_input->Name() == name; + }); } #if !defined(ORT_MINIMAL_BUILD) diff --git a/onnxruntime/core/framework/session_state.cc b/onnxruntime/core/framework/session_state.cc index 2069987faa..c5fe24cd56 100644 --- a/onnxruntime/core/framework/session_state.cc +++ b/onnxruntime/core/framework/session_state.cc @@ -234,36 +234,39 @@ void SessionState::CleanInitializedTensorsFromGraph() { graph_.CleanAllInitializedTensors(); } -Status SessionState::PrepackInitializedConstantTensors() { - // calculate the use count of each value - std::unordered_map node_arg_use_count; - for (const auto& node : GetGraphViewer().Nodes()) { - node.ForEachDef([&](const onnxruntime::NodeArg& node_arg, bool is_input) { - if (is_input) { - node_arg_use_count[node_arg.Name()]++; - } - }); - } - +Status SessionState::PrepackConstantInitializedTensors(std::unordered_map& constant_initializers_use_count) { for (auto& node : GetGraphViewer().Nodes()) { auto kernel = GetMutableKernel(node.Index()); int input_idx = 0; for (auto& input_def : node.InputDefs()) { if (input_def->Exists()) { const std::string& input_name = input_def->Name(); - int ort_value_idx; - ORT_RETURN_IF_ERROR(ort_value_name_idx_map_.GetIdx(input_name, ort_value_idx)); - if (constant_initialized_tensors_.count(ort_value_idx) && - constant_initialized_tensors_[ort_value_idx].IsTensor()) { - bool is_packed = false; - const Tensor& const_initialized_tensor = constant_initialized_tensors_[ort_value_idx].Get(); - ORT_RETURN_IF_ERROR(kernel->PrePack(const_initialized_tensor, input_idx, is_packed)); - if (is_packed && node_arg_use_count.count(input_name) && --node_arg_use_count[input_name] == 0) { - // release the constant intialized tensor - initialized_tensors_.erase(ort_value_idx); - constant_initialized_tensors_.erase(ort_value_idx); + SessionState* st = this; + // subgraph can use the value from outer scope, + // so it needs to check if current node uses constant initialized tensor from current and outer graphs + do { + int ort_value_idx; + if (st->GetOrtValueNameIdxMap().GetIdx(input_name, ort_value_idx).IsOK()) { + std::unordered_map& constant_initialized_tensors = st->constant_initialized_tensors_; + if (constant_initialized_tensors.count(ort_value_idx)) { + bool is_packed = false; + const Tensor& const_initialized_tensor = constant_initialized_tensors[ort_value_idx].Get(); + ORT_RETURN_IF_ERROR(kernel->PrePack(const_initialized_tensor, input_idx, is_packed)); + if (is_packed && constant_initializers_use_count.count(input_name) && --constant_initializers_use_count[input_name] == 0) { + // release the constant initialized tensor + st->initialized_tensors_.erase(ort_value_idx); + constant_initialized_tensors.erase(ort_value_idx); + } + } + // stop searching in 2 cases: + // 1. value is not from OuterScope + // 2. value is from OuterScope and the current OuterScope has the value + if (st != this || !st->graph_.IsOuterScopeValue(input_name)) { + break; + } } - } + st = st->Parent(); + } while (st); } input_idx++; } @@ -567,10 +570,13 @@ void SessionState::AddSubgraphSessionState(onnxruntime::NodeIndex index, const s ORT_ENFORCE(existing_entries.find(attribute_name) == existing_entries.cend(), "Entry exists in node ", index, " for attribute ", attribute_name); } -#ifdef ONNXRUNTIME_ENABLE_INSTRUMENT + session_state->parent_ = this; + +#ifdef ONNXRUNTIME_ENABLE_INSTRUMENT GenerateGraphId(); #endif + subgraph_session_states_[index].insert(std::make_pair(attribute_name, std::move(session_state))); } @@ -776,6 +782,27 @@ Status SessionState::LoadFromOrtFormat(const fbs::SessionState& fbs_session_stat } #endif +// Calculate the use count of a constant initialized tensor, including the use in subgraph. +// Note: This function doesn't handle the case below: +// The main graph has a constant initializer called X, and the subgraph also has a constant initializer called X, which overrides the X from main graph. +// For case like this, the current implementation will calculate the use count as 2, but they could contain completely different values so each should have a use count of 1. +// This is a very rare case. If it happens and X is prepacked, the consequence is that X won't be released and memory usage of X won't be saved. This will be fine. +static void ComputeConstantInitializerUseCount(const Graph& graph, std::unordered_map& constant_initializers_use_count) { + for (const auto& node : graph.Nodes()) { + for (const auto* arg : node.InputDefs()) { + if (arg->Exists() && graph.GetConstantInitializer(arg->Name(), true /*check_outer_scope*/)) { + constant_initializers_use_count[arg->Name()]++; + } + } + + if (node.ContainsSubgraph()) { + for (const gsl::not_null& subgraph : node.GetSubgraphs()) { + ComputeConstantInitializerUseCount(*subgraph, constant_initializers_use_count); + } + } + } +} + Status SessionState::FinalizeSessionState(const std::basic_string& graph_location, KernelRegistryManager& kernel_registry_manager, const SessionOptions& session_options, @@ -807,15 +834,18 @@ Status SessionState::FinalizeSessionState(const std::basic_string constant_initializers_use_count; + ComputeConstantInitializerUseCount(graph_, constant_initializers_use_count); return FinalizeSessionStateImpl(graph_location, kernel_registry_manager, nullptr, session_options, - remove_initializers); + remove_initializers, constant_initializers_use_count); } Status SessionState::FinalizeSessionStateImpl(const std::basic_string& graph_location, KernelRegistryManager& kernel_registry_manager, _In_opt_ const Node* parent_node, const SessionOptions& session_options, - bool remove_initializers) { + bool remove_initializers, + std::unordered_map& constant_initializers_use_count) { CreateGraphInfo(); // ignore any outer scope args we don't know about. this can happen if a node contains multiple subgraphs. @@ -868,7 +898,7 @@ Status SessionState::FinalizeSessionStateImpl(const std::basic_string& constant_initializers_use_count); SessionState* GetMutableSubgraphSessionState(onnxruntime::NodeIndex index, const std::string& attribute_name); @@ -315,7 +319,8 @@ class SessionState { KernelRegistryManager& kernel_registry_manager, _In_opt_ const Node* parent_node, const SessionOptions& session_options, - bool remove_initializers); + bool remove_initializers, + std::unordered_map& constant_initializers_use_count); #ifdef ENABLE_TRAINING Status GeneratePatternGroupCache( @@ -421,9 +426,9 @@ class SessionState { std::map, std::unordered_set> to_be_executed_nodes_; #endif -#ifdef ONNXRUNTIME_ENABLE_INSTRUMENT SessionState* parent_ = nullptr; //Assign each graph in each session an unique id. +#ifdef ONNXRUNTIME_ENABLE_INSTRUMENT int graph_id_ = 0; int next_graph_id_ = 1; diff --git a/onnxruntime/test/framework/session_state_test.cc b/onnxruntime/test/framework/session_state_test.cc index 58758ee667..0d16253421 100644 --- a/onnxruntime/test/framework/session_state_test.cc +++ b/onnxruntime/test/framework/session_state_test.cc @@ -189,20 +189,7 @@ class PrePackingTestOpKernel : public OpKernel { } }; -class SessionStatePrepackingTest : public testing::TestWithParam {}; -TEST_P(SessionStatePrepackingTest, PrePackingTest) { - OrtThreadPoolParams to; - auto tp = concurrency::CreateThreadPool(&onnxruntime::Env::Default(), to, concurrency::ThreadPoolType::INTRA_OP); - ONNX_OPERATOR_SCHEMA(PrePackingTest) - .SetDoc("Faking Node for PrePacking") - .Input(0, "Input_0", "input 0", "tensor(float)") - .Input(1, "Input_1", "input 1", "tensor(float)") - .Output(0, "output_0", "docstr for output_0.", "tensor(float)"); - - onnxruntime::Model model("graph_1", false, DefaultLoggingManager().DefaultLogger()); - // construct graph - auto& graph = model.MainGraph(); - +static void CreateSimpleGraph(Graph& graph) { // node creation and placement TypeProto type; type.mutable_tensor_type()->set_elem_type(TensorProto_DataType_FLOAT); @@ -218,8 +205,7 @@ TEST_P(SessionStatePrepackingTest, PrePackingTest) { onnxruntime::NodeArg output_arg("node_0_output_0", &type); outputs.push_back(&output_arg); - onnxruntime::Node& node = graph.AddNode("node_0", "PrePackingTest", "node 0", inputs, outputs); - node.SetExecutionProviderType(kCpuExecutionProvider); + graph.AddNode("node_0", "PrePackingTest", "node 0", inputs, outputs); // add an initializer ONNX_NAMESPACE::TensorProto tensor; @@ -231,6 +217,123 @@ TEST_P(SessionStatePrepackingTest, PrePackingTest) { auto status = graph.Resolve(); ASSERT_TRUE(status.IsOK()); +} + +static const ONNX_NAMESPACE::GraphProto CreateSubgraph(bool then_branch) { + Model model(then_branch ? "If_then" : "If_else", false, DefaultLoggingManager().DefaultLogger()); + auto& graph = model.MainGraph(); + + std::vector inputs; + std::vector outputs; + + const std::string suffix = then_branch ? "0" : "1"; + + // graph input has to have type and rank even though it's an outer scope value. + TypeProto type_float; + type_float.mutable_tensor_type()->set_elem_type(TensorProto_DataType_FLOAT); + type_float.mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(1); + + // outer scope values + auto& if_shared = graph.GetOrCreateNodeArg("if_shared", &type_float); + auto& if_input = graph.GetOrCreateNodeArg("if_input_" + suffix, &type_float); + + // add so that we don't end up with it being considered a graph input + graph.AddOuterScopeNodeArg("if_shared"); + graph.AddOuterScopeNodeArg("if_input_" + suffix); + + auto& if_out = graph.GetOrCreateNodeArg("if_output_" + suffix, &type_float); + + inputs = {&if_shared, &if_input}; + outputs = {&if_out}; + + graph.AddNode("if_node_" + suffix, "PrePackingTest", "if node " + suffix, inputs, outputs); + + auto status = graph.Resolve(); + EXPECT_EQ(status, Status::OK()); + + auto& proto = graph.ToGraphProto(); + + return proto; +} + +static void CreateGraphWithSubgraph(Graph& graph) { + TypeProto type_float; + type_float.mutable_tensor_type()->set_elem_type(TensorProto_DataType_FLOAT); + type_float.mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(1); + + { + std::vector inputs; + onnxruntime::NodeArg input_0_arg("if_input_0", &type_float); + onnxruntime::NodeArg input_1_arg("if_input_1", &type_float); + inputs.push_back(&input_0_arg); + inputs.push_back(&input_1_arg); + + std::vector outputs; + onnxruntime::NodeArg output_arg("node_0_output_0", &type_float); + outputs.push_back(&output_arg); + + graph.AddNode("node_0", "PrePackingTest", "node 0", inputs, outputs); + } + + { + TypeProto type_bool; + type_bool.mutable_tensor_type()->set_elem_type(TensorProto_DataType_BOOL); + type_bool.mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(1); + + onnxruntime::NodeArg bool_arg("bool_arg", &type_bool); + + std::vector outputs; + onnxruntime::NodeArg output_arg("output_arg", &type_float); + outputs.push_back(&output_arg); + + auto& if_node = graph.AddNode("if", "If", "If node", {&bool_arg}, outputs); + + auto then_proto = CreateSubgraph(true); + auto else_proto = CreateSubgraph(false); + if_node.AddAttribute("then_branch", then_proto); + if_node.AddAttribute("else_branch", else_proto); + } + + // add an initializer + ONNX_NAMESPACE::TensorProto tensor; + tensor.add_dims(1); + tensor.add_float_data(1.0f); + tensor.set_data_type(TensorProto_DataType_FLOAT); + tensor.set_name("if_shared"); + graph.AddInitializedTensor(tensor); + + auto status = graph.Resolve(); + ASSERT_TRUE(status.IsOK()); +} + +static void PlaceAllNodesToCPUEP(Graph& graph) { + for (auto& node : graph.Nodes()) { + node.SetExecutionProviderType(kCpuExecutionProvider); + if (node.ContainsSubgraph()) { + for (auto& entry : node.GetAttributeNameToMutableSubgraphMap()) { + Graph* subgraph = entry.second; + PlaceAllNodesToCPUEP(*subgraph); + } + } + } +} + +struct PrepackingTestParam { + bool test_subgraph; + bool test_prepacking; +}; + +class SessionStatePrepackingTest : public testing::TestWithParam {}; +TEST_P(SessionStatePrepackingTest, PrePackingTest) { + PrepackingTestParam test_param = GetParam(); + + OrtThreadPoolParams to; + auto tp = concurrency::CreateThreadPool(&onnxruntime::Env::Default(), to, concurrency::ThreadPoolType::INTRA_OP); + ONNX_OPERATOR_SCHEMA(PrePackingTest) + .SetDoc("Faking Node for PrePacking") + .Input(0, "Input_0", "input 0", "tensor(float)") + .Input(1, "Input_1", "input 1", "tensor(float)") + .Output(0, "output_0", "docstr for output_0.", "tensor(float)"); ExecutionProviders execution_providers; auto cpu_execution_provider = onnxruntime::make_unique(CPUExecutionProviderInfo(false)); @@ -238,7 +341,21 @@ TEST_P(SessionStatePrepackingTest, PrePackingTest) { DataTransferManager dtm; profiling::Profiler profiler; - SessionState session_state(graph, + + std::unordered_map domain_to_version; + domain_to_version[kOnnxDomain] = 11; + Model model("graph_main", false, ModelMetaData(), PathString(), IOnnxRuntimeOpSchemaRegistryList(), + domain_to_version, std::vector(), + DefaultLoggingManager().DefaultLogger()); + + // onnxruntime::Model model("graph_main", false, DefaultLoggingManager().DefaultLogger()); + if (test_param.test_subgraph) { + CreateGraphWithSubgraph(model.MainGraph()); + } else { + CreateSimpleGraph(model.MainGraph()); + } + + SessionState session_state(model.MainGraph(), execution_providers, true, /*enable_mem_pattern*/ tp.get(), @@ -248,7 +365,7 @@ TEST_P(SessionStatePrepackingTest, PrePackingTest) { profiler); KernelRegistryManager kernel_registry_manager; - status = kernel_registry_manager.RegisterKernels(execution_providers); + Status status = kernel_registry_manager.RegisterKernels(execution_providers); ASSERT_TRUE(status.IsOK()) << status.ErrorMessage(); std::shared_ptr kernel_registry = std::make_shared(); auto kernel_def = KernelDefBuilder().SetName("PrePackingTest").Provider(kCpuExecutionProvider).SinceVersion(1).Build(); @@ -257,19 +374,25 @@ TEST_P(SessionStatePrepackingTest, PrePackingTest) { [](const OpKernelInfo& info) -> OpKernel* { return new PrePackingTestOpKernel(info); }))); kernel_registry_manager.RegisterKernelRegistry(kernel_registry); + PlaceAllNodesToCPUEP(model.MainGraph()); + SessionOptions sess_options; - bool use_prepacking = GetParam(); - sess_options.session_configurations[kOrtSessionOptionsConfigDisablePrepacking] = use_prepacking ? "0" : "1"; + sess_options.session_configurations[kOrtSessionOptionsConfigDisablePrepacking] = test_param.test_prepacking ? "0" : "1"; ASSERT_STATUS_OK(session_state.FinalizeSessionState(std::basic_string(), kernel_registry_manager, sess_options)); const auto& const_initialized_tensors = session_state.GetConstantInitializedTensors(); // check prepacking - ASSERT_EQ(const_initialized_tensors.size(), size_t(use_prepacking ? 0 : 1)); + ASSERT_EQ(const_initialized_tensors.size(), size_t(test_param.test_prepacking ? 0 : 1)); } -INSTANTIATE_TEST_SUITE_P(SessionStateTests, SessionStatePrepackingTest, testing::Values(true, false)); +INSTANTIATE_TEST_SUITE_P(SessionStateTests, + SessionStatePrepackingTest, + testing::Values(PrepackingTestParam{false, false}, + PrepackingTestParam{false, true}, + PrepackingTestParam{true, false}, + PrepackingTestParam{true, true})); } // namespace test } // namespace onnxruntime