diff --git a/include/onnxruntime/core/graph/graph.h b/include/onnxruntime/core/graph/graph.h index 8ce8e4ff2f..e2d57cfda4 100644 --- a/include/onnxruntime/core/graph/graph.h +++ b/include/onnxruntime/core/graph/graph.h @@ -921,11 +921,18 @@ class Graph { @returns Node with fused subgraph. @remarks As a new Graph instance for the fused nodes is not created, a GraphViewer can be constructed with the IndexedSubGraph information to provide a view of the subgraph. The original nodes are left in place - while this is in use. + while this is in use. Call FinalizeFuseSubGraph to remove them once the fused replacement node is fully created. */ Node& BeginFuseSubGraph(const IndexedSubGraph& sub_graph, const std::string& fused_node_name); + /** + If we have BeginFuseSubGraph, but somehow hit errors, such as Compile of an EP failed on thesub_graph. + We can call CancelFuseSubGraph to undo the changes of BeginFuseSubGraph + @param fused_node The fused node and it's function body to be removed from the graph + */ + void CancelFuseSubGraph(const Node& fused_node); + void FinalizeFuseSubGraph(const IndexedSubGraph& sub_graph, Node& fused_node); #endif diff --git a/onnxruntime/core/framework/graph_partitioner.cc b/onnxruntime/core/framework/graph_partitioner.cc index 9abb3cb6f3..9a17ffe285 100644 --- a/onnxruntime/core/framework/graph_partitioner.cc +++ b/onnxruntime/core/framework/graph_partitioner.cc @@ -448,42 +448,43 @@ static Status PartitionOrtFormatModelImpl(Graph& graph, FuncManager& func_mgr, nodes_and_viewers.push_back(IExecutionProvider::FusedNodeAndGraph{fused_node, *viewers.back()}); } - std::vector node_compute_funcs; - node_compute_funcs.reserve(nodes_and_viewers.size()); - - ORT_RETURN_IF_ERROR(current_ep.Compile(nodes_and_viewers, node_compute_funcs)); - - if (node_compute_funcs.size() != nodes_and_viewers.size()) { - return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, type, " did not return correct number of compiled functions"); - } - - for (size_t j = 0, end = nodes_and_viewers.size(); j < end; j++) { + // We will compile the fused nodes one by one, and fuse the subgraph if successful. + // If a compilation fails we undo the fusion and leave the original nodes available for other EPs to take + for (size_t j = 0, end = nodes_and_viewers.size(); j < end; ++j) { Node& node = nodes_and_viewers[j].fused_node; + std::vector single_node_compute_func; + auto status = current_ep.Compile({nodes_and_viewers[j]}, single_node_compute_func); + if (!status.IsOK()) { + // There is compile error with the nodes_and_viewer[j], remove the fused_node and function from the graph + LOGS_DEFAULT(ERROR) << "EP: " << current_ep.Type() << " has Compile error: " << status.ErrorMessage(); + graph.CancelFuseSubGraph(node); + } else { + ORT_RETURN_IF(single_node_compute_func.empty(), "single_node_compute_func should have 1 elements"); + ORT_RETURN_IF_ERROR(func_mgr.AddFuncInfo(node.Name(), std::move(single_node_compute_func[0]))); - ORT_RETURN_IF_ERROR(func_mgr.AddFuncInfo(node.Name(), std::move(node_compute_funcs[j]))); + const auto& cur_capability = capabilities[j]; + const IndexedSubGraph& indexed_sub_graph = *cur_capability->sub_graph; + const IndexedSubGraph::MetaDef& metadef = *indexed_sub_graph.GetMetaDef(); - const auto& cur_capability = capabilities[j]; - const IndexedSubGraph& indexed_sub_graph = *cur_capability->sub_graph; - const IndexedSubGraph::MetaDef& metadef = *indexed_sub_graph.GetMetaDef(); + KernelDefBuilder builder; + BuildFusedKernelDef(builder, metadef, type); + auto kernel_def = builder.Build(); - KernelDefBuilder builder; - BuildFusedKernelDef(builder, metadef, type); - auto kernel_def = builder.Build(); + // save hash so SessionState can find the kernel. each kernel name should be unique + if (compiled_kernel_hashes.insert({metadef.name, kernel_def->GetHash()}).second == false) { + ORT_THROW("Existing entry in compiled kernel hashes for ", metadef.name, + ". Execution Provider must generate unique names across the entire model."); + } - // save hash so SessionState can find the kernel. each kernel name should be unique - if (compiled_kernel_hashes.insert({metadef.name, kernel_def->GetHash()}).second == false) { - ORT_THROW("Existing entry in compiled kernel hashes for ", metadef.name, - ". Execution Provider must generate unique names across the entire model."); + ORT_RETURN_IF_ERROR(fused_kernel_registry.Register( + KernelCreateInfo(std::move(kernel_def), static_cast( + [](const OpKernelInfo& info) -> OpKernel* { + return new FunctionKernel(info); + })))); + + // now that we're done compiling we can remove the original nodes from the Graph and wire in the new one + graph.FinalizeFuseSubGraph(indexed_sub_graph, node); } - - ORT_RETURN_IF_ERROR(fused_kernel_registry.Register( - KernelCreateInfo(std::move(kernel_def), static_cast( - [](const OpKernelInfo& info) -> OpKernel* { - return new FunctionKernel(info); - })))); - - // now that we're done compiling we can remove the original nodes from the Graph and wire in the new one - graph.FinalizeFuseSubGraph(indexed_sub_graph, node); } return Status::OK(); diff --git a/onnxruntime/core/graph/graph.cc b/onnxruntime/core/graph/graph.cc index e9b5225fa5..8f72f43c0e 100644 --- a/onnxruntime/core/graph/graph.cc +++ b/onnxruntime/core/graph/graph.cc @@ -3392,6 +3392,31 @@ Node& Graph::BeginFuseSubGraph(const IndexedSubGraph& sub_graph, const std::stri return node; } +void Graph::CancelFuseSubGraph(const Node& fused_node) { + auto node_idx = fused_node.Index(); + if (!GetNode(node_idx)) + return; + + if (fused_node.NodeType() != Node::Type::Fused) + return; + +#if !defined(ORT_MINIMAL_BUILD) + // Remove the function body from function container + const auto* fused_node_func = fused_node.GetFunctionBody(); + auto it = std::find_if( + function_container_.begin(), function_container_.end(), + [fused_node_func](const std::unique_ptr& func) { + return func.get() == fused_node_func; + }); + if (it != function_container_.end()) { + function_container_.erase(it); + } +#endif + + // Remove the fused_node + RemoveNode(node_idx); +} + void Graph::FinalizeFuseSubGraph(const IndexedSubGraph& sub_graph, Node& fused_node) { const auto* func_meta_def = sub_graph.GetMetaDef(); ORT_ENFORCE(nullptr != func_meta_def); @@ -3432,9 +3457,7 @@ void Graph::FinalizeFuseSubGraph(const IndexedSubGraph& sub_graph, Node& fused_n if (it != input_indexes.cend()) { AddEdge(producer_idx, new_node_idx, src_idx, it->second); } - } - else - { + } else { int dst_implicit_input_idx = dst_idx - (int)node->InputDefs().size(); ORT_ENFORCE(dst_implicit_input_idx < (int)node->ImplicitInputDefs().size()); auto it = input_indexes.find(node->ImplicitInputDefs()[dst_implicit_input_idx]->Name()); diff --git a/onnxruntime/test/providers/internal_testing/internal_testing_tests.cc b/onnxruntime/test/providers/internal_testing/internal_testing_tests.cc index d752d57959..aba2ed93c3 100644 --- a/onnxruntime/test/providers/internal_testing/internal_testing_tests.cc +++ b/onnxruntime/test/providers/internal_testing/internal_testing_tests.cc @@ -254,5 +254,102 @@ TEST(InternalTestingEP, TestModelWithSubgraph) { feeds); } +// A custom InternalTestingEP extension +// This is to testing execution fall back to CPU EP if Compile fails, for ORT format +// This EP will take an additional compile_failure_ops +// If in Compile() any nodes in the partition is also in compile_failure_ops +// The Compile will fail +class CompileFailureTestExecutionProvider : public InternalTestingExecutionProvider { + public: + CompileFailureTestExecutionProvider(const std::unordered_set& supported_ops, + const std::unordered_set& compile_failure_ops); + virtual ~CompileFailureTestExecutionProvider() = default; + + Status Compile(const std::vector& fused_nodes, + std::vector& node_compute_funcs) override; + + private: + std::unordered_set compile_failure_ops_; +}; + +CompileFailureTestExecutionProvider::CompileFailureTestExecutionProvider( + const std::unordered_set& supported_ops, + const std::unordered_set& compile_failure_ops) + : InternalTestingExecutionProvider(supported_ops), + compile_failure_ops_(compile_failure_ops) {} + +Status CompileFailureTestExecutionProvider::Compile(const std::vector& fused_nodes, + std::vector& node_compute_funcs) { + for (const auto& fused_node_and_graph : fused_nodes) { + // If any nodes in this partition is also in compile_failure_ops_, the Compile will fail + const onnxruntime::GraphViewer& graph_viewer(fused_node_and_graph.filtered_graph); + for (const auto& node : graph_viewer.Nodes()) { + if (compile_failure_ops_.find(node.OpType()) != compile_failure_ops_.end()) { + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, + "CompileFailureTestExecutionProvider::Compile failed for node: ", node.Name()); + } + } + } + + return InternalTestingExecutionProvider::Compile(fused_nodes, node_compute_funcs); +} + +TEST(InternalTestingEP, TestOrtModelWithCompileFailure) { + // In the test file, there are 2 Conv and 1 Gemm nodes, all disconnected + // So we should have 3 partitions be taken by InternalTestingExecutionProvider/CompileFailureTestExecutionProvider + // But CompileFailureTestExecutionProvider will fail the Compile for partition contains "Gemm" node + // This is to test the model initialization won't fail and Gemm node will not be replaced by the fused_node + const ORTCHAR_T* ort_model_path = ORT_TSTR("testdata/mnist.ort"); + + const std::unordered_set& supported_ops{"Conv", "Gemm"}; + const std::unordered_set& compile_failure_ops{"Gemm"}; + + // Use InternalTestingExecutionProvider + // We should have 3 partitions taken by the EP + // 2 Conv and 1 Gemm + { + InferenceSessionWrapper session(SessionOptions(), GetEnvironment()); + ASSERT_STATUS_OK(session.RegisterExecutionProvider( + onnxruntime::make_unique(supported_ops))); + ASSERT_STATUS_OK(session.Load(ort_model_path)); + ASSERT_STATUS_OK(session.Initialize()); + + int num_replaced_nodes = CountAndValidateAssignedNodes( + session.GetGraph(), supported_ops, session.GetSessionState().GetFuncMgr()); + + ASSERT_EQ(num_replaced_nodes, 3); + } + + // Use CompileFailureTestExecutionProvider which will fail Compile on "Gemm" + // We should have 2 partitions taken by the EP + // 2 Conv + { + InferenceSessionWrapper session(SessionOptions(), GetEnvironment()); + ASSERT_STATUS_OK(session.RegisterExecutionProvider( + onnxruntime::make_unique(supported_ops, compile_failure_ops))); + ASSERT_STATUS_OK(session.Load(ort_model_path)); + ASSERT_STATUS_OK(session.Initialize()); + + // 2 Conv nodes shoule be replaced with fused nodes + const auto& graph = session.GetGraph(); + int num_replaced_nodes = CountAndValidateAssignedNodes( + session.GetGraph(), {"Conv"}, session.GetSessionState().GetFuncMgr()); + + ASSERT_EQ(num_replaced_nodes, 2); + + // The Gemm node should still not have been replaced + int count_compile_failure_nodes = 0; + for (const auto& node : graph.Nodes()) { + if (compile_failure_ops.find(node.OpType()) != compile_failure_ops.end()) + count_compile_failure_nodes++; + } + ASSERT_EQ(count_compile_failure_nodes, 1); + + // Execute the session, since the last node is Gemm, and its input 0 is all 0s + // So the result should be the bias initializer of the Gemm node + ExecuteMnist(session, true /* enable_custom_ep */); + } +} + } // namespace test } // namespace onnxruntime