diff --git a/onnxruntime/core/optimizer/transformer_memcpy.cc b/onnxruntime/core/optimizer/transformer_memcpy.cc index b52d7f41af..c6f57900b9 100644 --- a/onnxruntime/core/optimizer/transformer_memcpy.cc +++ b/onnxruntime/core/optimizer/transformer_memcpy.cc @@ -18,10 +18,10 @@ class TransformerMemcpyImpl { bool ModifyGraph(const KernelRegistryManager& schema_registries); private: - void ProcessDefs(onnxruntime::Node& node, const KernelRegistryManager& kernel_registries); + void ProcessDefs(onnxruntime::Node& node, const KernelRegistryManager& kernel_registries, InitializedTensorSet& initializers_consumed); void BuildDefsMapping(const onnxruntime::NodeArg* arg, const KernelRegistryManager& kernel_registries); void AddCopyNode(onnxruntime::NodeArg* arg, bool is_input); - void ProcessInitializers(const KernelRegistryManager& kernel_registries); + bool ProcessInitializers(const KernelRegistryManager& kernel_registries, const InitializedTensorSet& initializers_consumed); private: ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(TransformerMemcpyImpl); @@ -52,6 +52,19 @@ class TransformerMemcpyImpl { std::string provider_; }; +/** Helper that returns a pointer to the corresponding TensorProto for a name if it is an initializer. +@param check_outer_scope If true and the graph is a subgraph, check parent graph/s for 'name' if not found in 'graph'. +*/ +static const onnx::TensorProto* GetInitializer(const Graph& graph, const std::string& name, bool check_outer_scope) { + const onnx::TensorProto* initializer = nullptr; + if (graph.GetInitializedTensor(name, initializer)) { + return initializer; + } else if (check_outer_scope && graph.IsSubgraph()) { + return GetInitializer(*graph.ParentGraph(), name, check_outer_scope); + } + return initializer; +} + // very simple GraphTransformer that uses TransformerMemcpyImpl for each graph // and mainly provides the subgraph recursion functionality common::Status MemcpyTransformer::ApplyImpl(Graph& graph, bool& modified, int graph_level) const { @@ -63,7 +76,8 @@ common::Status MemcpyTransformer::ApplyImpl(Graph& graph, bool& modified, int gr provider != onnxruntime::kTensorrtExecutionProvider && provider != onnxruntime::kOpenVINOExecutionProvider) { TransformerMemcpyImpl copy_impl(graph, provider); - modified = copy_impl.ModifyGraph(registry_manager_); + auto current_modified = copy_impl.ModifyGraph(registry_manager_); + modified = modified || current_modified; } } @@ -109,14 +123,16 @@ This transformer does not currently optimize copies between, e.g., two different bool TransformerMemcpyImpl::ModifyGraph(const KernelRegistryManager& kernel_registries) { bool modified = false; + InitializedTensorSet initializers_consumed; // find defs that require copy for (auto& node : graph_.Nodes()) { - //don't need to do node placement here now, onnxruntime will do it according to registered kernels. - ProcessDefs(node, kernel_registries); + //as we process the defs, collect all the initializers consumed at the current graph level + ProcessDefs(node, kernel_registries, initializers_consumed); } // for initializers shared by different providers, create dups - ProcessInitializers(kernel_registries); + if (ProcessInitializers(kernel_registries, initializers_consumed)) + modified = true; for (auto arg : graph_.GetInputs()) BuildDefsMapping(arg, kernel_registries); @@ -150,21 +166,27 @@ bool TransformerMemcpyImpl::ModifyGraph(const KernelRegistryManager& kernel_regi return modified; } -void TransformerMemcpyImpl::ProcessDefs(onnxruntime::Node& node, const KernelRegistryManager& kernel_registries) { +void TransformerMemcpyImpl::ProcessDefs(onnxruntime::Node& node, const KernelRegistryManager& kernel_registries, InitializedTensorSet& initializers_consumed) { if (node.GetExecutionProviderType() == provider_) { provider_nodes_.insert(&node); // note KernelCreateInfo might be nullptr for custom kernel const KernelCreateInfo* kci = nullptr; kernel_registries.SearchKernelRegistry(node, &kci); - auto status = onnxruntime::Node::ForEachWithIndex(node.InputDefs(), - [this, &kci](const onnxruntime::NodeArg& arg, size_t index) { - if (kci && kci->kernel_def->IsInputOnCpu(index)) - non_provider_input_defs_.insert(&arg); - else - provider_input_defs_.insert(&arg); - return Status::OK(); - }); + auto status = onnxruntime::Node::ForEachWithIndex( + node.InputDefs(), + [this, &kci, &initializers_consumed](const onnxruntime::NodeArg& arg, size_t index) { + // check if this NodeArg is an initializer defined in current outer graph level + const auto* initializer_tensor_proto = + GetInitializer(graph_, arg.Name(), true); + if (initializer_tensor_proto != nullptr) + initializers_consumed[arg.Name()] = initializer_tensor_proto; + if (kci && kci->kernel_def->IsInputOnCpu(index)) + non_provider_input_defs_.insert(&arg); + else + provider_input_defs_.insert(&arg); + return Status::OK(); + }); ORT_ENFORCE(status.IsOK(), status.ErrorMessage()); @@ -274,9 +296,9 @@ static const onnxruntime::NodeArg* FindNodeArg(const NodeArgSetType& def_set, co // We duplicate any initializer that is used by both provider nodes and non-provider nodes // to ensure that provider nodes and non-provider nodes don't share initializers, as they // need to stay in different memory locations. -void TransformerMemcpyImpl::ProcessInitializers(const KernelRegistryManager& kernel_registries) { +bool TransformerMemcpyImpl::ProcessInitializers(const KernelRegistryManager& kernel_registries, const InitializedTensorSet& initializers_consumed) { std::map replacements; - for (const auto& pair : graph_.GetAllInitializedTensors()) { + for (const auto& pair : initializers_consumed) { const auto& name = pair.first; const onnxruntime::NodeArg* provider_def = FindNodeArg(provider_input_defs_, name); const onnxruntime::NodeArg* non_provider_def = FindNodeArg(non_provider_input_defs_, name); @@ -284,10 +306,15 @@ void TransformerMemcpyImpl::ProcessInitializers(const KernelRegistryManager& ker std::string new_def_name = graph_.GenerateNodeArgName(name); auto& new_def = graph_.GetOrCreateNodeArg(new_def_name, provider_def->TypeAsProto()); - const TensorProto* tensor_proto = nullptr; - bool found = graph_.GetInitializedTensor(name, tensor_proto); - ORT_ENFORCE(found, "Failed to get initialized tensor ", name); - + // We make a copy of the initializer that is to be consumed by the provider Node so that + // session state initializer can copy it over to the provider device during its operation + // TODO: The copy being made is possibly redundant if this occurs in a subgraph + // When multiple subgraphs consume the same initializer as an implicit input, + // multiple copies of the initializer will be made into the provider device + // This should not directly affect runtime performance as the copies occur during initialization + // but overuse of the provider device's memory is definitely inefficient + // In future, we need to "statefully" make the copy only once and use it in all subgraphs referencing the initializer + const TensorProto* tensor_proto = pair.second; TensorProto new_tensor_proto = *tensor_proto; *(new_tensor_proto.mutable_name()) = new_def_name; graph_.AddInitializedTensor(new_tensor_proto); @@ -322,6 +349,9 @@ void TransformerMemcpyImpl::ProcessInitializers(const KernelRegistryManager& ker p_node->ReplaceDefs(dup_replacements); } + + // This denotes a modification to the graph + return !replacements.empty(); } } // namespace onnxruntime diff --git a/onnxruntime/test/framework/memcpy_transformer_test.cc b/onnxruntime/test/framework/memcpy_transformer_test.cc index f8ef7b3562..8eea68976e 100644 --- a/onnxruntime/test/framework/memcpy_transformer_test.cc +++ b/onnxruntime/test/framework/memcpy_transformer_test.cc @@ -177,6 +177,112 @@ TEST(TransformerTest, MemcpyTransformerTestCudaFirst) { ExpectSame(node2, node4, 0); ExpectSame(node2, node4, 1); } +TEST(TransformerTest, TestCopyNodeInsertionInitializerInSubgraph) { + // In this test, we are going to create a subgraph consuming an implicit input + // which is an initializer in the outer scope, and this implicit input to the subgraph + // is consumed by nodes on multiple devices + TensorProto value_tensor; + value_tensor.add_dims(1); + value_tensor.add_float_data(1.f); + value_tensor.set_data_type(ONNX_NAMESPACE::TensorProto_DataType_FLOAT); + + TypeProto tensor_float_type; + tensor_float_type.mutable_tensor_type()->set_elem_type(TensorProto_DataType_FLOAT); + + TypeProto tensor_bool_type; + tensor_bool_type.mutable_tensor_type()->set_elem_type(TensorProto_DataType_BOOL); + + onnxruntime::NodeArg i1_def("I1", &tensor_bool_type), + o1_def("O1", &tensor_float_type), + o2_def("O2", &tensor_float_type); + + // main graph + // this will only contain one 'If' node + std::unordered_map domain_to_version; + domain_to_version[kOnnxDomain] = 7; + auto model = std::make_shared("test", + false, + ModelMetaData(), + IOnnxRuntimeOpSchemaRegistryList(), + domain_to_version); + onnxruntime::Graph& graph = model->MainGraph(); + + TensorProto parent_constant(value_tensor); + parent_constant.set_name("parent_constant"); + graph.AddInitializedTensor(parent_constant); + + // subgraph + // this will contain 2 'Add' nodes - one on CPU and one of GPU + // one of the inputs to the 'Add' nodes is an implicit input to the subgraph + // which is an initializer in the main graph + std::unordered_map subgraph_domain_to_version; + subgraph_domain_to_version[kOnnxDomain] = 7; + auto sub_model = std::make_shared("test_subgraph", + false, + ModelMetaData(), + IOnnxRuntimeOpSchemaRegistryList(), + subgraph_domain_to_version); + onnxruntime::Graph& subgraph = sub_model->MainGraph(); + + TensorProto local_constant(value_tensor); + local_constant.set_name("local_constant"); + subgraph.AddInitializedTensor(local_constant); + + subgraph.AddOuterScopeNodeArg("parent_constant"); + subgraph.AddNode("node1", "Add", "operator1", + ArgMap{&subgraph.GetOrCreateNodeArg("local_constant", &tensor_float_type), + &graph.GetOrCreateNodeArg("parent_constant", &tensor_float_type)}, + ArgMap{&o1_def}); + + subgraph.AddNode("node2", "Add", "operator2", + ArgMap{&subgraph.GetOrCreateNodeArg("local_constant", &tensor_float_type), + &graph.GetOrCreateNodeArg("parent_constant", &tensor_float_type)}, + ArgMap{&o2_def}); + + auto status = subgraph.Resolve(); + ASSERT_TRUE(status.IsOK()) << status.ErrorMessage(); + + // main graph continued + // create the 'If' node + auto& if_node = graph.AddNode("node3", "If", "cpu operator2", ArgMap{&i1_def}, ArgMap{&o1_def, &o2_def}); + if_node.AddAttribute("then_branch", {subgraph.ToGraphProto()}); + if_node.AddAttribute("else_branch", {subgraph.ToGraphProto()}); + + onnxruntime::Graph* subgraph_1 = if_node.GetMutableGraphAttribute("then_branch"); + for (auto& node : subgraph_1->Nodes()) { + if (node.Name() == "node2") { + // only this node is on GPU + node.SetExecutionProviderType(onnxruntime::kCudaExecutionProvider); + } else { + node.SetExecutionProviderType(onnxruntime::kCpuExecutionProvider); + } + } + + onnxruntime::Graph* subgraph_2 = if_node.GetMutableGraphAttribute("else_branch"); + for (auto& node : subgraph_2->Nodes()) { + node.SetExecutionProviderType(onnxruntime::kCpuExecutionProvider); + } + + status = graph.Resolve(); + ASSERT_TRUE(status.IsOK()) << status.ErrorMessage(); + + KernelRegistryManager kernel_registry_manager; + ExecutionProviders execution_providers; + execution_providers.Add(onnxruntime::kCudaExecutionProvider, + std::make_unique(CUDAExecutionProviderInfo())); + execution_providers.Add(onnxruntime::kCpuExecutionProvider, + std::make_unique(CPUExecutionProviderInfo())); + KernelRegistryManager test_registry_manager; + test_registry_manager.RegisterKernels(execution_providers); + + MemcpyTransformer transformer({onnxruntime::kCudaExecutionProvider}, test_registry_manager); + + bool modified = false; + status = transformer.Apply(graph, modified); + EXPECT_TRUE(status.IsOK()) << status.ErrorMessage(); + EXPECT_TRUE(modified); +} + #endif } // namespace test