Fix a bug that fused func manager in subgraph session state is nullptr (#1251)

Description: This fixes nullptr of fused func manager issue when running fused function inside sub graph session state

Motivation and Context

The bug happens in running fused functions created IExecutionProvider::Compile inside sub graph, i.e. Scan, which causes crash.
The problem is that FuncInfo is collected into main graph's session state, before sub graph session state is created.
The fix is to share FuncInfo between main graph and sub graph.
This commit is contained in:
KeDengMS 2019-06-19 14:37:21 -07:00 committed by GitHub
parent 23838d9c2a
commit df68111b98
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 27 additions and 17 deletions

View file

@ -7,7 +7,7 @@ namespace onnxruntime {
class FuncManager {
public:
FuncManager() : fused_funcs_(std::make_unique<std::unordered_map<std::string, FuncInfo> >()), lib_loader_(std::make_unique<ExLibLoader>()) {}
FuncManager() : fused_funcs_(std::make_shared<std::unordered_map<std::string, FuncInfo> >()), lib_loader_(std::make_unique<ExLibLoader>()) {}
Status AddFuncInfo(const std::string& name, const std::string& dll_path);
@ -15,6 +15,10 @@ class FuncManager {
Status GetFuncs(const std::string& name, ComputeFunc* compute, CreateFunctionStateFunc* create, DestroyFunctionStateFunc* release) const;
void SetFusedFuncs(const FuncManager& func_mgr) {
fused_funcs_ = func_mgr.fused_funcs_;
}
struct FuncInfo {
std::string dso_path;
ComputeFunc compute_func;
@ -27,7 +31,10 @@ class FuncManager {
const std::string kCreateStateFuncSymbol = "Create_State_";
const std::string kReleaseStateFuncSymbol = "Release_State_";
std::unique_ptr<std::unordered_map<std::string, FuncInfo> > fused_funcs_;
// note that subgraph session state shares fused_funcs with main graph
// because it's filled in by the time main graph is traversed,
// while subgraph session state is created later
std::shared_ptr<std::unordered_map<std::string, FuncInfo> > fused_funcs_;
std::unique_ptr<ExLibLoader> lib_loader_;
ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(FuncManager);
};

View file

@ -225,9 +225,9 @@ common::Status InferenceSession::Load(const std::basic_string<T>& model_uri) {
model_location_ = ToWideString(model_uri);
auto loader = [this](std::shared_ptr<onnxruntime::Model>& model) {
#ifdef ENABLE_LANGUAGE_INTEROP_OPS
LoadInterOp(model_location_, interop_domains_, [&](const char* msg){LOGS(*session_logger_, WARNING) << msg;});
for(const auto& domain: interop_domains_) {
AddCustomOpDomains({domain.get()});
LoadInterOp(model_location_, interop_domains_, [&](const char* msg) { LOGS(*session_logger_, WARNING) << msg; });
for (const auto& domain : interop_domains_) {
AddCustomOpDomains({domain.get()});
}
#endif
return onnxruntime::Model::Load(model_location_, model, HasLocalSchema() ? &custom_schema_registries_ : nullptr);
@ -255,9 +255,9 @@ common::Status InferenceSession::Load(const std::wstring& model_uri) {
common::Status InferenceSession::Load(const ModelProto& model_proto) {
auto loader = [this, &model_proto](std::shared_ptr<onnxruntime::Model>& model) {
#ifdef ENABLE_LANGUAGE_INTEROP_OPS
LoadInterOp(model_proto, interop_domains_, [&](const char* msg){LOGS(*session_logger_, WARNING) << msg;});
for(const auto& domain: interop_domains_) {
AddCustomOpDomains({domain.get()});
LoadInterOp(model_proto, interop_domains_, [&](const char* msg) { LOGS(*session_logger_, WARNING) << msg; });
for (const auto& domain : interop_domains_) {
AddCustomOpDomains({domain.get()});
}
#endif
return onnxruntime::Model::Load(model_proto, model, HasLocalSchema() ? &custom_schema_registries_ : nullptr);
@ -269,9 +269,9 @@ common::Status InferenceSession::Load(const ModelProto& model_proto) {
common::Status InferenceSession::Load(std::unique_ptr<ModelProto> p_model_proto) {
auto loader = [this, &p_model_proto](std::shared_ptr<onnxruntime::Model>& model) {
#ifdef ENABLE_LANGUAGE_INTEROP_OPS
LoadInterOp(*p_model_proto, interop_domains_, [&](const char* msg){LOGS(*session_logger_, WARNING) << msg;});
for(const auto& domain: interop_domains_) {
AddCustomOpDomains({domain.get()});
LoadInterOp(*p_model_proto, interop_domains_, [&](const char* msg) { LOGS(*session_logger_, WARNING) << msg; });
for (const auto& domain : interop_domains_) {
AddCustomOpDomains({domain.get()});
}
#endif
return onnxruntime::Model::Load(std::move(p_model_proto), model,
@ -292,9 +292,9 @@ common::Status InferenceSession::Load(std::istream& model_istream) {
"Failed to load model because protobuf parsing failed.");
}
#ifdef ENABLE_LANGUAGE_INTEROP_OPS
LoadInterOp(model_proto, interop_domains_, [&](const char* msg){LOGS(*session_logger_, WARNING) << msg;});
for(const auto& domain: interop_domains_) {
AddCustomOpDomains({domain.get()});
LoadInterOp(model_proto, interop_domains_, [&](const char* msg) { LOGS(*session_logger_, WARNING) << msg; });
for (const auto& domain : interop_domains_) {
AddCustomOpDomains({domain.get()});
}
#endif
return onnxruntime::Model::Load(model_proto, model, HasLocalSchema() ? &custom_schema_registries_ : nullptr);
@ -313,9 +313,9 @@ common::Status InferenceSession::Load(const void* model_data, int model_data_len
"Failed to load model because protobuf parsing failed.");
}
#ifdef ENABLE_LANGUAGE_INTEROP_OPS
LoadInterOp(model_proto, interop_domains_, [&](const char* msg){LOGS(*session_logger_, WARNING) << msg;});
for(const auto& domain: interop_domains_) {
AddCustomOpDomains({domain.get()});
LoadInterOp(model_proto, interop_domains_, [&](const char* msg) { LOGS(*session_logger_, WARNING) << msg; });
for (const auto& domain : interop_domains_) {
AddCustomOpDomains({domain.get()});
}
#endif
@ -397,6 +397,9 @@ common::Status InferenceSession::CreateSubgraphSessionState(Graph& graph, Sessio
// Pass threadpool to subgraph
subgraph_session_state->SetThreadPool(session_state.GetThreadPool());
// Pass fused function manager to subgraph
subgraph_session_state->GetMutableFuncMgr().SetFusedFuncs(session_state.GetFuncMgr());
// recurse
ORT_RETURN_IF_ERROR(CreateSubgraphSessionState(*subgraph, *subgraph_session_state));