mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-07-05 04:17:53 +00:00
Let execution fall back to CPU EP if Compile of a partition on current EP fails (#6580)
* Let exccution fall back to CPU EP if compile of a partition fails * Removed debugging logs * Addressed CR comments
This commit is contained in:
parent
f2ce3aae13
commit
68193e28de
4 changed files with 162 additions and 34 deletions
|
|
@ -921,11 +921,18 @@ class Graph {
|
|||
@returns Node with fused subgraph.
|
||||
@remarks As a new Graph instance for the fused nodes is not created, a GraphViewer can be constructed with the
|
||||
IndexedSubGraph information to provide a view of the subgraph. The original nodes are left in place
|
||||
while this is in use.
|
||||
while this is in use.
|
||||
Call FinalizeFuseSubGraph to remove them once the fused replacement node is fully created.
|
||||
*/
|
||||
Node& BeginFuseSubGraph(const IndexedSubGraph& sub_graph, const std::string& fused_node_name);
|
||||
|
||||
/**
|
||||
If we have BeginFuseSubGraph, but somehow hit errors, such as Compile of an EP failed on thesub_graph.
|
||||
We can call CancelFuseSubGraph to undo the changes of BeginFuseSubGraph
|
||||
@param fused_node The fused node and it's function body to be removed from the graph
|
||||
*/
|
||||
void CancelFuseSubGraph(const Node& fused_node);
|
||||
|
||||
void FinalizeFuseSubGraph(const IndexedSubGraph& sub_graph, Node& fused_node);
|
||||
#endif
|
||||
|
||||
|
|
|
|||
|
|
@ -448,42 +448,43 @@ static Status PartitionOrtFormatModelImpl(Graph& graph, FuncManager& func_mgr,
|
|||
nodes_and_viewers.push_back(IExecutionProvider::FusedNodeAndGraph{fused_node, *viewers.back()});
|
||||
}
|
||||
|
||||
std::vector<NodeComputeInfo> node_compute_funcs;
|
||||
node_compute_funcs.reserve(nodes_and_viewers.size());
|
||||
|
||||
ORT_RETURN_IF_ERROR(current_ep.Compile(nodes_and_viewers, node_compute_funcs));
|
||||
|
||||
if (node_compute_funcs.size() != nodes_and_viewers.size()) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, type, " did not return correct number of compiled functions");
|
||||
}
|
||||
|
||||
for (size_t j = 0, end = nodes_and_viewers.size(); j < end; j++) {
|
||||
// We will compile the fused nodes one by one, and fuse the subgraph if successful.
|
||||
// If a compilation fails we undo the fusion and leave the original nodes available for other EPs to take
|
||||
for (size_t j = 0, end = nodes_and_viewers.size(); j < end; ++j) {
|
||||
Node& node = nodes_and_viewers[j].fused_node;
|
||||
std::vector<NodeComputeInfo> single_node_compute_func;
|
||||
auto status = current_ep.Compile({nodes_and_viewers[j]}, single_node_compute_func);
|
||||
if (!status.IsOK()) {
|
||||
// There is compile error with the nodes_and_viewer[j], remove the fused_node and function from the graph
|
||||
LOGS_DEFAULT(ERROR) << "EP: " << current_ep.Type() << " has Compile error: " << status.ErrorMessage();
|
||||
graph.CancelFuseSubGraph(node);
|
||||
} else {
|
||||
ORT_RETURN_IF(single_node_compute_func.empty(), "single_node_compute_func should have 1 elements");
|
||||
ORT_RETURN_IF_ERROR(func_mgr.AddFuncInfo(node.Name(), std::move(single_node_compute_func[0])));
|
||||
|
||||
ORT_RETURN_IF_ERROR(func_mgr.AddFuncInfo(node.Name(), std::move(node_compute_funcs[j])));
|
||||
const auto& cur_capability = capabilities[j];
|
||||
const IndexedSubGraph& indexed_sub_graph = *cur_capability->sub_graph;
|
||||
const IndexedSubGraph::MetaDef& metadef = *indexed_sub_graph.GetMetaDef();
|
||||
|
||||
const auto& cur_capability = capabilities[j];
|
||||
const IndexedSubGraph& indexed_sub_graph = *cur_capability->sub_graph;
|
||||
const IndexedSubGraph::MetaDef& metadef = *indexed_sub_graph.GetMetaDef();
|
||||
KernelDefBuilder builder;
|
||||
BuildFusedKernelDef(builder, metadef, type);
|
||||
auto kernel_def = builder.Build();
|
||||
|
||||
KernelDefBuilder builder;
|
||||
BuildFusedKernelDef(builder, metadef, type);
|
||||
auto kernel_def = builder.Build();
|
||||
// save hash so SessionState can find the kernel. each kernel name should be unique
|
||||
if (compiled_kernel_hashes.insert({metadef.name, kernel_def->GetHash()}).second == false) {
|
||||
ORT_THROW("Existing entry in compiled kernel hashes for ", metadef.name,
|
||||
". Execution Provider must generate unique names across the entire model.");
|
||||
}
|
||||
|
||||
// save hash so SessionState can find the kernel. each kernel name should be unique
|
||||
if (compiled_kernel_hashes.insert({metadef.name, kernel_def->GetHash()}).second == false) {
|
||||
ORT_THROW("Existing entry in compiled kernel hashes for ", metadef.name,
|
||||
". Execution Provider must generate unique names across the entire model.");
|
||||
ORT_RETURN_IF_ERROR(fused_kernel_registry.Register(
|
||||
KernelCreateInfo(std::move(kernel_def), static_cast<KernelCreatePtrFn>(
|
||||
[](const OpKernelInfo& info) -> OpKernel* {
|
||||
return new FunctionKernel(info);
|
||||
}))));
|
||||
|
||||
// now that we're done compiling we can remove the original nodes from the Graph and wire in the new one
|
||||
graph.FinalizeFuseSubGraph(indexed_sub_graph, node);
|
||||
}
|
||||
|
||||
ORT_RETURN_IF_ERROR(fused_kernel_registry.Register(
|
||||
KernelCreateInfo(std::move(kernel_def), static_cast<KernelCreatePtrFn>(
|
||||
[](const OpKernelInfo& info) -> OpKernel* {
|
||||
return new FunctionKernel(info);
|
||||
}))));
|
||||
|
||||
// now that we're done compiling we can remove the original nodes from the Graph and wire in the new one
|
||||
graph.FinalizeFuseSubGraph(indexed_sub_graph, node);
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
|
|
|
|||
|
|
@ -3392,6 +3392,31 @@ Node& Graph::BeginFuseSubGraph(const IndexedSubGraph& sub_graph, const std::stri
|
|||
return node;
|
||||
}
|
||||
|
||||
void Graph::CancelFuseSubGraph(const Node& fused_node) {
|
||||
auto node_idx = fused_node.Index();
|
||||
if (!GetNode(node_idx))
|
||||
return;
|
||||
|
||||
if (fused_node.NodeType() != Node::Type::Fused)
|
||||
return;
|
||||
|
||||
#if !defined(ORT_MINIMAL_BUILD)
|
||||
// Remove the function body from function container
|
||||
const auto* fused_node_func = fused_node.GetFunctionBody();
|
||||
auto it = std::find_if(
|
||||
function_container_.begin(), function_container_.end(),
|
||||
[fused_node_func](const std::unique_ptr<onnxruntime::Function>& func) {
|
||||
return func.get() == fused_node_func;
|
||||
});
|
||||
if (it != function_container_.end()) {
|
||||
function_container_.erase(it);
|
||||
}
|
||||
#endif
|
||||
|
||||
// Remove the fused_node
|
||||
RemoveNode(node_idx);
|
||||
}
|
||||
|
||||
void Graph::FinalizeFuseSubGraph(const IndexedSubGraph& sub_graph, Node& fused_node) {
|
||||
const auto* func_meta_def = sub_graph.GetMetaDef();
|
||||
ORT_ENFORCE(nullptr != func_meta_def);
|
||||
|
|
@ -3432,9 +3457,7 @@ void Graph::FinalizeFuseSubGraph(const IndexedSubGraph& sub_graph, Node& fused_n
|
|||
if (it != input_indexes.cend()) {
|
||||
AddEdge(producer_idx, new_node_idx, src_idx, it->second);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
} else {
|
||||
int dst_implicit_input_idx = dst_idx - (int)node->InputDefs().size();
|
||||
ORT_ENFORCE(dst_implicit_input_idx < (int)node->ImplicitInputDefs().size());
|
||||
auto it = input_indexes.find(node->ImplicitInputDefs()[dst_implicit_input_idx]->Name());
|
||||
|
|
|
|||
|
|
@ -254,5 +254,102 @@ TEST(InternalTestingEP, TestModelWithSubgraph) {
|
|||
feeds);
|
||||
}
|
||||
|
||||
// A custom InternalTestingEP extension
|
||||
// This is to testing execution fall back to CPU EP if Compile fails, for ORT format
|
||||
// This EP will take an additional compile_failure_ops
|
||||
// If in Compile() any nodes in the partition is also in compile_failure_ops
|
||||
// The Compile will fail
|
||||
class CompileFailureTestExecutionProvider : public InternalTestingExecutionProvider {
|
||||
public:
|
||||
CompileFailureTestExecutionProvider(const std::unordered_set<std::string>& supported_ops,
|
||||
const std::unordered_set<std::string>& compile_failure_ops);
|
||||
virtual ~CompileFailureTestExecutionProvider() = default;
|
||||
|
||||
Status Compile(const std::vector<FusedNodeAndGraph>& fused_nodes,
|
||||
std::vector<NodeComputeInfo>& node_compute_funcs) override;
|
||||
|
||||
private:
|
||||
std::unordered_set<std::string> compile_failure_ops_;
|
||||
};
|
||||
|
||||
CompileFailureTestExecutionProvider::CompileFailureTestExecutionProvider(
|
||||
const std::unordered_set<std::string>& supported_ops,
|
||||
const std::unordered_set<std::string>& compile_failure_ops)
|
||||
: InternalTestingExecutionProvider(supported_ops),
|
||||
compile_failure_ops_(compile_failure_ops) {}
|
||||
|
||||
Status CompileFailureTestExecutionProvider::Compile(const std::vector<FusedNodeAndGraph>& fused_nodes,
|
||||
std::vector<NodeComputeInfo>& node_compute_funcs) {
|
||||
for (const auto& fused_node_and_graph : fused_nodes) {
|
||||
// If any nodes in this partition is also in compile_failure_ops_, the Compile will fail
|
||||
const onnxruntime::GraphViewer& graph_viewer(fused_node_and_graph.filtered_graph);
|
||||
for (const auto& node : graph_viewer.Nodes()) {
|
||||
if (compile_failure_ops_.find(node.OpType()) != compile_failure_ops_.end()) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL,
|
||||
"CompileFailureTestExecutionProvider::Compile failed for node: ", node.Name());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return InternalTestingExecutionProvider::Compile(fused_nodes, node_compute_funcs);
|
||||
}
|
||||
|
||||
TEST(InternalTestingEP, TestOrtModelWithCompileFailure) {
|
||||
// In the test file, there are 2 Conv and 1 Gemm nodes, all disconnected
|
||||
// So we should have 3 partitions be taken by InternalTestingExecutionProvider/CompileFailureTestExecutionProvider
|
||||
// But CompileFailureTestExecutionProvider will fail the Compile for partition contains "Gemm" node
|
||||
// This is to test the model initialization won't fail and Gemm node will not be replaced by the fused_node
|
||||
const ORTCHAR_T* ort_model_path = ORT_TSTR("testdata/mnist.ort");
|
||||
|
||||
const std::unordered_set<std::string>& supported_ops{"Conv", "Gemm"};
|
||||
const std::unordered_set<std::string>& compile_failure_ops{"Gemm"};
|
||||
|
||||
// Use InternalTestingExecutionProvider
|
||||
// We should have 3 partitions taken by the EP
|
||||
// 2 Conv and 1 Gemm
|
||||
{
|
||||
InferenceSessionWrapper session(SessionOptions(), GetEnvironment());
|
||||
ASSERT_STATUS_OK(session.RegisterExecutionProvider(
|
||||
onnxruntime::make_unique<InternalTestingExecutionProvider>(supported_ops)));
|
||||
ASSERT_STATUS_OK(session.Load(ort_model_path));
|
||||
ASSERT_STATUS_OK(session.Initialize());
|
||||
|
||||
int num_replaced_nodes = CountAndValidateAssignedNodes(
|
||||
session.GetGraph(), supported_ops, session.GetSessionState().GetFuncMgr());
|
||||
|
||||
ASSERT_EQ(num_replaced_nodes, 3);
|
||||
}
|
||||
|
||||
// Use CompileFailureTestExecutionProvider which will fail Compile on "Gemm"
|
||||
// We should have 2 partitions taken by the EP
|
||||
// 2 Conv
|
||||
{
|
||||
InferenceSessionWrapper session(SessionOptions(), GetEnvironment());
|
||||
ASSERT_STATUS_OK(session.RegisterExecutionProvider(
|
||||
onnxruntime::make_unique<CompileFailureTestExecutionProvider>(supported_ops, compile_failure_ops)));
|
||||
ASSERT_STATUS_OK(session.Load(ort_model_path));
|
||||
ASSERT_STATUS_OK(session.Initialize());
|
||||
|
||||
// 2 Conv nodes shoule be replaced with fused nodes
|
||||
const auto& graph = session.GetGraph();
|
||||
int num_replaced_nodes = CountAndValidateAssignedNodes(
|
||||
session.GetGraph(), {"Conv"}, session.GetSessionState().GetFuncMgr());
|
||||
|
||||
ASSERT_EQ(num_replaced_nodes, 2);
|
||||
|
||||
// The Gemm node should still not have been replaced
|
||||
int count_compile_failure_nodes = 0;
|
||||
for (const auto& node : graph.Nodes()) {
|
||||
if (compile_failure_ops.find(node.OpType()) != compile_failure_ops.end())
|
||||
count_compile_failure_nodes++;
|
||||
}
|
||||
ASSERT_EQ(count_compile_failure_nodes, 1);
|
||||
|
||||
// Execute the session, since the last node is Gemm, and its input 0 is all 0s
|
||||
// So the result should be the bias initializer of the Gemm node
|
||||
ExecuteMnist(session, true /* enable_custom_ep */);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace test
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
Loading…
Reference in a new issue