diff --git a/onnxruntime/core/framework/fuse_nodes_funcs.h b/onnxruntime/core/framework/fuse_nodes_funcs.h index d7cf2cc3cc..e0990c5dd0 100644 --- a/onnxruntime/core/framework/fuse_nodes_funcs.h +++ b/onnxruntime/core/framework/fuse_nodes_funcs.h @@ -7,7 +7,7 @@ namespace onnxruntime { class FuncManager { public: - FuncManager() : fused_funcs_(std::make_unique >()), lib_loader_(std::make_unique()) {} + FuncManager() : fused_funcs_(std::make_shared >()), lib_loader_(std::make_unique()) {} Status AddFuncInfo(const std::string& name, const std::string& dll_path); @@ -15,6 +15,10 @@ class FuncManager { Status GetFuncs(const std::string& name, ComputeFunc* compute, CreateFunctionStateFunc* create, DestroyFunctionStateFunc* release) const; + void SetFusedFuncs(const FuncManager& func_mgr) { + fused_funcs_ = func_mgr.fused_funcs_; + } + struct FuncInfo { std::string dso_path; ComputeFunc compute_func; @@ -27,7 +31,10 @@ class FuncManager { const std::string kCreateStateFuncSymbol = "Create_State_"; const std::string kReleaseStateFuncSymbol = "Release_State_"; - std::unique_ptr > fused_funcs_; + // note that subgraph session state shares fused_funcs with main graph + // because it's filled in by the time main graph is traversed, + // while subgraph session state is created later + std::shared_ptr > fused_funcs_; std::unique_ptr lib_loader_; ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(FuncManager); }; diff --git a/onnxruntime/core/session/inference_session.cc b/onnxruntime/core/session/inference_session.cc index 614cc822e6..4d7906d8b8 100644 --- a/onnxruntime/core/session/inference_session.cc +++ b/onnxruntime/core/session/inference_session.cc @@ -225,9 +225,9 @@ common::Status InferenceSession::Load(const std::basic_string& model_uri) { model_location_ = ToWideString(model_uri); auto loader = [this](std::shared_ptr& model) { #ifdef ENABLE_LANGUAGE_INTEROP_OPS - LoadInterOp(model_location_, interop_domains_, [&](const char* msg){LOGS(*session_logger_, WARNING) << msg;}); - for(const auto& domain: interop_domains_) { - AddCustomOpDomains({domain.get()}); + LoadInterOp(model_location_, interop_domains_, [&](const char* msg) { LOGS(*session_logger_, WARNING) << msg; }); + for (const auto& domain : interop_domains_) { + AddCustomOpDomains({domain.get()}); } #endif return onnxruntime::Model::Load(model_location_, model, HasLocalSchema() ? &custom_schema_registries_ : nullptr); @@ -255,9 +255,9 @@ common::Status InferenceSession::Load(const std::wstring& model_uri) { common::Status InferenceSession::Load(const ModelProto& model_proto) { auto loader = [this, &model_proto](std::shared_ptr& model) { #ifdef ENABLE_LANGUAGE_INTEROP_OPS - LoadInterOp(model_proto, interop_domains_, [&](const char* msg){LOGS(*session_logger_, WARNING) << msg;}); - for(const auto& domain: interop_domains_) { - AddCustomOpDomains({domain.get()}); + LoadInterOp(model_proto, interop_domains_, [&](const char* msg) { LOGS(*session_logger_, WARNING) << msg; }); + for (const auto& domain : interop_domains_) { + AddCustomOpDomains({domain.get()}); } #endif return onnxruntime::Model::Load(model_proto, model, HasLocalSchema() ? &custom_schema_registries_ : nullptr); @@ -269,9 +269,9 @@ common::Status InferenceSession::Load(const ModelProto& model_proto) { common::Status InferenceSession::Load(std::unique_ptr p_model_proto) { auto loader = [this, &p_model_proto](std::shared_ptr& model) { #ifdef ENABLE_LANGUAGE_INTEROP_OPS - LoadInterOp(*p_model_proto, interop_domains_, [&](const char* msg){LOGS(*session_logger_, WARNING) << msg;}); - for(const auto& domain: interop_domains_) { - AddCustomOpDomains({domain.get()}); + LoadInterOp(*p_model_proto, interop_domains_, [&](const char* msg) { LOGS(*session_logger_, WARNING) << msg; }); + for (const auto& domain : interop_domains_) { + AddCustomOpDomains({domain.get()}); } #endif return onnxruntime::Model::Load(std::move(p_model_proto), model, @@ -292,9 +292,9 @@ common::Status InferenceSession::Load(std::istream& model_istream) { "Failed to load model because protobuf parsing failed."); } #ifdef ENABLE_LANGUAGE_INTEROP_OPS - LoadInterOp(model_proto, interop_domains_, [&](const char* msg){LOGS(*session_logger_, WARNING) << msg;}); - for(const auto& domain: interop_domains_) { - AddCustomOpDomains({domain.get()}); + LoadInterOp(model_proto, interop_domains_, [&](const char* msg) { LOGS(*session_logger_, WARNING) << msg; }); + for (const auto& domain : interop_domains_) { + AddCustomOpDomains({domain.get()}); } #endif return onnxruntime::Model::Load(model_proto, model, HasLocalSchema() ? &custom_schema_registries_ : nullptr); @@ -313,9 +313,9 @@ common::Status InferenceSession::Load(const void* model_data, int model_data_len "Failed to load model because protobuf parsing failed."); } #ifdef ENABLE_LANGUAGE_INTEROP_OPS - LoadInterOp(model_proto, interop_domains_, [&](const char* msg){LOGS(*session_logger_, WARNING) << msg;}); - for(const auto& domain: interop_domains_) { - AddCustomOpDomains({domain.get()}); + LoadInterOp(model_proto, interop_domains_, [&](const char* msg) { LOGS(*session_logger_, WARNING) << msg; }); + for (const auto& domain : interop_domains_) { + AddCustomOpDomains({domain.get()}); } #endif @@ -397,6 +397,9 @@ common::Status InferenceSession::CreateSubgraphSessionState(Graph& graph, Sessio // Pass threadpool to subgraph subgraph_session_state->SetThreadPool(session_state.GetThreadPool()); + // Pass fused function manager to subgraph + subgraph_session_state->GetMutableFuncMgr().SetFusedFuncs(session_state.GetFuncMgr()); + // recurse ORT_RETURN_IF_ERROR(CreateSubgraphSessionState(*subgraph, *subgraph_session_state));