Refactor InferenceSession class (#654)

* Refactor InferenceSession interface

* Make some member and func private

* more protected members

* more protected

* reorder class members

* reordering

* reordering

The InferenceSession was implemented in the pImpl idiom, which hides the actual implementation. There are requirements to expose the implementation to other new classes, so this change is to pave the way.

The main changes are: abandon the pImpl idiom of InferenceSession
This commit is contained in:
Tao Qin 2019-03-25 14:09:33 -07:00 committed by GitHub
parent c8f1da28c4
commit 39fb68b761
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 959 additions and 999 deletions

File diff suppressed because it is too large Load diff

View file

@ -7,11 +7,19 @@
#include <unordered_map>
#include "core/common/common.h"
#include "core/common/status.h"
#include "core/framework/framework_common.h"
#include "core/graph/basic_types.h"
#include "core/common/logging/logging.h"
#include "core/common/profiler.h"
#include "core/common/status.h"
#include "core/framework/execution_providers.h"
#include "core/framework/framework_common.h"
#include "core/framework/iexecutor.h"
#include "core/framework/kernel_registry_manager.h"
#include "core/framework/path_lib.h"
#include "core/framework/session_state.h"
#include "core/graph/basic_types.h"
#include "core/optimizer/graph_transformer_level.h"
#include "core/optimizer/graph_transformer_mgr.h"
#include "core/optimizer/insert_cast_transformer.h"
namespace onnxruntime { // forward declarations
class GraphTransformer;
@ -29,8 +37,8 @@ struct OrtCustomOpDomain {
namespace onnxruntime {
class IExecutionProvider; // forward decl
class IOBinding;
class CustomRegistry;
class Notification;
namespace logging {
class LoggingManager;
@ -245,7 +253,7 @@ class InferenceSession {
/**
* Get the current number of in-progress concurrent Run calls.
*/
int GetCurrentNumRuns();
int GetCurrentNumRuns() const;
/**
* Start profiling on this inference session. This simply turns on profiling events to be
@ -284,10 +292,132 @@ class InferenceSession {
*/
common::Status Load(std::unique_ptr<ONNX_NAMESPACE::ModelProto> p_model_proto);
common::Status DoPostLoadProcessing(onnxruntime::Model& model);
/// convenience pointer to logger. should always be the same as session_state_.Logger();
const logging::Logger* session_logger_;
// The model served by this inference session instance.
// Currently this has to be a shared ptr because the Model::Load method
// returns a shared_ptr only. Ideally factory functions should always return
// unique_ptr for maximum flexibility. Client can always upgrade it to shared_ptr
// if they need.
std::shared_ptr<onnxruntime::Model> model_;
// Immutable state for each op in the model. Shared by all executors.
SessionState session_state_;
// names of model inputs and outputs used for quick validation.
std::unordered_set<std::string> required_model_input_names_;
std::unordered_set<std::string> model_input_names_;
std::unordered_set<std::string> model_output_names_;
// The file path of where the model was loaded. e.g. /tmp/test_squeezenet/model.onnx
std::basic_string<PATH_CHAR_TYPE> model_location_;
private:
ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(InferenceSession);
class Impl;
std::unique_ptr<Impl> impl_;
bool HasLocalSchema() const {
return !custom_schema_registries_.empty();
}
common::Status SaveModelMetadata(const onnxruntime::Model& model);
// Create a Logger for a single execution if possible. Otherwise use the default logger.
// If a new logger is created, it will also be stored in new_run_logger,
// which must remain valid for the duration of the execution.
// If the default logger is used, new_run_logger will remain empty.
// The returned value should be used in the execution.
const logging::Logger& CreateLoggerForRun(const RunOptions& run_options,
std::unique_ptr<logging::Logger>& new_run_logger);
common::Status Load(std::function<common::Status(std::shared_ptr<Model>&)> loader, const std::string& event_name);
common::Status TransformGraph(onnxruntime::Graph& graph,
const onnxruntime::GraphTransformerManager& graph_transformer_mgr,
const ExecutionProviders& providers,
KernelRegistryManager& kernel_registry_manager,
const InsertCastTransformer& insert_cast_transformer,
SessionState& session_state);
common::Status CreateSubgraphSessionState(Graph& graph, SessionState& session_state);
common::Status InitializeSubgraphSessions(Graph& graph, SessionState& session_state);
void AddPredefinedTransformers(GraphTransformerManager& transformer_manager,
TransformerLevel graph_optimization_level,
const std::vector<std::string>& custom_list);
void InitLogger(logging::LoggingManager* logging_manager);
static common::Status CheckTypes(MLDataType actual, MLDataType expected);
common::Status ValidateInputs(const std::vector<std::string>& feed_names,
const std::vector<MLValue>& feeds);
common::Status ValidateOutputs(const std::vector<std::string>& output_names,
const std::vector<MLValue>* p_fetches);
common::Status WaitForNotification(Notification* p_executor_done, int64_t timeout_in_ms);
template <typename T>
common::Status Load(const std::basic_string<T>& model_uri);
template <typename T>
void StartProfiling(const std::basic_string<T>& file_prefix);
const SessionOptions session_options_;
onnxruntime::GraphTransformerManager graph_transformation_mgr_;
// List of transformers to run. When this list is not empty only the transformers in this list
// will be run regardless of the level set.
// .i.e This list overrides both SessionOptions.graph_optimization_level and predefined transformers.
std::vector<std::string> transformers_to_enable_;
/// Logging manager if provided.
logging::LoggingManager* logging_manager_;
/// Logger for this session. WARNING: Will contain nullptr if logging_manager_ is nullptr.
std::unique_ptr<logging::Logger> owned_session_logger_;
// Profiler for this session.
profiling::Profiler session_profiler_;
ExecutionProviders execution_providers_;
KernelRegistryManager kernel_registry_manager_;
std::list<std::shared_ptr<onnxruntime::IOnnxRuntimeOpSchemaCollection>> custom_schema_registries_;
// A set of executors that can run in parallel.
std::vector<std::unique_ptr<IExecutor>> executors_; // TODO do we need this vector?
ModelMetadata model_metadata_;
InputDefList required_input_def_list_;
std::unordered_map<std::string, const NodeArg*> input_def_map_;
OutputDefList output_def_list_;
// Environment for this session
// not used now; we'll need it when we introduce threadpool
// statically allocated pointer, no need to manage its lifetime.
//Env* env_;
// Threadpool for this session
//thread::ThreadPool thread_pool_; // not used for now; will add it later when implementing RunAsync
#ifdef USE_EIGEN_THREADPOOL
std::unique_ptr<Eigen::NonBlockingThreadPool> thread_pool_;
#else
std::unique_ptr<TaskThreadPool> thread_pool_;
#endif
// Number of concurrently running executors
std::atomic<int> current_num_runs_;
mutable onnxruntime::OrtMutex session_mutex_; // to ensure only one thread can invoke Load/Initialize
bool is_model_loaded_ = false; // GUARDED_BY(session_mutex_)
bool is_inited_ = false; // GUARDED_BY(session_mutex_)
InsertCastTransformer insert_cast_transformer_;
};
} // namespace onnxruntime