mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-17 21:10:43 +00:00
Fix a bug that fused func manager in subgraph session state is nullptr (#1251)
Description: This fixes nullptr of fused func manager issue when running fused function inside sub graph session state Motivation and Context The bug happens in running fused functions created IExecutionProvider::Compile inside sub graph, i.e. Scan, which causes crash. The problem is that FuncInfo is collected into main graph's session state, before sub graph session state is created. The fix is to share FuncInfo between main graph and sub graph.
This commit is contained in:
parent
23838d9c2a
commit
df68111b98
2 changed files with 27 additions and 17 deletions
|
|
@ -7,7 +7,7 @@ namespace onnxruntime {
|
|||
|
||||
class FuncManager {
|
||||
public:
|
||||
FuncManager() : fused_funcs_(std::make_unique<std::unordered_map<std::string, FuncInfo> >()), lib_loader_(std::make_unique<ExLibLoader>()) {}
|
||||
FuncManager() : fused_funcs_(std::make_shared<std::unordered_map<std::string, FuncInfo> >()), lib_loader_(std::make_unique<ExLibLoader>()) {}
|
||||
|
||||
Status AddFuncInfo(const std::string& name, const std::string& dll_path);
|
||||
|
||||
|
|
@ -15,6 +15,10 @@ class FuncManager {
|
|||
|
||||
Status GetFuncs(const std::string& name, ComputeFunc* compute, CreateFunctionStateFunc* create, DestroyFunctionStateFunc* release) const;
|
||||
|
||||
void SetFusedFuncs(const FuncManager& func_mgr) {
|
||||
fused_funcs_ = func_mgr.fused_funcs_;
|
||||
}
|
||||
|
||||
struct FuncInfo {
|
||||
std::string dso_path;
|
||||
ComputeFunc compute_func;
|
||||
|
|
@ -27,7 +31,10 @@ class FuncManager {
|
|||
const std::string kCreateStateFuncSymbol = "Create_State_";
|
||||
const std::string kReleaseStateFuncSymbol = "Release_State_";
|
||||
|
||||
std::unique_ptr<std::unordered_map<std::string, FuncInfo> > fused_funcs_;
|
||||
// note that subgraph session state shares fused_funcs with main graph
|
||||
// because it's filled in by the time main graph is traversed,
|
||||
// while subgraph session state is created later
|
||||
std::shared_ptr<std::unordered_map<std::string, FuncInfo> > fused_funcs_;
|
||||
std::unique_ptr<ExLibLoader> lib_loader_;
|
||||
ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(FuncManager);
|
||||
};
|
||||
|
|
|
|||
|
|
@ -225,9 +225,9 @@ common::Status InferenceSession::Load(const std::basic_string<T>& model_uri) {
|
|||
model_location_ = ToWideString(model_uri);
|
||||
auto loader = [this](std::shared_ptr<onnxruntime::Model>& model) {
|
||||
#ifdef ENABLE_LANGUAGE_INTEROP_OPS
|
||||
LoadInterOp(model_location_, interop_domains_, [&](const char* msg){LOGS(*session_logger_, WARNING) << msg;});
|
||||
for(const auto& domain: interop_domains_) {
|
||||
AddCustomOpDomains({domain.get()});
|
||||
LoadInterOp(model_location_, interop_domains_, [&](const char* msg) { LOGS(*session_logger_, WARNING) << msg; });
|
||||
for (const auto& domain : interop_domains_) {
|
||||
AddCustomOpDomains({domain.get()});
|
||||
}
|
||||
#endif
|
||||
return onnxruntime::Model::Load(model_location_, model, HasLocalSchema() ? &custom_schema_registries_ : nullptr);
|
||||
|
|
@ -255,9 +255,9 @@ common::Status InferenceSession::Load(const std::wstring& model_uri) {
|
|||
common::Status InferenceSession::Load(const ModelProto& model_proto) {
|
||||
auto loader = [this, &model_proto](std::shared_ptr<onnxruntime::Model>& model) {
|
||||
#ifdef ENABLE_LANGUAGE_INTEROP_OPS
|
||||
LoadInterOp(model_proto, interop_domains_, [&](const char* msg){LOGS(*session_logger_, WARNING) << msg;});
|
||||
for(const auto& domain: interop_domains_) {
|
||||
AddCustomOpDomains({domain.get()});
|
||||
LoadInterOp(model_proto, interop_domains_, [&](const char* msg) { LOGS(*session_logger_, WARNING) << msg; });
|
||||
for (const auto& domain : interop_domains_) {
|
||||
AddCustomOpDomains({domain.get()});
|
||||
}
|
||||
#endif
|
||||
return onnxruntime::Model::Load(model_proto, model, HasLocalSchema() ? &custom_schema_registries_ : nullptr);
|
||||
|
|
@ -269,9 +269,9 @@ common::Status InferenceSession::Load(const ModelProto& model_proto) {
|
|||
common::Status InferenceSession::Load(std::unique_ptr<ModelProto> p_model_proto) {
|
||||
auto loader = [this, &p_model_proto](std::shared_ptr<onnxruntime::Model>& model) {
|
||||
#ifdef ENABLE_LANGUAGE_INTEROP_OPS
|
||||
LoadInterOp(*p_model_proto, interop_domains_, [&](const char* msg){LOGS(*session_logger_, WARNING) << msg;});
|
||||
for(const auto& domain: interop_domains_) {
|
||||
AddCustomOpDomains({domain.get()});
|
||||
LoadInterOp(*p_model_proto, interop_domains_, [&](const char* msg) { LOGS(*session_logger_, WARNING) << msg; });
|
||||
for (const auto& domain : interop_domains_) {
|
||||
AddCustomOpDomains({domain.get()});
|
||||
}
|
||||
#endif
|
||||
return onnxruntime::Model::Load(std::move(p_model_proto), model,
|
||||
|
|
@ -292,9 +292,9 @@ common::Status InferenceSession::Load(std::istream& model_istream) {
|
|||
"Failed to load model because protobuf parsing failed.");
|
||||
}
|
||||
#ifdef ENABLE_LANGUAGE_INTEROP_OPS
|
||||
LoadInterOp(model_proto, interop_domains_, [&](const char* msg){LOGS(*session_logger_, WARNING) << msg;});
|
||||
for(const auto& domain: interop_domains_) {
|
||||
AddCustomOpDomains({domain.get()});
|
||||
LoadInterOp(model_proto, interop_domains_, [&](const char* msg) { LOGS(*session_logger_, WARNING) << msg; });
|
||||
for (const auto& domain : interop_domains_) {
|
||||
AddCustomOpDomains({domain.get()});
|
||||
}
|
||||
#endif
|
||||
return onnxruntime::Model::Load(model_proto, model, HasLocalSchema() ? &custom_schema_registries_ : nullptr);
|
||||
|
|
@ -313,9 +313,9 @@ common::Status InferenceSession::Load(const void* model_data, int model_data_len
|
|||
"Failed to load model because protobuf parsing failed.");
|
||||
}
|
||||
#ifdef ENABLE_LANGUAGE_INTEROP_OPS
|
||||
LoadInterOp(model_proto, interop_domains_, [&](const char* msg){LOGS(*session_logger_, WARNING) << msg;});
|
||||
for(const auto& domain: interop_domains_) {
|
||||
AddCustomOpDomains({domain.get()});
|
||||
LoadInterOp(model_proto, interop_domains_, [&](const char* msg) { LOGS(*session_logger_, WARNING) << msg; });
|
||||
for (const auto& domain : interop_domains_) {
|
||||
AddCustomOpDomains({domain.get()});
|
||||
}
|
||||
#endif
|
||||
|
||||
|
|
@ -397,6 +397,9 @@ common::Status InferenceSession::CreateSubgraphSessionState(Graph& graph, Sessio
|
|||
// Pass threadpool to subgraph
|
||||
subgraph_session_state->SetThreadPool(session_state.GetThreadPool());
|
||||
|
||||
// Pass fused function manager to subgraph
|
||||
subgraph_session_state->GetMutableFuncMgr().SetFusedFuncs(session_state.GetFuncMgr());
|
||||
|
||||
// recurse
|
||||
ORT_RETURN_IF_ERROR(CreateSubgraphSessionState(*subgraph, *subgraph_session_state));
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue