From df68111b9809d9ec2db58633ebba5fcf0ff8f47b Mon Sep 17 00:00:00 2001 From: KeDengMS Date: Wed, 19 Jun 2019 14:37:21 -0700 Subject: [PATCH] Fix a bug that fused func manager in subgraph session state is nullptr (#1251) Description: This fixes nullptr of fused func manager issue when running fused function inside sub graph session state Motivation and Context The bug happens in running fused functions created IExecutionProvider::Compile inside sub graph, i.e. Scan, which causes crash. The problem is that FuncInfo is collected into main graph's session state, before sub graph session state is created. The fix is to share FuncInfo between main graph and sub graph. --- onnxruntime/core/framework/fuse_nodes_funcs.h | 11 +++++-- onnxruntime/core/session/inference_session.cc | 33 ++++++++++--------- 2 files changed, 27 insertions(+), 17 deletions(-) diff --git a/onnxruntime/core/framework/fuse_nodes_funcs.h b/onnxruntime/core/framework/fuse_nodes_funcs.h index d7cf2cc3cc..e0990c5dd0 100644 --- a/onnxruntime/core/framework/fuse_nodes_funcs.h +++ b/onnxruntime/core/framework/fuse_nodes_funcs.h @@ -7,7 +7,7 @@ namespace onnxruntime { class FuncManager { public: - FuncManager() : fused_funcs_(std::make_unique >()), lib_loader_(std::make_unique()) {} + FuncManager() : fused_funcs_(std::make_shared >()), lib_loader_(std::make_unique()) {} Status AddFuncInfo(const std::string& name, const std::string& dll_path); @@ -15,6 +15,10 @@ class FuncManager { Status GetFuncs(const std::string& name, ComputeFunc* compute, CreateFunctionStateFunc* create, DestroyFunctionStateFunc* release) const; + void SetFusedFuncs(const FuncManager& func_mgr) { + fused_funcs_ = func_mgr.fused_funcs_; + } + struct FuncInfo { std::string dso_path; ComputeFunc compute_func; @@ -27,7 +31,10 @@ class FuncManager { const std::string kCreateStateFuncSymbol = "Create_State_"; const std::string kReleaseStateFuncSymbol = "Release_State_"; - std::unique_ptr > fused_funcs_; + // note that subgraph session state shares fused_funcs with main graph + // because it's filled in by the time main graph is traversed, + // while subgraph session state is created later + std::shared_ptr > fused_funcs_; std::unique_ptr lib_loader_; ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(FuncManager); }; diff --git a/onnxruntime/core/session/inference_session.cc b/onnxruntime/core/session/inference_session.cc index 614cc822e6..4d7906d8b8 100644 --- a/onnxruntime/core/session/inference_session.cc +++ b/onnxruntime/core/session/inference_session.cc @@ -225,9 +225,9 @@ common::Status InferenceSession::Load(const std::basic_string& model_uri) { model_location_ = ToWideString(model_uri); auto loader = [this](std::shared_ptr& model) { #ifdef ENABLE_LANGUAGE_INTEROP_OPS - LoadInterOp(model_location_, interop_domains_, [&](const char* msg){LOGS(*session_logger_, WARNING) << msg;}); - for(const auto& domain: interop_domains_) { - AddCustomOpDomains({domain.get()}); + LoadInterOp(model_location_, interop_domains_, [&](const char* msg) { LOGS(*session_logger_, WARNING) << msg; }); + for (const auto& domain : interop_domains_) { + AddCustomOpDomains({domain.get()}); } #endif return onnxruntime::Model::Load(model_location_, model, HasLocalSchema() ? &custom_schema_registries_ : nullptr); @@ -255,9 +255,9 @@ common::Status InferenceSession::Load(const std::wstring& model_uri) { common::Status InferenceSession::Load(const ModelProto& model_proto) { auto loader = [this, &model_proto](std::shared_ptr& model) { #ifdef ENABLE_LANGUAGE_INTEROP_OPS - LoadInterOp(model_proto, interop_domains_, [&](const char* msg){LOGS(*session_logger_, WARNING) << msg;}); - for(const auto& domain: interop_domains_) { - AddCustomOpDomains({domain.get()}); + LoadInterOp(model_proto, interop_domains_, [&](const char* msg) { LOGS(*session_logger_, WARNING) << msg; }); + for (const auto& domain : interop_domains_) { + AddCustomOpDomains({domain.get()}); } #endif return onnxruntime::Model::Load(model_proto, model, HasLocalSchema() ? &custom_schema_registries_ : nullptr); @@ -269,9 +269,9 @@ common::Status InferenceSession::Load(const ModelProto& model_proto) { common::Status InferenceSession::Load(std::unique_ptr p_model_proto) { auto loader = [this, &p_model_proto](std::shared_ptr& model) { #ifdef ENABLE_LANGUAGE_INTEROP_OPS - LoadInterOp(*p_model_proto, interop_domains_, [&](const char* msg){LOGS(*session_logger_, WARNING) << msg;}); - for(const auto& domain: interop_domains_) { - AddCustomOpDomains({domain.get()}); + LoadInterOp(*p_model_proto, interop_domains_, [&](const char* msg) { LOGS(*session_logger_, WARNING) << msg; }); + for (const auto& domain : interop_domains_) { + AddCustomOpDomains({domain.get()}); } #endif return onnxruntime::Model::Load(std::move(p_model_proto), model, @@ -292,9 +292,9 @@ common::Status InferenceSession::Load(std::istream& model_istream) { "Failed to load model because protobuf parsing failed."); } #ifdef ENABLE_LANGUAGE_INTEROP_OPS - LoadInterOp(model_proto, interop_domains_, [&](const char* msg){LOGS(*session_logger_, WARNING) << msg;}); - for(const auto& domain: interop_domains_) { - AddCustomOpDomains({domain.get()}); + LoadInterOp(model_proto, interop_domains_, [&](const char* msg) { LOGS(*session_logger_, WARNING) << msg; }); + for (const auto& domain : interop_domains_) { + AddCustomOpDomains({domain.get()}); } #endif return onnxruntime::Model::Load(model_proto, model, HasLocalSchema() ? &custom_schema_registries_ : nullptr); @@ -313,9 +313,9 @@ common::Status InferenceSession::Load(const void* model_data, int model_data_len "Failed to load model because protobuf parsing failed."); } #ifdef ENABLE_LANGUAGE_INTEROP_OPS - LoadInterOp(model_proto, interop_domains_, [&](const char* msg){LOGS(*session_logger_, WARNING) << msg;}); - for(const auto& domain: interop_domains_) { - AddCustomOpDomains({domain.get()}); + LoadInterOp(model_proto, interop_domains_, [&](const char* msg) { LOGS(*session_logger_, WARNING) << msg; }); + for (const auto& domain : interop_domains_) { + AddCustomOpDomains({domain.get()}); } #endif @@ -397,6 +397,9 @@ common::Status InferenceSession::CreateSubgraphSessionState(Graph& graph, Sessio // Pass threadpool to subgraph subgraph_session_state->SetThreadPool(session_state.GetThreadPool()); + // Pass fused function manager to subgraph + subgraph_session_state->GetMutableFuncMgr().SetFusedFuncs(session_state.GetFuncMgr()); + // recurse ORT_RETURN_IF_ERROR(CreateSubgraphSessionState(*subgraph, *subgraph_session_state));