From 39fb68b7612c0f1bf8a2dc888be77f60eee9786a Mon Sep 17 00:00:00 2001 From: Tao Qin <48697806+TaoQinMS@users.noreply.github.com> Date: Mon, 25 Mar 2019 14:09:33 -0700 Subject: [PATCH] Refactor InferenceSession class (#654) * Refactor InferenceSession interface * Make some member and func private * more protected members * more protected * reorder class members * reordering * reordering The InferenceSession was implemented in the pImpl idiom, which hides the actual implementation. There are requirements to expose the implementation to other new classes, so this change is to pave the way. The main changes are: abandon the pImpl idiom of InferenceSession --- onnxruntime/core/session/inference_session.cc | 1814 ++++++++--------- onnxruntime/core/session/inference_session.h | 144 +- 2 files changed, 959 insertions(+), 999 deletions(-) diff --git a/onnxruntime/core/session/inference_session.cc b/onnxruntime/core/session/inference_session.cc index 935343ba83..45609d7d64 100644 --- a/onnxruntime/core/session/inference_session.cc +++ b/onnxruntime/core/session/inference_session.cc @@ -34,15 +34,12 @@ #include "core/framework/sequential_executor.h" #include "core/framework/op_kernel_context_internal.h" #include "core/framework/parallel_executor.h" -#include "core/framework/path_lib.h" -#include "core/framework/session_state.h" #include "core/framework/session_state_initializer.h" #include "core/framework/tensorprotoutils.h" #include "core/framework/tensor_type_and_shape.h" #include "core/framework/utils.h" #include "core/optimizer/transformer_memcpy.h" #include "core/optimizer/graph_transformer.h" -#include "core/optimizer/graph_transformer_mgr.h" #include "core/optimizer/insert_cast_transformer.h" #include "core/optimizer/transformer_memcpy.h" #include "core/providers/cpu/cpu_execution_provider.h" @@ -161,971 +158,629 @@ struct CustomOpKernel : OpKernel { void* op_kernel_; }; -class InferenceSession::Impl { - public: - Impl(const SessionOptions& session_options, logging::LoggingManager* logging_manager) - : session_options_{session_options}, - graph_transformation_mgr_{session_options_.max_num_graph_transformation_steps}, - logging_manager_{logging_manager}, - session_state_{execution_providers_}, - insert_cast_transformer_{"CastFloat16Transformer"} { - ORT_ENFORCE(Environment::IsInitialized(), - "Environment must be initialized before creating an InferenceSession."); +InferenceSession::InferenceSession(const SessionOptions& session_options, logging::LoggingManager* logging_manager) + : session_state_{execution_providers_}, + session_options_{session_options}, + graph_transformation_mgr_{session_options_.max_num_graph_transformation_steps}, + logging_manager_{logging_manager}, + insert_cast_transformer_{"CastFloat16Transformer"} { + ORT_ENFORCE(Environment::IsInitialized(), + "Environment must be initialized before creating an InferenceSession."); - InitLogger(logging_manager); + InitLogger(logging_manager); - // currently the threadpool is used by the parallel executor only and hence - // there is no point creating it when only sequential execution is enabled. - if (!session_options.enable_sequential_execution) { - int pool_size = session_options_.session_thread_pool_size == 0 - ? std::thread::hardware_concurrency() / 2 - : session_options_.session_thread_pool_size; + // currently the threadpool is used by the parallel executor only and hence + // there is no point creating it when only sequential execution is enabled. + if (!session_options.enable_sequential_execution) { + int pool_size = session_options_.session_thread_pool_size == 0 + ? std::thread::hardware_concurrency() / 2 + : session_options_.session_thread_pool_size; #ifdef USE_EIGEN_THREADPOOL - thread_pool_ = std::make_unique(pool_size); + thread_pool_ = std::make_unique(pool_size); #else - thread_pool_ = std::make_unique(pool_size); + thread_pool_ = std::make_unique(pool_size); #endif - } - - session_state_.SetThreadPool(thread_pool_.get()); - session_profiler_.Initialize(session_logger_); - session_state_.SetProfiler(session_profiler_); - if (session_options.enable_profiling) { - StartProfiling(session_options.profile_file_prefix); - } } - common::Status RegisterExecutionProvider(std::unique_ptr p_exec_provider) { - if (p_exec_provider == nullptr) { - return Status(common::ONNXRUNTIME, common::FAIL, "Received nullptr for exec provider"); - } - - std::string provider_type = p_exec_provider->Type(); - VLOGS(*session_logger_, 1) << "Adding execution provider of type: " << provider_type; - execution_providers_.Add(provider_type, std::move(p_exec_provider)); - - return Status::OK(); + session_state_.SetThreadPool(thread_pool_.get()); + session_profiler_.Initialize(session_logger_); + session_state_.SetProfiler(session_profiler_); + if (session_options.enable_profiling) { + StartProfiling(session_options.profile_file_prefix); } - - common::Status RegisterGraphTransformer(std::unique_ptr p_graph_transformer, - const std::vector& providers, - TransformerLevel level) { - if (p_graph_transformer == nullptr) { - return Status(common::ONNXRUNTIME, common::FAIL, "Received nullptr for graph transformer"); - } - return graph_transformation_mgr_.Register(std::move(p_graph_transformer), level, providers); - } - - common::Status AddCustomTransformerList(const std::vector& transformers_to_enable) { - std::copy(transformers_to_enable.begin(), transformers_to_enable.end(), - std::back_inserter(transformers_to_enable_)); - - return Status::OK(); - } - - common::Status AddCustomOpDomains(const std::vector& op_domains) { - auto custom_registry = std::make_shared(); - - for (auto& domain : op_domains) { - SchemasContainer schemas_container; - - schemas_container.domain = domain->domain_; - schemas_container.baseline_opset_version = 1; - schemas_container.opset_version = 1000; - - for (auto& op : domain->custom_ops_) { - ONNX_NAMESPACE::OpSchema schema(op->GetName(op), "unknown", 0); - - auto input_count = op->GetInputTypeCount(op); - for (size_t i = 0; i < input_count; i++) { - auto type = op->GetInputType(op, i); - - schema.Input(i, "A", "Description", - DataTypeImpl::ToString(onnxruntime::DataTypeImpl::TensorTypeFromONNXEnum(type))); - } - - auto output_count = op->GetOutputTypeCount(op); - for (size_t i = 0; i < output_count; i++) { - auto type = op->GetOutputType(op, i); - - schema.Output(i, "A", "Description", - DataTypeImpl::ToString(onnxruntime::DataTypeImpl::TensorTypeFromONNXEnum(type))); - } - - schema.SinceVersion(1); - schema.AllowUncheckedAttributes(); - - schemas_container.schemas_list.push_back(schema); - - KernelDefBuilder def_builder; - def_builder.SetName(op->GetName(op)) - .SetDomain(onnxruntime::kOnnxDomain) - .SinceVersion(1) - .Provider(onnxruntime::kCpuExecutionProvider); - KernelCreateFn kernel_create_fn = [&op](const OpKernelInfo& info) -> OpKernel* { return new CustomOpKernel(info, *op); }; - KernelCreateInfo create_info(def_builder.Build(), kernel_create_fn); - - custom_registry->RegisterCustomKernel(create_info); - } - - ORT_RETURN_IF_ERROR(custom_registry->RegisterOpSet(schemas_container.schemas_list, - schemas_container.domain, - schemas_container.baseline_opset_version, - schemas_container.opset_version)); - } - RegisterCustomRegistry(custom_registry); - return Status::OK(); - } - - common::Status RegisterCustomRegistry(std::shared_ptr& custom_registry) { - if (custom_registry == nullptr) { - return Status(common::ONNXRUNTIME, common::FAIL, "Received nullptr for custom registry"); - } - - // Insert session-level customized kernel registry. - kernel_registry_manager_.RegisterKernelRegistry(custom_registry); - // if (custom_schema_registries_.empty()) - // custom_schema_registries_.push_back(); - custom_schema_registries_.push_back(custom_registry); - return Status::OK(); - } - - common::Status Load(std::function&)> loader, const std::string& event_name) { - Status status = Status::OK(); - auto tp = session_profiler_.StartTime(); - try { - std::lock_guard l(session_mutex_); - if (is_model_loaded_) { // already loaded - LOGS(*session_logger_, ERROR) << "This session already contains a loaded model."; - return common::Status(common::ONNXRUNTIME, common::MODEL_LOADED, - "This session already contains a loaded model."); - } - - std::shared_ptr p_tmp_model; - status = loader(p_tmp_model); - ORT_RETURN_IF_ERROR(status); - - model_ = p_tmp_model; - - status = DoPostLoadProcessing(*model_); - ORT_RETURN_IF_ERROR(status); - - // all steps complete, mark the model as loaded. - is_model_loaded_ = true; - } catch (const std::exception& ex) { - status = Status(common::ONNXRUNTIME, common::FAIL, "Exception during loading: " + std::string(ex.what())); - } catch (...) { - LOGS(*session_logger_, ERROR) << "Unknown exception in Load()"; - status = Status(common::ONNXRUNTIME, common::RUNTIME_EXCEPTION, "Encountered unknown exception in Load()"); - } - - if (session_profiler_.FEnabled()) { - session_profiler_.EndTimeAndRecordEvent(profiling::SESSION_EVENT, event_name, tp); - } - - return status; - } - - template - common::Status Load(const T& model_uri) { - model_location_ = ToWideString(model_uri); - auto loader = [this](std::shared_ptr& model) { - return onnxruntime::Model::Load(model_location_, model, HasLocalSchema() ? &custom_schema_registries_ : nullptr); - }; - - common::Status st = Load(loader, "model_loading_uri"); - if (!st.IsOK()) { - std::ostringstream oss; - oss << "Load model from " << ToMBString(model_uri) << " failed:" << st.ErrorMessage(); - return common::Status(st.Category(), st.Code(), oss.str()); - } - return Status::OK(); - } - - common::Status Load(const ModelProto& model_proto) { - auto loader = [this, &model_proto](std::shared_ptr& model) { - return onnxruntime::Model::Load(model_proto, model, HasLocalSchema() ? &custom_schema_registries_ : nullptr); - }; - - return Load(loader, "model_loading_proto"); - } - - common::Status Load(std::unique_ptr p_model_proto) { - auto loader = [this, &p_model_proto](std::shared_ptr& model) { - return onnxruntime::Model::Load(std::move(p_model_proto), model, - HasLocalSchema() ? &custom_schema_registries_ : nullptr); - }; - - return Load(loader, "model_loading_proto"); - } - - common::Status Load(std::istream& model_istream) { - auto loader = [this, &model_istream](std::shared_ptr& model) { - ModelProto model_proto; - - google::protobuf::io::IstreamInputStream zero_copy_input(&model_istream); - const bool result = model_proto.ParseFromZeroCopyStream(&zero_copy_input) && model_istream.eof(); - if (!result) { - return Status(common::ONNXRUNTIME, common::INVALID_PROTOBUF, - "Failed to load model because protobuf parsing failed."); - } - - return onnxruntime::Model::Load(model_proto, model, HasLocalSchema() ? &custom_schema_registries_ : nullptr); - }; - - return Load(loader, "model_loading_istream"); - } - - static common::Status TransformGraph(onnxruntime::Graph& graph, - const onnxruntime::GraphTransformerManager& graph_transformer_mgr, - const ExecutionProviders& providers, - KernelRegistryManager& kernel_registry_manager, - const InsertCastTransformer& insert_cast_transformer, - SessionState& session_state) { - // The transformer order: - // 1. built-in graph rewriter - // 2. each execution provider's transformer - // 3. do node placement according to kernel definition - // 4. insert copy nodes - // 5. insert cast nodes. - - // first apply global(execution provider independent), level 1(default/system/basic) graph to graph optimizations - ORT_RETURN_IF_ERROR(graph_transformer_mgr.ApplyTransformers(graph, TransformerLevel::Level1)); - - // Do partitioning based on execution providers' capability. - GraphPartitioner partitioner(kernel_registry_manager, providers); - ORT_RETURN_IF_ERROR(partitioner.Partition(graph, session_state.ExportDll(), session_state.GetMutableFuncMgr())); - - // apply transformers except default transformers - // Default transformers are required for correctness and they are owned and run by inference session - for (int i = static_cast(TransformerLevel::Level1); i < static_cast(TransformerLevel::MaxTransformerLevel); i++) { - ORT_RETURN_IF_ERROR(graph_transformer_mgr.ApplyTransformers(graph, static_cast(i))); - } - - bool modified = false; - // Insert cast node/s. - ORT_RETURN_IF_ERROR(insert_cast_transformer.Apply(graph, modified)); - - // Now every node should be already assigned to an execution provider - for (auto& node : graph.Nodes()) { - if (node.GetExecutionProviderType().empty()) { - std::ostringstream oss; - oss << "Could not find an implementation for the node "; - if (!node.Name().empty()) oss << node.Name() << ":"; - oss << node.OpType(); - if (node.Op()) { - oss << "(" << node.Op()->since_version() << ")"; - } - return Status(common::ONNXRUNTIME, common::NOT_IMPLEMENTED, oss.str()); - } - } - - std::vector provider_types; - for (auto& provider_ptr : providers) { - provider_types.push_back(provider_ptr->Type()); - } - - // Insert copy node/s. - MemcpyTransformer copy_transformer{provider_types, kernel_registry_manager}; - ORT_RETURN_IF_ERROR(copy_transformer.Apply(graph, modified)); - - return common::Status::OK(); - } - - /// Create SessionState instance for each subgraph as we need that for the GraphPartitioner - /// This will be initialized by InitializeSubgraphSessions. - common::Status CreateSubgraphSessionState(Graph& graph, SessionState& session_state) { - for (auto& node : graph.Nodes()) { - for (auto& entry : node.GetAttributeNameToMutableSubgraphMap()) { - auto& name = entry.first; - Graph* subgraph = entry.second; - ORT_ENFORCE(subgraph, "Main Graph instance should have populated all subgraphs when being resolved."); - - auto subgraph_session_state = std::make_unique(execution_providers_); - subgraph_session_state->SetProfiler(session_profiler_); - subgraph_session_state->SetLogger(*session_logger_); - - // recurse - ORT_RETURN_IF_ERROR(CreateSubgraphSessionState(*subgraph, *subgraph_session_state)); - - // add the subgraph SessionState instance to the parent graph SessionState so it can be retrieved - // by Compute() via OpKernelContextInternal. - session_state.AddSubgraphSessionState(node.Index(), name, std::move(subgraph_session_state)); - } - } - - return Status::OK(); - } - - /// iterate nodes in graph looking for ones with graph attribute/s - /// @param graph The graph to iterate - /// @param session_state The SessionState instance for 'graph'. - /// @remarks We pass in graph and session_state so we can handled nested subgraphs in the future - common::Status InitializeSubgraphSessions(Graph& graph, SessionState& session_state) { - for (auto& node : graph.Nodes()) { - for (const auto& entry : node.GetAttributeNameToMutableSubgraphMap()) { - auto& name = entry.first; - Graph& subgraph = *entry.second; - - SessionState* subgraph_session_state = session_state.GetMutableSubgraphSessionState(node.Index(), name); - ORT_ENFORCE(subgraph_session_state, "CreateSubgraphSessionState should have created an entry earlier."); - - // setup everything required to execute the subgraph and save it in subgraph_session_state - SessionStateInitializer initializer{model_location_, subgraph, *subgraph_session_state, execution_providers_, - kernel_registry_manager_}; - - ORT_RETURN_IF_ERROR(initializer.CreatePlan(&node, node.ImplicitInputDefs(), - session_options_.enable_sequential_execution)); - - ORT_RETURN_IF_ERROR(initializer.InitializeAndSave(&node.ImplicitInputDefs())); - - // LOGS(*session_logger_, VERBOSE) << std::make_pair(subgraph_info.session_state->GetExecutionPlan(), - // &*subgraph_info.session_state); - - // recurse - ORT_RETURN_IF_ERROR(InitializeSubgraphSessions(subgraph, *subgraph_session_state)); - } - } - - return Status::OK(); - } - - common::Status Initialize() { - Status status = Status::OK(); - auto tp = session_profiler_.StartTime(); - - try { - LOGS(*session_logger_, INFO) << "Initializing session."; - std::lock_guard l(session_mutex_); - if (!is_model_loaded_) { - LOGS(*session_logger_, ERROR) << "Model was not loaded"; - return common::Status(common::ONNXRUNTIME, common::FAIL, "Model was not loaded."); - } - - if (is_inited_) { // already initialized - LOGS(*session_logger_, INFO) << "Session has already been initialized."; - return common::Status::OK(); - } - - // Register default CPUExecutionProvider if user didn't provide it through the Register() calls - if (!execution_providers_.Get(onnxruntime::kCpuExecutionProvider)) { - LOGS(*session_logger_, INFO) << "Adding default CPU execution provider."; - CPUExecutionProviderInfo epi{session_options_.enable_cpu_mem_arena}; - ORT_RETURN_IF_ERROR(execution_providers_.Add(onnxruntime::kCpuExecutionProvider, - std::make_unique(epi))); - } - - // add predefined transformers - AddPredefinedTransformers(graph_transformation_mgr_, session_options_.graph_optimization_level, transformers_to_enable_); - - onnxruntime::Graph& graph = model_->MainGraph(); - - // Collect the kernel registries from execution provider instances; - // There are 2 kinds of kernel registries with priority from high to low as below, - // 1. Custom execution provider type specific kernel registries. - // 2. common execution provider type specific kernel registries. - // The 1st and 2nd ones are shared across sessions. - // The 1st ones should have already been registered via session-level API into KernelRegistryManager. - // - // Register 2nd registries into KernelRegistryManager. - ORT_RETURN_IF_ERROR(kernel_registry_manager_.RegisterKernels(execution_providers_)); - - SessionStateInitializer session_initializer{model_location_, graph, session_state_, execution_providers_, - kernel_registry_manager_}; - - // create SessionState for subgraphs as it's needed by the transformers - ORT_RETURN_IF_ERROR(CreateSubgraphSessionState(graph, session_state_)); - - // apply any transformations to the main graph and any subgraphs - ORT_RETURN_IF_ERROR(TransformGraph(graph, graph_transformation_mgr_, - execution_providers_, kernel_registry_manager_, - insert_cast_transformer_, - session_state_)); - - // now that all the transforms are done, call Resolve on the main graph. this will recurse into the subgraphs. - ORT_RETURN_IF_ERROR(graph.Resolve()); - - ORT_RETURN_IF_ERROR(session_initializer.CreatePlan(nullptr, {}, session_options_.enable_sequential_execution)); - ORT_RETURN_IF_ERROR(session_initializer.InitializeAndSave(nullptr)); - - // handle any subgraphs - ORT_RETURN_IF_ERROR(InitializeSubgraphSessions(graph, session_state_)); - - session_state_.CalculateNodeIndexInfo(); - - is_inited_ = true; - - LOGS(*session_logger_, INFO) << "Session successfully initialized."; - } catch (const NotImplementedException& ex) { - status = ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED, "Exception during initialization: ", ex.what()); - LOGS(*session_logger_, ERROR) << status.ErrorMessage(); - } catch (const std::exception& ex) { - status = ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Exception during initialization: ", ex.what()); - LOGS(*session_logger_, ERROR) << status.ErrorMessage(); - } catch (...) { - status = ORT_MAKE_STATUS(ONNXRUNTIME, RUNTIME_EXCEPTION, "Encountered unknown exception in Initialize()"); - LOGS(*session_logger_, ERROR) << status.ErrorMessage(); - } - - if (session_profiler_.FEnabled()) { - session_profiler_.EndTimeAndRecordEvent(profiling::SESSION_EVENT, "session_initialization", tp); - } - return status; - } - - int GetCurrentNumRuns() const { - return current_num_runs_.load(); - } - - static common::Status CheckTypes(MLDataType actual, MLDataType expected) { - if (actual == expected) { - return Status::OK(); - } - auto actual_name = std::string(typeid(*actual).name()); - auto expected_name = std::string(typeid(*expected).name()); - return Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT, - "Unexpected input data type. Actual: (" + actual_name + ") , expected: (" + expected_name + ")"); - } - - common::Status ValidateInputs(const std::vector& feed_names, - const std::vector& feeds) { - const auto begin_names = feed_names.cbegin(); - const auto end_names = feed_names.cend(); - std::unordered_set required_feed_ids; - for (auto& arg : required_input_def_list_) { - auto& arg_name = arg->Name(); - if (arg_name.empty()) { - continue; - } - - auto feed_names_entry = std::find(begin_names, end_names, arg_name); - if (feed_names_entry == end_names) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Missing required input: ", arg_name); - } - - auto idx = feed_names_entry - begin_names; - required_feed_ids.insert(idx); - auto& input_ml_value = feeds.at(idx); - auto expected_type = utils::GetMLDataType(*arg); - - if (input_ml_value.IsTensor()) { - auto expected_element_type = expected_type->AsTensorType()->GetElementType(); - auto input_element_type = input_ml_value.Get().DataType(); - ORT_RETURN_IF_ERROR(CheckTypes(input_element_type, expected_element_type)); - } else { - auto input_type = input_ml_value.Type(); - ORT_RETURN_IF_ERROR(CheckTypes(input_type, expected_type)); - } - } - - if (feeds.size() > required_feed_ids.size()) { - // More feeds are offered. - // In the case of overriding some initializers (which are also taken as graph inputs). - for (size_t i = 0; i < feeds.size(); ++i) { - if (required_feed_ids.count(i) > 0) { - continue; - } - auto iter = input_def_map_.find(feed_names[i]); - if (input_def_map_.end() == iter) { - std::ostringstream ostr; - std::for_each(std::begin(model_input_names_), - std::end(model_input_names_), - [&ostr](const std::string& elem) { - ostr << elem << " "; - }); - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Invalid Feed Input Names:", feed_names[i], - ". Valid input names are: ", ostr.str()); - } - - auto& input_ml_value = feeds.at(i); - ORT_ENFORCE(input_ml_value.IsTensor()); - auto input_element_type = input_ml_value.Get().DataType(); - - auto expected_type = utils::GetMLDataType(*iter->second); - auto expected_element_type = expected_type->AsTensorType()->GetElementType(); - - ORT_RETURN_IF_ERROR(CheckTypes(input_element_type, expected_element_type)); - } - } - - return Status::OK(); - } - - common::Status ValidateOutputs(const std::vector& output_names, - const std::vector* p_fetches) { - if (!p_fetches) { - return common::Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT, - "Output vector pointer is NULL"); - } - - if (output_names.empty()) { - return common::Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT, - "At least one output should be requested."); - } - - if (!p_fetches->empty() && - (output_names.size() != p_fetches->size())) { - std::ostringstream ostr; - ostr << "Output vector incorrectly sized: output_names.size(): " << output_names.size() - << "p_fetches->size(): " << p_fetches->size(); - return common::Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT, ostr.str()); - } - - bool valid = true; - std::ostringstream invalid_names; - for (const auto& name : output_names) { - if (model_output_names_.find(name) == model_output_names_.end()) { - valid = false; - invalid_names << " " << name; - } - } - - if (!valid) { - std::ostringstream ostr; - std::for_each(std::begin(model_output_names_), - std::end(model_output_names_), - [&ostr](const std::string& elem) { - ostr << elem << " "; - }); - return common::Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT, - "Invalid Output Names:" + invalid_names.str() + - " Valid output names are: " + ostr.str()); - } - - // TODO add more validation here like checking shape of the allocated buffers - - return common::Status::OK(); - } - - Status Run(const RunOptions& run_options, - const std::vector& feed_names, - const std::vector& feeds, - const std::vector& output_names, - std::vector* p_fetches) { - auto tp = session_profiler_.StartTime(); - Status retval = Status::OK(); - - try { - { - std::lock_guard l(session_mutex_); - if (!is_inited_) { - LOGS(*session_logger_, ERROR) << "Session was not initialized"; - retval = Status(common::ONNXRUNTIME, common::FAIL, "Session not initialized."); - } - } - - ORT_RETURN_IF_ERROR(ValidateInputs(feed_names, feeds)); - - // if the output vector is non-empty, ensure that its the same size as the output_names - ORT_RETURN_IF_ERROR(ValidateOutputs(output_names, p_fetches)); - - FeedsFetchesInfo info(feed_names, output_names); - ORT_RETURN_IF_ERROR(info.SetMLValueIdxs(session_state_.GetMLValueNameIdxMap())); - FeedsFetchesManager feeds_fetches_manager{std::move(info)}; - - if (!run_options.run_tag.empty()) { - LOGS(*session_logger_, INFO) << "Running with tag: " << run_options.run_tag; - } - - ++current_num_runs_; - - // TODO should we add this exec to the list of executors? i guess its not needed now? - - // scope of owned_run_logger is just the call to Execute. - // If Execute ever becomes async we need a different approach - std::unique_ptr owned_run_logger; - auto run_logger = CreateLoggerForRun(run_options, owned_run_logger); - - // info all execution providers InferenceSession:Run started - // TODO: only call OnRunStart for all providers in-use - for (auto& xp : execution_providers_) { - ORT_CHECK_AND_SET_RETVAL(xp->OnRunStart()); - } - - // execute the graph - ORT_CHECK_AND_SET_RETVAL( - utils::ExecuteGraph(session_state_, feeds_fetches_manager, feeds, *p_fetches, {}, - session_options_.enable_sequential_execution, run_options.terminate, run_logger, - false)); - - } catch (const std::exception& e) { - retval = Status(common::ONNXRUNTIME, common::FAIL, e.what()); - } catch (...) { - retval = Status(common::ONNXRUNTIME, common::RUNTIME_EXCEPTION, "Encountered unknown exception in Run()"); - } - - // info all execution providers InferenceSession:Run ended - for (auto& xp : execution_providers_) { - ORT_CHECK_AND_SET_RETVAL(xp->OnRunEnd()); - } - - --current_num_runs_; - if (session_profiler_.FEnabled()) { - session_profiler_.EndTimeAndRecordEvent(profiling::SESSION_EVENT, "model_run", tp); - } - - return retval; - } - - std::pair GetModelMetadata() const { - { - std::lock_guard l(session_mutex_); - if (!is_model_loaded_) { - LOGS(*session_logger_, ERROR) << "Model was not loaded"; - return std::make_pair(common::Status(common::ONNXRUNTIME, common::FAIL, "Model was not loaded."), - nullptr); - } - } - - return std::make_pair(common::Status::OK(), &model_metadata_); - } - - std::pair GetModelInputs() const { - { - std::lock_guard l(session_mutex_); - if (!is_model_loaded_) { - LOGS(*session_logger_, ERROR) << "Model was not loaded"; - return std::make_pair(common::Status(common::ONNXRUNTIME, common::FAIL, "Model was not loaded."), - nullptr); - } - } - - return std::make_pair(common::Status::OK(), &required_input_def_list_); - } - - std::pair GetModelOutputs() const { - { - std::lock_guard l(session_mutex_); - if (!is_model_loaded_) { - LOGS(*session_logger_, ERROR) << "Model was not loaded"; - return std::make_pair(common::Status(common::ONNXRUNTIME, common::FAIL, "Model was not loaded."), - nullptr); - } - } - - return std::make_pair(common::Status::OK(), &output_def_list_); - } - - common::Status NewIOBinding(std::unique_ptr* io_binding) { - { - std::lock_guard l(session_mutex_); - if (!is_inited_) { - LOGS(*session_logger_, ERROR) << "Session was not initialized"; - return common::Status(common::ONNXRUNTIME, common::FAIL, "Session not initialized."); - } - } - - // private constructor, can't use make_unique - *io_binding = std::unique_ptr(new IOBinding(session_state_)); - return Status::OK(); - } - - common::Status Run(const RunOptions& run_options, IOBinding& io_binding) { - // TODO should Run() call io_binding.SynchronizeInputs() or should it let the callers do it? - // io_binding.SynchronizeInputs(); - return Run(run_options, io_binding.feed_names_, io_binding.feeds_, io_binding.output_names_, &io_binding.outputs_); - } - - common::Status Run(IOBinding& io_binding) { - RunOptions run_options; - return Run(run_options, io_binding); - } - - template - void StartProfiling(const std::basic_string& file_prefix) { - std::basic_ostringstream ss; - ss << file_prefix << "_" << GetCurrentTimeString() << ".json"; - session_profiler_.StartProfiling(ss.str()); - } - - void StartProfiling(const logging::Logger* logger_ptr) { - session_profiler_.StartProfiling(logger_ptr); - } - - std::string EndProfiling() { - if (is_model_loaded_) { - return session_profiler_.EndProfiling(); - } - LOGS(*session_logger_, ERROR) << "Could not write a profile because no model was loaded."; - return std::string(); - } - - private: - bool HasLocalSchema() const { - return !custom_schema_registries_.empty(); - } - - // assumes model has already been loaded before - common::Status DoPostLoadProcessing(onnxruntime::Model& model) { - // TODO add other post load processing here - common::Status status = SaveModelMetadata(model); - return status; - } - - common::Status SaveModelMetadata(const onnxruntime::Model& model) { - VLOGS(*session_logger_, 1) << "Saving model metadata"; - const onnxruntime::Graph& graph = model.MainGraph(); - - // save model metadata - model_metadata_.producer_name = model.ProducerName(); - model_metadata_.description = model.DocString(); - model_metadata_.domain = model.Domain(); - model_metadata_.version = model.ModelVersion(); - model_metadata_.custom_metadata_map = model.MetaData(); - model_metadata_.graph_name = graph.Name(); - - // save required inputs - const auto& required_inputs = graph.GetInputs(); // inputs excluding initializers - required_input_def_list_.reserve(required_inputs.size()); - required_model_input_names_.reserve(required_inputs.size()); - for (const auto& elem : required_inputs) { - required_input_def_list_.push_back(elem); - required_model_input_names_.insert(elem->Name()); - } - - // save all valid inputs - auto& all_inputs = graph.GetInputsIncludingInitializers(); - input_def_map_.reserve(all_inputs.size()); - model_input_names_.reserve(all_inputs.size()); - for (auto elem : all_inputs) { - input_def_map_.insert({elem->Name(), elem}); - model_input_names_.insert(elem->Name()); - } - - // save outputs - const auto& outputs = graph.GetOutputs(); - output_def_list_.reserve(outputs.size()); - model_output_names_.reserve(outputs.size()); - for (const auto& elem : outputs) { - output_def_list_.push_back(elem); - model_output_names_.insert(elem->Name()); - } - - VLOGS(*session_logger_, 1) << "Done saving model metadata"; - return common::Status::OK(); - } - - // Create a Logger for a single execution if possible. Otherwise use the default logger. - // If a new logger is created, it will also be stored in new_run_logger, - // which must remain valid for the duration of the execution. - // If the default logger is used, new_run_logger will remain empty. - // The returned value should be used in the execution. - const logging::Logger& CreateLoggerForRun(const RunOptions& run_options, - std::unique_ptr& new_run_logger) { - const logging::Logger* run_logger; - - // create a per-run logger if we can - if (logging_manager_ != nullptr) { - std::string run_log_id{session_options_.session_logid}; - - if (!session_options_.session_logid.empty() && !run_options.run_tag.empty()) { - run_log_id += ":"; - } - - run_log_id += run_options.run_tag; - - if (run_options.run_log_verbosity_level > 0) { - new_run_logger = logging_manager_->CreateLogger(run_log_id, - logging::Severity::kVERBOSE, - false, - run_options.run_log_verbosity_level); - } else { - new_run_logger = logging_manager_->CreateLogger(run_log_id); - } - - run_logger = new_run_logger.get(); - VLOGS(*run_logger, 1) << "Created logger for run with id of " << run_log_id; - } else { - // fallback to using default logger. this does NOT have any session or run specific id/tag in it - run_logger = session_logger_; - VLOGS(*run_logger, 1) << "Using default logger for run " << run_options.run_tag; - } - - return *run_logger; - } - - void InitLogger(logging::LoggingManager* logging_manager) { - // create logger for session, using provided logging manager if possible - if (logging_manager != nullptr) { - std::string session_logid = !session_options_.session_logid.empty() - ? session_options_.session_logid - : "InferenceSession"; // there's probably a better default... - - if (session_options_.session_log_verbosity_level > 0) { - owned_session_logger_ = logging_manager->CreateLogger(session_logid, - logging::Severity::kVERBOSE, - false, - session_options_.session_log_verbosity_level); - } else { - owned_session_logger_ = logging_manager->CreateLogger(session_logid); - } - session_logger_ = owned_session_logger_.get(); - } else { - session_logger_ = &logging::LoggingManager::DefaultLogger(); - } - - session_state_.SetLogger(*session_logger_); - } - - // Registers all the predefined transformers with transformer manager - void AddPredefinedTransformers(GraphTransformerManager& transformer_manager, - TransformerLevel graph_optimization_level, - const std::vector& custom_list) { - auto add_transformers = [&](TransformerLevel level, std::vector&& providers, std::string t_name) { - // Generate and register rewrite rules for level - auto rewrite_rules_to_register = - transformer_utils::GenerateRewriteRules(level, &custom_list); - if (!rewrite_rules_to_register.empty()) { - std::unique_ptr graph_rewrite_rules = - std::make_unique(t_name + "_RuleBasedTransformer", - "Apply rewrite rules for " + t_name); - for (auto& entry : rewrite_rules_to_register) { - graph_rewrite_rules->Register(std::move(entry)); - } - transformer_manager.Register(std::move(graph_rewrite_rules), level, - std::move(providers)); - } - - // Generate and register transformers for level - auto transformers_to_register = transformer_utils::GenerateTransformers(level, &custom_list); - for (auto& entry : transformers_to_register) { - transformer_manager.Register(std::move(entry.first), level, std::move(entry.second)); - } - }; - - if ((graph_optimization_level >= TransformerLevel::Level1) || !custom_list.empty()) { - add_transformers(TransformerLevel::Level1, {}, "Level1"); - } - - if ((graph_optimization_level >= TransformerLevel::Level2) || !custom_list.empty()) { - add_transformers(TransformerLevel::Level2, {onnxruntime::kCpuExecutionProvider}, "Level2"); - } - } - - common::Status WaitForNotification(Notification* p_executor_done, int64_t timeout_in_ms) { - if (timeout_in_ms > 0) { - ORT_NOT_IMPLEMENTED(__FUNCTION__, "timeout_in_ms >0 is not supported"); // TODO - } - p_executor_done->WaitForNotification(); - - return Status::OK(); - } - - const SessionOptions session_options_; - - onnxruntime::GraphTransformerManager graph_transformation_mgr_; - - // List of transformers to run. When this list is not empty only the transformers in this list - // will be run regardless of the level set. - // .i.e This list overrides both SessionOptions.graph_optimization_level and predefined transformers. - std::vector transformers_to_enable_; - - /// Logging manager if provided. - logging::LoggingManager* logging_manager_; - - /// Logger for this session. WARNING: Will contain nullptr if logging_manager_ is nullptr. - std::unique_ptr owned_session_logger_; - - /// convenience pointer to logger. should always be the same as session_state_.Logger(); - const logging::Logger* session_logger_; - - // Profiler for this session. - profiling::Profiler session_profiler_; - - ExecutionProviders execution_providers_; - - KernelRegistryManager kernel_registry_manager_; - std::list> custom_schema_registries_; - - // The model served by this inference session instance. - // Currently this has to be a shared ptr because the Model::Load method - // returns a shared_ptr only. Ideally factory functions should always return - // unique_ptr for maximum flexibility. Client can always upgrade it to shared_ptr - // if they need. - std::shared_ptr model_; - - // A set of executors that can run in parallel. - std::vector> executors_; // TODO do we need this vector? - - // Immutable state for each op in the model. Shared by all executors. - SessionState session_state_; - - ModelMetadata model_metadata_; - InputDefList required_input_def_list_; - std::unordered_map input_def_map_; - OutputDefList output_def_list_; - - // names of model inputs and outputs used for quick validation. - std::unordered_set required_model_input_names_; - std::unordered_set model_input_names_; - std::unordered_set model_output_names_; - - // Environment for this session - // not used now; we'll need it when we introduce threadpool - // statically allocated pointer, no need to manage its lifetime. - //Env* env_; - - // Threadpool for this session - //thread::ThreadPool thread_pool_; // not used for now; will add it later when implementing RunAsync -#ifdef USE_EIGEN_THREADPOOL - std::unique_ptr thread_pool_; -#else - std::unique_ptr thread_pool_; -#endif - - // Number of concurrently running executors - std::atomic current_num_runs_; - - mutable onnxruntime::OrtMutex session_mutex_; // to ensure only one thread can invoke Load/Initialize - bool is_model_loaded_ = false; // GUARDED_BY(session_mutex_) - bool is_inited_ = false; // GUARDED_BY(session_mutex_) - - InsertCastTransformer insert_cast_transformer_; - // The file path of where the model was loaded. e.g. /tmp/test_squeezenet/model.onnx - std::basic_string model_location_; -}; // namespace onnxruntime - -// -// InferenceSession -// -InferenceSession::InferenceSession(const SessionOptions& session_options, - logging::LoggingManager* logging_manager) - : impl_(std::make_unique(session_options, logging_manager)) { } InferenceSession::~InferenceSession() = default; -common::Status InferenceSession::Load(const std::string& model_uri) { - return impl_->Load(model_uri); +common::Status InferenceSession::RegisterExecutionProvider(std::unique_ptr p_exec_provider) { + if (p_exec_provider == nullptr) { + return Status(common::ONNXRUNTIME, common::FAIL, "Received nullptr for exec provider"); + } + + std::string provider_type = p_exec_provider->Type(); + VLOGS(*session_logger_, 1) << "Adding execution provider of type: " << provider_type; + execution_providers_.Add(provider_type, std::move(p_exec_provider)); + + return Status::OK(); } + +common::Status InferenceSession::RegisterGraphTransformer(std::unique_ptr p_graph_transformer, + const std::vector& providers, + TransformerLevel level) { + if (p_graph_transformer == nullptr) { + return Status(common::ONNXRUNTIME, common::FAIL, "Received nullptr for graph transformer"); + } + return graph_transformation_mgr_.Register(std::move(p_graph_transformer), level, providers); +} + +common::Status InferenceSession::AddCustomTransformerList(const std::vector& transformers_to_enable) { + std::copy(transformers_to_enable.begin(), transformers_to_enable.end(), + std::back_inserter(transformers_to_enable_)); + + return Status::OK(); +} + +common::Status InferenceSession::AddCustomOpDomains(const std::vector& op_domains) { + auto custom_registry = std::make_shared(); + + for (auto& domain : op_domains) { + SchemasContainer schemas_container; + + schemas_container.domain = domain->domain_; + schemas_container.baseline_opset_version = 1; + schemas_container.opset_version = 1000; + + for (auto& op : domain->custom_ops_) { + ONNX_NAMESPACE::OpSchema schema(op->GetName(op), "unknown", 0); + + auto input_count = op->GetInputTypeCount(op); + for (size_t i = 0; i < input_count; i++) { + auto type = op->GetInputType(op, i); + + schema.Input(i, "A", "Description", + DataTypeImpl::ToString(onnxruntime::DataTypeImpl::TensorTypeFromONNXEnum(type))); + } + + auto output_count = op->GetOutputTypeCount(op); + for (size_t i = 0; i < output_count; i++) { + auto type = op->GetOutputType(op, i); + + schema.Output(i, "A", "Description", + DataTypeImpl::ToString(onnxruntime::DataTypeImpl::TensorTypeFromONNXEnum(type))); + } + + schema.SinceVersion(1); + schema.AllowUncheckedAttributes(); + + schemas_container.schemas_list.push_back(schema); + + KernelDefBuilder def_builder; + def_builder.SetName(op->GetName(op)) + .SetDomain(onnxruntime::kOnnxDomain) + .SinceVersion(1) + .Provider(onnxruntime::kCpuExecutionProvider); + KernelCreateFn kernel_create_fn = [&op](const OpKernelInfo& info) -> OpKernel* { return new CustomOpKernel(info, *op); }; + KernelCreateInfo create_info(def_builder.Build(), kernel_create_fn); + + custom_registry->RegisterCustomKernel(create_info); + } + + ORT_RETURN_IF_ERROR(custom_registry->RegisterOpSet(schemas_container.schemas_list, + schemas_container.domain, + schemas_container.baseline_opset_version, + schemas_container.opset_version)); + } + RegisterCustomRegistry(custom_registry); + return Status::OK(); +} + +common::Status InferenceSession::RegisterCustomRegistry(std::shared_ptr custom_registry) { + if (custom_registry == nullptr) { + return Status(common::ONNXRUNTIME, common::FAIL, "Received nullptr for custom registry"); + } + + // Insert session-level customized kernel registry. + kernel_registry_manager_.RegisterKernelRegistry(custom_registry); + // if (custom_schema_registries_.empty()) + // custom_schema_registries_.push_back(); + custom_schema_registries_.push_back(custom_registry); + return Status::OK(); +} + +common::Status InferenceSession::Load(std::function&)> loader, const std::string& event_name) { + Status status = Status::OK(); + auto tp = session_profiler_.StartTime(); + try { + std::lock_guard l(session_mutex_); + if (is_model_loaded_) { // already loaded + LOGS(*session_logger_, ERROR) << "This session already contains a loaded model."; + return common::Status(common::ONNXRUNTIME, common::MODEL_LOADED, + "This session already contains a loaded model."); + } + + std::shared_ptr p_tmp_model; + status = loader(p_tmp_model); + ORT_RETURN_IF_ERROR(status); + + model_ = p_tmp_model; + + status = DoPostLoadProcessing(*model_); + ORT_RETURN_IF_ERROR(status); + + // all steps complete, mark the model as loaded. + is_model_loaded_ = true; + } catch (const std::exception& ex) { + status = Status(common::ONNXRUNTIME, common::FAIL, "Exception during loading: " + std::string(ex.what())); + } catch (...) { + LOGS(*session_logger_, ERROR) << "Unknown exception in Load()"; + status = Status(common::ONNXRUNTIME, common::RUNTIME_EXCEPTION, "Encountered unknown exception in Load()"); + } + + if (session_profiler_.FEnabled()) { + session_profiler_.EndTimeAndRecordEvent(profiling::SESSION_EVENT, event_name, tp); + } + + return status; +} + +template +common::Status InferenceSession::Load(const std::basic_string& model_uri) { + model_location_ = ToWideString(model_uri); + auto loader = [this](std::shared_ptr& model) { + return onnxruntime::Model::Load(model_location_, model, HasLocalSchema() ? &custom_schema_registries_ : nullptr); + }; + + common::Status st = Load(loader, "model_loading_uri"); + if (!st.IsOK()) { + std::ostringstream oss; + oss << "Load model from " << ToMBString(model_uri) << " failed:" << st.ErrorMessage(); + return common::Status(st.Category(), st.Code(), oss.str()); + } + return Status::OK(); +} + +common::Status InferenceSession::Load(const std::string& model_uri) { + return Load(model_uri); +} + #ifdef _WIN32 common::Status InferenceSession::Load(const std::wstring& model_uri) { - return impl_->Load(model_uri); + return Load(model_uri); } #endif + +common::Status InferenceSession::Load(const ModelProto& model_proto) { + auto loader = [this, &model_proto](std::shared_ptr& model) { + return onnxruntime::Model::Load(model_proto, model, HasLocalSchema() ? &custom_schema_registries_ : nullptr); + }; + + return Load(loader, "model_loading_proto"); +} + +common::Status InferenceSession::Load(std::unique_ptr p_model_proto) { + auto loader = [this, &p_model_proto](std::shared_ptr& model) { + return onnxruntime::Model::Load(std::move(p_model_proto), model, + HasLocalSchema() ? &custom_schema_registries_ : nullptr); + }; + + return Load(loader, "model_loading_proto"); +} + common::Status InferenceSession::Load(std::istream& model_istream) { - return impl_->Load(model_istream); + auto loader = [this, &model_istream](std::shared_ptr& model) { + ModelProto model_proto; + + google::protobuf::io::IstreamInputStream zero_copy_input(&model_istream); + const bool result = model_proto.ParseFromZeroCopyStream(&zero_copy_input) && model_istream.eof(); + if (!result) { + return Status(common::ONNXRUNTIME, common::INVALID_PROTOBUF, + "Failed to load model because protobuf parsing failed."); + } + + return onnxruntime::Model::Load(model_proto, model, HasLocalSchema() ? &custom_schema_registries_ : nullptr); + }; + + return Load(loader, "model_loading_istream"); +} + +common::Status InferenceSession::TransformGraph(onnxruntime::Graph& graph, + const onnxruntime::GraphTransformerManager& graph_transformer_mgr, + const ExecutionProviders& providers, + KernelRegistryManager& kernel_registry_manager, + const InsertCastTransformer& insert_cast_transformer, + SessionState& session_state) { + // The transformer order: + // 1. built-in graph rewriter + // 2. each execution provider's transformer + // 3. do node placement according to kernel definition + // 4. insert copy nodes + // 5. insert cast nodes. + + // first apply global(execution provider independent), level 1(default/system/basic) graph to graph optimizations + ORT_RETURN_IF_ERROR(graph_transformer_mgr.ApplyTransformers(graph, TransformerLevel::Level1)); + + // Do partitioning based on execution providers' capability. + GraphPartitioner partitioner(kernel_registry_manager, providers); + ORT_RETURN_IF_ERROR(partitioner.Partition(graph, session_state.ExportDll(), session_state.GetMutableFuncMgr())); + + // apply transformers except default transformers + // Default transformers are required for correctness and they are owned and run by inference session + for (int i = static_cast(TransformerLevel::Level1); i < static_cast(TransformerLevel::MaxTransformerLevel); i++) { + ORT_RETURN_IF_ERROR(graph_transformer_mgr.ApplyTransformers(graph, static_cast(i))); + } + + bool modified = false; + // Insert cast node/s. + ORT_RETURN_IF_ERROR(insert_cast_transformer.Apply(graph, modified)); + + // Now every node should be already assigned to an execution provider + for (auto& node : graph.Nodes()) { + if (node.GetExecutionProviderType().empty()) { + std::ostringstream oss; + oss << "Could not find an implementation for the node "; + if (!node.Name().empty()) oss << node.Name() << ":"; + oss << node.OpType(); + if (node.Op()) { + oss << "(" << node.Op()->since_version() << ")"; + } + return Status(common::ONNXRUNTIME, common::NOT_IMPLEMENTED, oss.str()); + } + } + + std::vector provider_types; + for (auto& provider_ptr : providers) { + provider_types.push_back(provider_ptr->Type()); + } + + // Insert copy node/s. + MemcpyTransformer copy_transformer{provider_types, kernel_registry_manager}; + ORT_RETURN_IF_ERROR(copy_transformer.Apply(graph, modified)); + + return common::Status::OK(); +} + +/// Create SessionState instance for each subgraph as we need that for the GraphPartitioner +/// This will be initialized by InitializeSubgraphSessions. +common::Status InferenceSession::CreateSubgraphSessionState(Graph& graph, SessionState& session_state) { + for (auto& node : graph.Nodes()) { + for (auto& entry : node.GetAttributeNameToMutableSubgraphMap()) { + auto& name = entry.first; + Graph* subgraph = entry.second; + ORT_ENFORCE(subgraph, "Main Graph instance should have populated all subgraphs when being resolved."); + + auto subgraph_session_state = std::make_unique(execution_providers_); + subgraph_session_state->SetProfiler(session_profiler_); + subgraph_session_state->SetLogger(*session_logger_); + + // recurse + ORT_RETURN_IF_ERROR(CreateSubgraphSessionState(*subgraph, *subgraph_session_state)); + + // add the subgraph SessionState instance to the parent graph SessionState so it can be retrieved + // by Compute() via OpKernelContextInternal. + session_state.AddSubgraphSessionState(node.Index(), name, std::move(subgraph_session_state)); + } + } + + return Status::OK(); +} + +/// iterate nodes in graph looking for ones with graph attribute/s +/// @param graph The graph to iterate +/// @param session_state The SessionState instance for 'graph'. +/// @remarks We pass in graph and session_state so we can handled nested subgraphs in the future +common::Status InferenceSession::InitializeSubgraphSessions(Graph& graph, SessionState& session_state) { + for (auto& node : graph.Nodes()) { + for (const auto& entry : node.GetAttributeNameToMutableSubgraphMap()) { + auto& name = entry.first; + Graph& subgraph = *entry.second; + + SessionState* subgraph_session_state = session_state.GetMutableSubgraphSessionState(node.Index(), name); + ORT_ENFORCE(subgraph_session_state, "CreateSubgraphSessionState should have created an entry earlier."); + + // setup everything required to execute the subgraph and save it in subgraph_session_state + SessionStateInitializer initializer{model_location_, subgraph, *subgraph_session_state, execution_providers_, + kernel_registry_manager_}; + + ORT_RETURN_IF_ERROR(initializer.CreatePlan(&node, node.ImplicitInputDefs(), + session_options_.enable_sequential_execution)); + + ORT_RETURN_IF_ERROR(initializer.InitializeAndSave(&node.ImplicitInputDefs())); + + // LOGS(*session_logger_, VERBOSE) << std::make_pair(subgraph_info.session_state->GetExecutionPlan(), + // &*subgraph_info.session_state); + + // recurse + ORT_RETURN_IF_ERROR(InitializeSubgraphSessions(subgraph, *subgraph_session_state)); + } + } + + return Status::OK(); } common::Status InferenceSession::Initialize() { - return impl_->Initialize(); + Status status = Status::OK(); + auto tp = session_profiler_.StartTime(); + + try { + LOGS(*session_logger_, INFO) << "Initializing session."; + std::lock_guard l(session_mutex_); + if (!is_model_loaded_) { + LOGS(*session_logger_, ERROR) << "Model was not loaded"; + return common::Status(common::ONNXRUNTIME, common::FAIL, "Model was not loaded."); + } + + if (is_inited_) { // already initialized + LOGS(*session_logger_, INFO) << "Session has already been initialized."; + return common::Status::OK(); + } + + // Register default CPUExecutionProvider if user didn't provide it through the Register() calls + if (!execution_providers_.Get(onnxruntime::kCpuExecutionProvider)) { + LOGS(*session_logger_, INFO) << "Adding default CPU execution provider."; + CPUExecutionProviderInfo epi{session_options_.enable_cpu_mem_arena}; + ORT_RETURN_IF_ERROR(execution_providers_.Add(onnxruntime::kCpuExecutionProvider, + std::make_unique(epi))); + } + + // add predefined transformers + AddPredefinedTransformers(graph_transformation_mgr_, session_options_.graph_optimization_level, transformers_to_enable_); + + onnxruntime::Graph& graph = model_->MainGraph(); + + // Collect the kernel registries from execution provider instances; + // There are 2 kinds of kernel registries with priority from high to low as below, + // 1. Custom execution provider type specific kernel registries. + // 2. common execution provider type specific kernel registries. + // The 1st and 2nd ones are shared across sessions. + // The 1st ones should have already been registered via session-level API into KernelRegistryManager. + // + // Register 2nd registries into KernelRegistryManager. + ORT_RETURN_IF_ERROR(kernel_registry_manager_.RegisterKernels(execution_providers_)); + + SessionStateInitializer session_initializer{model_location_, graph, session_state_, execution_providers_, + kernel_registry_manager_}; + + // create SessionState for subgraphs as it's needed by the transformers + ORT_RETURN_IF_ERROR(CreateSubgraphSessionState(graph, session_state_)); + + // apply any transformations to the main graph and any subgraphs + ORT_RETURN_IF_ERROR(TransformGraph(graph, graph_transformation_mgr_, + execution_providers_, kernel_registry_manager_, + insert_cast_transformer_, + session_state_)); + + // now that all the transforms are done, call Resolve on the main graph. this will recurse into the subgraphs. + ORT_RETURN_IF_ERROR(graph.Resolve()); + + ORT_RETURN_IF_ERROR(session_initializer.CreatePlan(nullptr, {}, session_options_.enable_sequential_execution)); + ORT_RETURN_IF_ERROR(session_initializer.InitializeAndSave(nullptr)); + + // handle any subgraphs + ORT_RETURN_IF_ERROR(InitializeSubgraphSessions(graph, session_state_)); + + session_state_.CalculateNodeIndexInfo(); + + is_inited_ = true; + + LOGS(*session_logger_, INFO) << "Session successfully initialized."; + } catch (const NotImplementedException& ex) { + status = ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED, "Exception during initialization: ", ex.what()); + LOGS(*session_logger_, ERROR) << status.ErrorMessage(); + } catch (const std::exception& ex) { + status = ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Exception during initialization: ", ex.what()); + LOGS(*session_logger_, ERROR) << status.ErrorMessage(); + } catch (...) { + status = ORT_MAKE_STATUS(ONNXRUNTIME, RUNTIME_EXCEPTION, "Encountered unknown exception in Initialize()"); + LOGS(*session_logger_, ERROR) << status.ErrorMessage(); + } + + if (session_profiler_.FEnabled()) { + session_profiler_.EndTimeAndRecordEvent(profiling::SESSION_EVENT, "session_initialization", tp); + } + return status; } -common::Status InferenceSession::Run(const RunOptions& run_options, - const std::vector& feed_names, - const std::vector& feeds, - const std::vector& output_names, - std::vector* p_fetches) { - return impl_->Run(run_options, feed_names, feeds, output_names, p_fetches); +int InferenceSession::GetCurrentNumRuns() const { + return current_num_runs_.load(); +} + +common::Status InferenceSession::CheckTypes(MLDataType actual, MLDataType expected) { + if (actual == expected) { + return Status::OK(); + } + auto actual_name = std::string(typeid(*actual).name()); + auto expected_name = std::string(typeid(*expected).name()); + return Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT, + "Unexpected input data type. Actual: (" + actual_name + ") , expected: (" + expected_name + ")"); +} + +common::Status InferenceSession::ValidateInputs(const std::vector& feed_names, + const std::vector& feeds) { + const auto begin_names = feed_names.cbegin(); + const auto end_names = feed_names.cend(); + std::unordered_set required_feed_ids; + for (auto& arg : required_input_def_list_) { + auto& arg_name = arg->Name(); + if (arg_name.empty()) { + continue; + } + + auto feed_names_entry = std::find(begin_names, end_names, arg_name); + if (feed_names_entry == end_names) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Missing required input: ", arg_name); + } + + auto idx = feed_names_entry - begin_names; + required_feed_ids.insert(idx); + auto& input_ml_value = feeds.at(idx); + auto expected_type = utils::GetMLDataType(*arg); + + if (input_ml_value.IsTensor()) { + auto expected_element_type = expected_type->AsTensorType()->GetElementType(); + auto input_element_type = input_ml_value.Get().DataType(); + ORT_RETURN_IF_ERROR(CheckTypes(input_element_type, expected_element_type)); + } else { + auto input_type = input_ml_value.Type(); + ORT_RETURN_IF_ERROR(CheckTypes(input_type, expected_type)); + } + } + + if (feeds.size() > required_feed_ids.size()) { + // More feeds are offered. + // In the case of overriding some initializers (which are also taken as graph inputs). + for (size_t i = 0; i < feeds.size(); ++i) { + if (required_feed_ids.count(i) > 0) { + continue; + } + auto iter = input_def_map_.find(feed_names[i]); + if (input_def_map_.end() == iter) { + std::ostringstream ostr; + std::for_each(std::begin(model_input_names_), + std::end(model_input_names_), + [&ostr](const std::string& elem) { + ostr << elem << " "; + }); + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Invalid Feed Input Names:", feed_names[i], + ". Valid input names are: ", ostr.str()); + } + + auto& input_ml_value = feeds.at(i); + ORT_ENFORCE(input_ml_value.IsTensor()); + auto input_element_type = input_ml_value.Get().DataType(); + + auto expected_type = utils::GetMLDataType(*iter->second); + auto expected_element_type = expected_type->AsTensorType()->GetElementType(); + + ORT_RETURN_IF_ERROR(CheckTypes(input_element_type, expected_element_type)); + } + } + + return Status::OK(); +} + +common::Status InferenceSession::ValidateOutputs(const std::vector& output_names, + const std::vector* p_fetches) { + if (!p_fetches) { + return common::Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT, + "Output vector pointer is NULL"); + } + + if (output_names.empty()) { + return common::Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT, + "At least one output should be requested."); + } + + if (!p_fetches->empty() && + (output_names.size() != p_fetches->size())) { + std::ostringstream ostr; + ostr << "Output vector incorrectly sized: output_names.size(): " << output_names.size() + << "p_fetches->size(): " << p_fetches->size(); + return common::Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT, ostr.str()); + } + + bool valid = true; + std::ostringstream invalid_names; + for (const auto& name : output_names) { + if (model_output_names_.find(name) == model_output_names_.end()) { + valid = false; + invalid_names << " " << name; + } + } + + if (!valid) { + std::ostringstream ostr; + std::for_each(std::begin(model_output_names_), + std::end(model_output_names_), + [&ostr](const std::string& elem) { + ostr << elem << " "; + }); + return common::Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT, + "Invalid Output Names:" + invalid_names.str() + + " Valid output names are: " + ostr.str()); + } + + // TODO add more validation here like checking shape of the allocated buffers + + return common::Status::OK(); +} + +Status InferenceSession::Run(const RunOptions& run_options, + const std::vector& feed_names, + const std::vector& feeds, + const std::vector& output_names, + std::vector* p_fetches) { + auto tp = session_profiler_.StartTime(); + Status retval = Status::OK(); + + try { + { + std::lock_guard l(session_mutex_); + if (!is_inited_) { + LOGS(*session_logger_, ERROR) << "Session was not initialized"; + retval = Status(common::ONNXRUNTIME, common::FAIL, "Session not initialized."); + } + } + + ORT_RETURN_IF_ERROR(ValidateInputs(feed_names, feeds)); + + // if the output vector is non-empty, ensure that its the same size as the output_names + ORT_RETURN_IF_ERROR(ValidateOutputs(output_names, p_fetches)); + + FeedsFetchesInfo info(feed_names, output_names); + ORT_RETURN_IF_ERROR(info.SetMLValueIdxs(session_state_.GetMLValueNameIdxMap())); + FeedsFetchesManager feeds_fetches_manager{std::move(info)}; + + if (!run_options.run_tag.empty()) { + LOGS(*session_logger_, INFO) << "Running with tag: " << run_options.run_tag; + } + + ++current_num_runs_; + + // TODO should we add this exec to the list of executors? i guess its not needed now? + + // scope of owned_run_logger is just the call to Execute. + // If Execute ever becomes async we need a different approach + std::unique_ptr owned_run_logger; + auto run_logger = CreateLoggerForRun(run_options, owned_run_logger); + + // info all execution providers InferenceSession:Run started + // TODO: only call OnRunStart for all providers in-use + for (auto& xp : execution_providers_) { + ORT_CHECK_AND_SET_RETVAL(xp->OnRunStart()); + } + + // execute the graph + ORT_CHECK_AND_SET_RETVAL( + utils::ExecuteGraph(session_state_, feeds_fetches_manager, feeds, *p_fetches, {}, + session_options_.enable_sequential_execution, run_options.terminate, run_logger, + false)); + + } catch (const std::exception& e) { + retval = Status(common::ONNXRUNTIME, common::FAIL, e.what()); + } catch (...) { + retval = Status(common::ONNXRUNTIME, common::RUNTIME_EXCEPTION, "Encountered unknown exception in Run()"); + } + + // info all execution providers InferenceSession:Run ended + for (auto& xp : execution_providers_) { + ORT_CHECK_AND_SET_RETVAL(xp->OnRunEnd()); + } + + --current_num_runs_; + if (session_profiler_.FEnabled()) { + session_profiler_.EndTimeAndRecordEvent(profiling::SESSION_EVENT, "model_run", tp); + } + + return retval; } common::Status InferenceSession::Run(const NameMLValMap& feeds, const std::vector& output_names, std::vector* p_fetches) { - return Run({}, feeds, output_names, p_fetches); + return Run(RunOptions(), feeds, output_names, p_fetches); } common::Status InferenceSession::Run(const RunOptions& run_options, @@ -1148,76 +803,251 @@ common::Status InferenceSession::Run(const RunOptions& run_options, } std::pair InferenceSession::GetModelMetadata() const { - return impl_->GetModelMetadata(); + { + std::lock_guard l(session_mutex_); + if (!is_model_loaded_) { + LOGS(*session_logger_, ERROR) << "Model was not loaded"; + return std::make_pair(common::Status(common::ONNXRUNTIME, common::FAIL, "Model was not loaded."), + nullptr); + } + } + + return std::make_pair(common::Status::OK(), &model_metadata_); } std::pair InferenceSession::GetModelInputs() const { - return impl_->GetModelInputs(); + { + std::lock_guard l(session_mutex_); + if (!is_model_loaded_) { + LOGS(*session_logger_, ERROR) << "Model was not loaded"; + return std::make_pair(common::Status(common::ONNXRUNTIME, common::FAIL, "Model was not loaded."), + nullptr); + } + } + + return std::make_pair(common::Status::OK(), &required_input_def_list_); } std::pair InferenceSession::GetModelOutputs() const { - return impl_->GetModelOutputs(); -} + { + std::lock_guard l(session_mutex_); + if (!is_model_loaded_) { + LOGS(*session_logger_, ERROR) << "Model was not loaded"; + return std::make_pair(common::Status(common::ONNXRUNTIME, common::FAIL, "Model was not loaded."), + nullptr); + } + } -int InferenceSession::GetCurrentNumRuns() { - return impl_->GetCurrentNumRuns(); -} - -void InferenceSession::StartProfiling(const std::string& file_prefix) { - impl_->StartProfiling(file_prefix); -} - -#ifdef _WIN32 -void InferenceSession::StartProfiling(const std::wstring& file_prefix) { impl_->StartProfiling(file_prefix); } -#endif -void InferenceSession::StartProfiling(const logging::Logger* custom_logger) { - impl_->StartProfiling(custom_logger); -} - -std::string InferenceSession::EndProfiling() { - return impl_->EndProfiling(); -} - -common::Status InferenceSession::RegisterExecutionProvider(std::unique_ptr p_exec_provider) { - return impl_->RegisterExecutionProvider(std::move(p_exec_provider)); -} - -common::Status InferenceSession::RegisterGraphTransformer(std::unique_ptr p_graph_transformer, - const std::vector& providers, - TransformerLevel level) { - - return impl_->RegisterGraphTransformer(std::move(p_graph_transformer), providers, level); -} - -common::Status InferenceSession::AddCustomTransformerList(const std::vector& transformers_to_enable) { - return impl_->AddCustomTransformerList(transformers_to_enable); -} - -common::Status InferenceSession::RegisterCustomRegistry(std::shared_ptr custom_registry) { - return impl_->RegisterCustomRegistry(custom_registry); -} - -common::Status InferenceSession::Load(const ModelProto& model_proto) { - return impl_->Load(model_proto); -} - -common::Status InferenceSession::Load(std::unique_ptr p_model_proto) { - return impl_->Load(std::move(p_model_proto)); + return std::make_pair(common::Status::OK(), &output_def_list_); } common::Status InferenceSession::NewIOBinding(std::unique_ptr* io_binding) { - return impl_->NewIOBinding(io_binding); + { + std::lock_guard l(session_mutex_); + if (!is_inited_) { + LOGS(*session_logger_, ERROR) << "Session was not initialized"; + return common::Status(common::ONNXRUNTIME, common::FAIL, "Session not initialized."); + } + } + + // private constructor, can't use make_unique + *io_binding = std::unique_ptr(new IOBinding(session_state_)); + return Status::OK(); } common::Status InferenceSession::Run(const RunOptions& run_options, IOBinding& io_binding) { - return impl_->Run(run_options, io_binding); + // TODO should Run() call io_binding.SynchronizeInputs() or should it let the callers do it? + // io_binding.SynchronizeInputs(); + return Run(run_options, io_binding.feed_names_, io_binding.feeds_, io_binding.output_names_, &io_binding.outputs_); } common::Status InferenceSession::Run(IOBinding& io_binding) { - return impl_->Run(io_binding); + RunOptions run_options; + return Run(run_options, io_binding); } -common::Status InferenceSession::AddCustomOpDomains(const std::vector& ops) { - return impl_->AddCustomOpDomains(ops); +template +void InferenceSession::StartProfiling(const std::basic_string& file_prefix) { + std::basic_ostringstream ss; + ss << file_prefix << "_" << GetCurrentTimeString() << ".json"; + session_profiler_.StartProfiling(ss.str()); +} + +void InferenceSession::StartProfiling(const std::string& file_prefix) { + StartProfiling(file_prefix); +} + +#ifdef _WIN32 +void InferenceSession::StartProfiling(const std::wstring& file_prefix) { + StartProfiling(file_prefix); +} +#endif + +void InferenceSession::StartProfiling(const logging::Logger* logger_ptr) { + session_profiler_.StartProfiling(logger_ptr); +} + +std::string InferenceSession::EndProfiling() { + if (is_model_loaded_) { + return session_profiler_.EndProfiling(); + } + LOGS(*session_logger_, ERROR) << "Could not write a profile because no model was loaded."; + return std::string(); +} + +// assumes model has already been loaded before +common::Status InferenceSession::DoPostLoadProcessing(onnxruntime::Model& model) { + // TODO add other post load processing here + common::Status status = SaveModelMetadata(model); + return status; +} + +common::Status InferenceSession::SaveModelMetadata(const onnxruntime::Model& model) { + VLOGS(*session_logger_, 1) << "Saving model metadata"; + const onnxruntime::Graph& graph = model.MainGraph(); + + // save model metadata + model_metadata_.producer_name = model.ProducerName(); + model_metadata_.description = model.DocString(); + model_metadata_.domain = model.Domain(); + model_metadata_.version = model.ModelVersion(); + model_metadata_.custom_metadata_map = model.MetaData(); + model_metadata_.graph_name = graph.Name(); + + // save required inputs + const auto& required_inputs = graph.GetInputs(); // inputs excluding initializers + required_input_def_list_.reserve(required_inputs.size()); + required_model_input_names_.reserve(required_inputs.size()); + for (const auto& elem : required_inputs) { + required_input_def_list_.push_back(elem); + required_model_input_names_.insert(elem->Name()); + } + + // save all valid inputs + auto& all_inputs = graph.GetInputsIncludingInitializers(); + input_def_map_.reserve(all_inputs.size()); + model_input_names_.reserve(all_inputs.size()); + for (auto elem : all_inputs) { + input_def_map_.insert({elem->Name(), elem}); + model_input_names_.insert(elem->Name()); + } + + // save outputs + const auto& outputs = graph.GetOutputs(); + output_def_list_.reserve(outputs.size()); + model_output_names_.reserve(outputs.size()); + for (const auto& elem : outputs) { + output_def_list_.push_back(elem); + model_output_names_.insert(elem->Name()); + } + + VLOGS(*session_logger_, 1) << "Done saving model metadata"; + return common::Status::OK(); +} + +// Create a Logger for a single execution if possible. Otherwise use the default logger. +// If a new logger is created, it will also be stored in new_run_logger, +// which must remain valid for the duration of the execution. +// If the default logger is used, new_run_logger will remain empty. +// The returned value should be used in the execution. +const logging::Logger& InferenceSession::CreateLoggerForRun(const RunOptions& run_options, + std::unique_ptr& new_run_logger) { + const logging::Logger* run_logger; + + // create a per-run logger if we can + if (logging_manager_ != nullptr) { + std::string run_log_id{session_options_.session_logid}; + + if (!session_options_.session_logid.empty() && !run_options.run_tag.empty()) { + run_log_id += ":"; + } + + run_log_id += run_options.run_tag; + + if (run_options.run_log_verbosity_level > 0) { + new_run_logger = logging_manager_->CreateLogger(run_log_id, + logging::Severity::kVERBOSE, + false, + run_options.run_log_verbosity_level); + } else { + new_run_logger = logging_manager_->CreateLogger(run_log_id); + } + + run_logger = new_run_logger.get(); + VLOGS(*run_logger, 1) << "Created logger for run with id of " << run_log_id; + } else { + // fallback to using default logger. this does NOT have any session or run specific id/tag in it + run_logger = session_logger_; + VLOGS(*run_logger, 1) << "Using default logger for run " << run_options.run_tag; + } + + return *run_logger; +} + +void InferenceSession::InitLogger(logging::LoggingManager* logging_manager) { + // create logger for session, using provided logging manager if possible + if (logging_manager != nullptr) { + std::string session_logid = !session_options_.session_logid.empty() + ? session_options_.session_logid + : "InferenceSession"; // there's probably a better default... + + if (session_options_.session_log_verbosity_level > 0) { + owned_session_logger_ = logging_manager->CreateLogger(session_logid, + logging::Severity::kVERBOSE, + false, + session_options_.session_log_verbosity_level); + } else { + owned_session_logger_ = logging_manager->CreateLogger(session_logid); + } + session_logger_ = owned_session_logger_.get(); + } else { + session_logger_ = &logging::LoggingManager::DefaultLogger(); + } + + session_state_.SetLogger(*session_logger_); +} + +// Registers all the predefined transformers with transformer manager +void InferenceSession::AddPredefinedTransformers(GraphTransformerManager& transformer_manager, + TransformerLevel graph_optimization_level, + const std::vector& custom_list) { + auto add_transformers = [&](TransformerLevel level, std::vector&& providers, std::string t_name) { + // Generate and register rewrite rules for level + auto rewrite_rules_to_register = + transformer_utils::GenerateRewriteRules(level, &custom_list); + if (!rewrite_rules_to_register.empty()) { + std::unique_ptr graph_rewrite_rules = + std::make_unique(t_name + "_RuleBasedTransformer", + "Apply rewrite rules for " + t_name); + for (auto& entry : rewrite_rules_to_register) { + graph_rewrite_rules->Register(std::move(entry)); + } + transformer_manager.Register(std::move(graph_rewrite_rules), level, + std::move(providers)); + } + + // Generate and register transformers for level + auto transformers_to_register = transformer_utils::GenerateTransformers(level, &custom_list); + for (auto& entry : transformers_to_register) { + transformer_manager.Register(std::move(entry.first), level, std::move(entry.second)); + } + }; + + if ((graph_optimization_level >= TransformerLevel::Level1) || !custom_list.empty()) { + add_transformers(TransformerLevel::Level1, {}, "Level1"); + } + + if ((graph_optimization_level >= TransformerLevel::Level2) || !custom_list.empty()) { + add_transformers(TransformerLevel::Level2, {onnxruntime::kCpuExecutionProvider}, "Level2"); + } +} + +common::Status InferenceSession::WaitForNotification(Notification* p_executor_done, int64_t timeout_in_ms) { + if (timeout_in_ms > 0) { + ORT_NOT_IMPLEMENTED(__FUNCTION__, "timeout_in_ms >0 is not supported"); // TODO + } + p_executor_done->WaitForNotification(); + + return Status::OK(); } } // namespace onnxruntime diff --git a/onnxruntime/core/session/inference_session.h b/onnxruntime/core/session/inference_session.h index 4be765bbc0..d9b6b6d9e3 100644 --- a/onnxruntime/core/session/inference_session.h +++ b/onnxruntime/core/session/inference_session.h @@ -7,11 +7,19 @@ #include #include "core/common/common.h" -#include "core/common/status.h" -#include "core/framework/framework_common.h" -#include "core/graph/basic_types.h" #include "core/common/logging/logging.h" +#include "core/common/profiler.h" +#include "core/common/status.h" +#include "core/framework/execution_providers.h" +#include "core/framework/framework_common.h" +#include "core/framework/iexecutor.h" +#include "core/framework/kernel_registry_manager.h" +#include "core/framework/path_lib.h" +#include "core/framework/session_state.h" +#include "core/graph/basic_types.h" #include "core/optimizer/graph_transformer_level.h" +#include "core/optimizer/graph_transformer_mgr.h" +#include "core/optimizer/insert_cast_transformer.h" namespace onnxruntime { // forward declarations class GraphTransformer; @@ -29,8 +37,8 @@ struct OrtCustomOpDomain { namespace onnxruntime { class IExecutionProvider; // forward decl class IOBinding; - class CustomRegistry; +class Notification; namespace logging { class LoggingManager; @@ -245,7 +253,7 @@ class InferenceSession { /** * Get the current number of in-progress concurrent Run calls. */ - int GetCurrentNumRuns(); + int GetCurrentNumRuns() const; /** * Start profiling on this inference session. This simply turns on profiling events to be @@ -284,10 +292,132 @@ class InferenceSession { */ common::Status Load(std::unique_ptr p_model_proto); + common::Status DoPostLoadProcessing(onnxruntime::Model& model); + + /// convenience pointer to logger. should always be the same as session_state_.Logger(); + const logging::Logger* session_logger_; + + // The model served by this inference session instance. + // Currently this has to be a shared ptr because the Model::Load method + // returns a shared_ptr only. Ideally factory functions should always return + // unique_ptr for maximum flexibility. Client can always upgrade it to shared_ptr + // if they need. + std::shared_ptr model_; + + // Immutable state for each op in the model. Shared by all executors. + SessionState session_state_; + + // names of model inputs and outputs used for quick validation. + std::unordered_set required_model_input_names_; + std::unordered_set model_input_names_; + std::unordered_set model_output_names_; + + // The file path of where the model was loaded. e.g. /tmp/test_squeezenet/model.onnx + std::basic_string model_location_; + private: ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(InferenceSession); - class Impl; - std::unique_ptr impl_; + bool HasLocalSchema() const { + return !custom_schema_registries_.empty(); + } + + common::Status SaveModelMetadata(const onnxruntime::Model& model); + + // Create a Logger for a single execution if possible. Otherwise use the default logger. + // If a new logger is created, it will also be stored in new_run_logger, + // which must remain valid for the duration of the execution. + // If the default logger is used, new_run_logger will remain empty. + // The returned value should be used in the execution. + const logging::Logger& CreateLoggerForRun(const RunOptions& run_options, + std::unique_ptr& new_run_logger); + + common::Status Load(std::function&)> loader, const std::string& event_name); + + common::Status TransformGraph(onnxruntime::Graph& graph, + const onnxruntime::GraphTransformerManager& graph_transformer_mgr, + const ExecutionProviders& providers, + KernelRegistryManager& kernel_registry_manager, + const InsertCastTransformer& insert_cast_transformer, + SessionState& session_state); + + common::Status CreateSubgraphSessionState(Graph& graph, SessionState& session_state); + + common::Status InitializeSubgraphSessions(Graph& graph, SessionState& session_state); + + void AddPredefinedTransformers(GraphTransformerManager& transformer_manager, + TransformerLevel graph_optimization_level, + const std::vector& custom_list); + + void InitLogger(logging::LoggingManager* logging_manager); + + static common::Status CheckTypes(MLDataType actual, MLDataType expected); + + common::Status ValidateInputs(const std::vector& feed_names, + const std::vector& feeds); + + common::Status ValidateOutputs(const std::vector& output_names, + const std::vector* p_fetches); + + common::Status WaitForNotification(Notification* p_executor_done, int64_t timeout_in_ms); + + template + common::Status Load(const std::basic_string& model_uri); + + template + void StartProfiling(const std::basic_string& file_prefix); + + const SessionOptions session_options_; + + onnxruntime::GraphTransformerManager graph_transformation_mgr_; + + // List of transformers to run. When this list is not empty only the transformers in this list + // will be run regardless of the level set. + // .i.e This list overrides both SessionOptions.graph_optimization_level and predefined transformers. + std::vector transformers_to_enable_; + + /// Logging manager if provided. + logging::LoggingManager* logging_manager_; + + /// Logger for this session. WARNING: Will contain nullptr if logging_manager_ is nullptr. + std::unique_ptr owned_session_logger_; + + // Profiler for this session. + profiling::Profiler session_profiler_; + + ExecutionProviders execution_providers_; + + KernelRegistryManager kernel_registry_manager_; + std::list> custom_schema_registries_; + + // A set of executors that can run in parallel. + std::vector> executors_; // TODO do we need this vector? + + ModelMetadata model_metadata_; + InputDefList required_input_def_list_; + std::unordered_map input_def_map_; + OutputDefList output_def_list_; + +// Environment for this session +// not used now; we'll need it when we introduce threadpool +// statically allocated pointer, no need to manage its lifetime. +//Env* env_; + +// Threadpool for this session +//thread::ThreadPool thread_pool_; // not used for now; will add it later when implementing RunAsync +#ifdef USE_EIGEN_THREADPOOL + std::unique_ptr thread_pool_; +#else + std::unique_ptr thread_pool_; +#endif + + // Number of concurrently running executors + std::atomic current_num_runs_; + + mutable onnxruntime::OrtMutex session_mutex_; // to ensure only one thread can invoke Load/Initialize + bool is_model_loaded_ = false; // GUARDED_BY(session_mutex_) + bool is_inited_ = false; // GUARDED_BY(session_mutex_) + + InsertCastTransformer insert_cast_transformer_; }; } // namespace onnxruntime