mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-03 23:49:44 +00:00
* Introduce OrtTasks to replace EventPool * return run_id to frontend * pass run_id to backward * OrtTasks support multiple bg_events * make message_queue a member of orttask * Replace MessageQueue with std::promise * Move status_promise into Task * Move terminate flag into Task * Reenable previously disabled UTs * Add unit tests * Replace condition variables with std::promise * Move to CreateBackgroundTask in the main thread * return status and output in forward_future * use throw for terminating background thread * cleanup tasks at destructor * reenable test_mixed_nnmodule_ortmodules_training * add mutex for ORTTasks functions * add mutex for bg_threads * delay tests before start * add ut for multi-task common backbone Co-authored-by: Sherlock Huang <bahuang@OrtTrainingDev3.af05slrtruoetgaxwwjv5nsq5e.px.internal.cloudapp.net>
2074 lines
86 KiB
C++
2074 lines
86 KiB
C++
// Copyright (c) Microsoft Corporation. All rights reserved.
|
|
// Licensed under the MIT License.
|
|
|
|
#include "core/graph/onnx_protobuf.h"
|
|
#include "core/session/inference_session.h"
|
|
|
|
#include <memory>
|
|
#include <sstream>
|
|
#include <unordered_set>
|
|
#include <list>
|
|
#include <string>
|
|
#include <thread>
|
|
|
|
#include "core/common/denormal.h"
|
|
#include "core/common/logging/logging.h"
|
|
#include "core/framework/allocatormgr.h"
|
|
#include "core/framework/error_code_helper.h"
|
|
#include "core/framework/execution_frame.h"
|
|
#include "core/framework/feeds_fetches_manager.h"
|
|
#include "core/framework/graph_partitioner.h"
|
|
#include "core/framework/kernel_def_builder.h"
|
|
#include "core/framework/kernel_registry.h"
|
|
#include "core/framework/mldata_type_utils.h"
|
|
#include "core/framework/TensorSeq.h"
|
|
#include "core/framework/tensorprotoutils.h"
|
|
#include "core/framework/tensor_type_and_shape.h"
|
|
#include "core/framework/op_kernel_context_internal.h"
|
|
#include "core/framework/ort_value_pattern_planner.h"
|
|
#include "core/framework/utils.h"
|
|
#include "core/graph/graph_viewer.h"
|
|
#include "core/graph/model.h"
|
|
#include "core/optimizer/transformer_memcpy.h"
|
|
#include "core/optimizer/graph_transformer.h"
|
|
#include "core/optimizer/insert_cast_transformer.h"
|
|
#include "core/optimizer/rule_based_graph_transformer.h"
|
|
#include "core/optimizer/graph_transformer_utils.h"
|
|
#include "core/platform/Barrier.h"
|
|
#include "core/platform/ort_mutex.h"
|
|
#include "core/platform/threadpool.h"
|
|
#include "core/providers/cpu/controlflow/utils.h"
|
|
#include "core/providers/cpu/cpu_execution_provider.h"
|
|
#include "core/flatbuffers/flatbuffers_utils.h"
|
|
#ifdef USE_DML // TODO: This is necessary for the workaround in TransformGraph
|
|
#include "core/providers/dml/DmlExecutionProvider/src/GraphTransformer.h"
|
|
#endif
|
|
#include "core/session/environment.h"
|
|
#include "core/session/IOBinding.h"
|
|
#include "core/session/inference_session_utils.h"
|
|
#include "core/session/onnxruntime_session_options_config_keys.h"
|
|
#include "core/util/protobuf_parsing_utils.h"
|
|
#include "core/util/thread_utils.h"
|
|
|
|
// custom ops are not available in a minimal build unless ORT_MINIMAL_BUILD_CUSTOM_OPS is set
|
|
#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_MINIMAL_BUILD_CUSTOM_OPS)
|
|
#include "core/framework/customregistry.h"
|
|
#include "core/session/custom_ops.h"
|
|
#endif
|
|
|
|
#ifdef ENABLE_TRAINING
|
|
#include "orttraining/training_ops/cpu/controlflow/ort_tasks.h"
|
|
#endif
|
|
|
|
using namespace ONNX_NAMESPACE;
|
|
using namespace onnxruntime::experimental;
|
|
using namespace onnxruntime::common;
|
|
|
|
namespace onnxruntime {
|
|
namespace {
|
|
template <typename T>
|
|
const T* GetDateFormatString();
|
|
|
|
template <>
|
|
inline const char* GetDateFormatString<char>() {
|
|
return "%Y-%m-%d_%H-%M-%S";
|
|
}
|
|
#ifdef _WIN32
|
|
template <>
|
|
inline const wchar_t* GetDateFormatString<wchar_t>() {
|
|
return L"%Y-%m-%d_%H-%M-%S";
|
|
}
|
|
#endif
|
|
// TODO: use LoggingManager::GetTimestamp and date::operator<<
|
|
// (see ostream_sink.cc for an example)
|
|
// to simplify this and match the log file timestamp format.
|
|
template <typename T>
|
|
inline std::basic_string<T> GetCurrentTimeString() {
|
|
auto now = std::chrono::system_clock::now();
|
|
auto in_time_t = std::chrono::system_clock::to_time_t(now);
|
|
std::tm local_tm; // NOLINT
|
|
|
|
#ifdef _WIN32
|
|
ORT_ENFORCE(localtime_s(&local_tm, &in_time_t) == 0);
|
|
#else
|
|
localtime_r(&in_time_t, &local_tm);
|
|
#endif
|
|
|
|
T time_str[32];
|
|
OrtStrftime<T>(time_str, sizeof(time_str), GetDateFormatString<T>(), &local_tm);
|
|
return std::basic_string<T>(time_str);
|
|
}
|
|
|
|
} // namespace
|
|
|
|
std::atomic<uint32_t> InferenceSession::global_session_id_{1};
|
|
|
|
// The current model versions for saving the ort format models
|
|
// This version is NOT onnxruntime version
|
|
// Only update this version when there is a file format change which will break the compatibilites
|
|
// Once this model version is updated, the kSupportedOrtModelVersions in IsOrtModelVersionSupported
|
|
// below will also need to be updated.
|
|
// See onnxruntime/core/session/flatbuffers/schema/README.md for more details on versioning.
|
|
// Version 1 - history begins
|
|
// Version 2 - add serialization/deserialization of sparse_initializer
|
|
// Version 3 - add `graph_doc_string` to Model
|
|
static constexpr const char* kOrtModelVersion = "3";
|
|
|
|
#if defined(ENABLE_ORT_FORMAT_LOAD)
|
|
// Check if the given ort model version is supported in this build
|
|
static bool IsOrtModelVersionSupported(const std::string& ort_model_version) {
|
|
// The ort model versions we will support in this build
|
|
// This may contain more versions than the kOrtModelVersion, based on the compatibilities
|
|
static const std::unordered_set<std::string> kSupportedOrtModelVersions{
|
|
std::string("1.4.0"), // This is a special model version for existing converted model
|
|
std::string("1"),
|
|
std::string("2"),
|
|
std::string(kOrtModelVersion),
|
|
};
|
|
|
|
return kSupportedOrtModelVersions.find(ort_model_version) != kSupportedOrtModelVersions.cend();
|
|
}
|
|
#endif // defined(ENABLE_ORT_FORMAT_LOAD)
|
|
|
|
static Status FinalizeSessionOptions(const SessionOptions& user_provided_session_options,
|
|
const ONNX_NAMESPACE::ModelProto& model_proto,
|
|
bool is_model_proto_parsed,
|
|
/*out*/ SessionOptions& finalized_session_options) {
|
|
#if !defined(ORT_MINIMAL_BUILD)
|
|
const logging::Logger& default_logger = logging::LoggingManager::DefaultLogger();
|
|
|
|
// By now the environment should have initialized. (It is enforced prior to this.)
|
|
const Env& env_instance = Env::Default();
|
|
|
|
bool session_options_from_model = false;
|
|
|
|
// Get the value held by the environment variable - kOrtLoadConfigFromModelEnvVar
|
|
const std::string load_config_from_model_env_var_value =
|
|
env_instance.GetEnvironmentVar(inference_session_utils::kOrtLoadConfigFromModelEnvVar);
|
|
|
|
// Ascertain if the model is to be read for the ORT config from the afore parsed env var
|
|
if (!load_config_from_model_env_var_value.empty()) {
|
|
// Check if the env var contains an unsupported value
|
|
if (load_config_from_model_env_var_value.length() > 1 ||
|
|
(load_config_from_model_env_var_value[0] != '0' && load_config_from_model_env_var_value[0] != '1')) {
|
|
std::ostringstream oss;
|
|
oss << "The only supported values for the environment variable "
|
|
<< inference_session_utils::kOrtLoadConfigFromModelEnvVar << " are '0' and '1'. "
|
|
<< "The environment variable contained the value: " << load_config_from_model_env_var_value;
|
|
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, oss.str());
|
|
}
|
|
|
|
if (load_config_from_model_env_var_value[0] == '1') {
|
|
LOGS(default_logger, INFO) << "Reading the provided model for the ORT config";
|
|
session_options_from_model = true;
|
|
}
|
|
}
|
|
|
|
// The model is to be read for an ORT config json that may hold some/all session options
|
|
if (session_options_from_model) {
|
|
SessionOptions constructed_session_options;
|
|
|
|
// In theory we should not hit this condition unless this internal class' APIs are being called incorrectly.
|
|
// This is a good sanity check to enforce that the model has been parsed prior to looking into it for ort config.
|
|
ORT_ENFORCE(is_model_proto_parsed, "ModelProto needs to be parsed to check for ORT config within it");
|
|
|
|
// Use default logger as the session_logger_ hasn't been initialized yet.
|
|
inference_session_utils::JsonConfigParser config_parser(default_logger);
|
|
|
|
auto status = config_parser.ParseOrtConfigJsonInModelProto(model_proto);
|
|
if (!status.IsOK()) {
|
|
return status;
|
|
}
|
|
|
|
status = config_parser.ParseSessionOptionsFromModelProto(constructed_session_options);
|
|
if (!status.IsOK()) {
|
|
return status;
|
|
}
|
|
|
|
// use the constructed session options
|
|
finalized_session_options = constructed_session_options;
|
|
} else {
|
|
// use user provided session options instance
|
|
finalized_session_options = user_provided_session_options;
|
|
}
|
|
#else
|
|
ORT_UNUSED_PARAMETER(model_proto);
|
|
ORT_UNUSED_PARAMETER(is_model_proto_parsed);
|
|
finalized_session_options = user_provided_session_options;
|
|
#endif // !defined(ORT_MINIMAL_BUILD)
|
|
|
|
return Status::OK();
|
|
}
|
|
|
|
void InferenceSession::ConstructorCommon(const SessionOptions& session_options,
|
|
const Environment& session_env) {
|
|
auto status = FinalizeSessionOptions(session_options, model_proto_, is_model_proto_parsed_, session_options_);
|
|
ORT_ENFORCE(status.IsOK(), "Could not finalize session options while constructing the inference session. Error Message: ",
|
|
status.ErrorMessage());
|
|
|
|
// The call to InitLogger depends on the final state of session_options_. Hence it should be invoked
|
|
// after the invocation of FinalizeSessionOptions.
|
|
InitLogger(logging_manager_); // this sets session_logger_ so that it can be used for logging after this point.
|
|
|
|
#if !defined(ORT_MINIMAL_BUILD)
|
|
// Update the number of steps for the graph transformer manager using the "finalized" session options
|
|
ORT_ENFORCE(graph_transformation_mgr_.SetSteps(session_options_.max_num_graph_transformation_steps).IsOK());
|
|
#endif
|
|
|
|
bool set_denormal_as_zero = session_options_.GetConfigOrDefault(kOrtSessionOptionsConfigSetDenormalAsZero, "0") == "1";
|
|
|
|
// The only first session option for flush-to-zero and denormal-as-zero is effective to main thread and OpenMP threads.
|
|
{
|
|
static std::once_flag once;
|
|
|
|
std::call_once(once, [&] {
|
|
#ifdef _OPENMP
|
|
InitializeWithDenormalAsZero(set_denormal_as_zero);
|
|
#endif
|
|
SetDenormalAsZero(set_denormal_as_zero);
|
|
|
|
LOGS(*session_logger_, INFO) << "Flush-to-zero and denormal-as-zero are " << ((set_denormal_as_zero) ? "on" : "off");
|
|
});
|
|
}
|
|
|
|
use_per_session_threads_ = session_options.use_per_session_threads;
|
|
|
|
if (use_per_session_threads_) {
|
|
LOGS(*session_logger_, INFO) << "Creating and using per session threadpools since use_per_session_threads_ is true";
|
|
{
|
|
OrtThreadPoolParams to = session_options_.intra_op_param;
|
|
if (to.name == nullptr) {
|
|
to.name = ORT_TSTR("intra-op");
|
|
}
|
|
to.set_denormal_as_zero = set_denormal_as_zero;
|
|
// If the thread pool can use all the processors, then
|
|
// we set affinity of each thread to each processor.
|
|
to.auto_set_affinity = to.thread_pool_size == 0 &&
|
|
session_options_.execution_mode == ExecutionMode::ORT_SEQUENTIAL &&
|
|
to.affinity_vec_len == 0;
|
|
thread_pool_ =
|
|
concurrency::CreateThreadPool(&Env::Default(), to, concurrency::ThreadPoolType::INTRA_OP);
|
|
}
|
|
if (session_options_.execution_mode == ExecutionMode::ORT_PARALLEL) {
|
|
OrtThreadPoolParams to = session_options_.inter_op_param;
|
|
// If the thread pool can use all the processors, then
|
|
// we set thread affinity.
|
|
to.auto_set_affinity =
|
|
to.thread_pool_size == 0 && session_options_.execution_mode == ExecutionMode::ORT_SEQUENTIAL;
|
|
if (to.name == nullptr)
|
|
to.name = ORT_TSTR("intra-op");
|
|
to.set_denormal_as_zero = set_denormal_as_zero;
|
|
inter_op_thread_pool_ =
|
|
concurrency::CreateThreadPool(&Env::Default(), to, concurrency::ThreadPoolType::INTER_OP);
|
|
if (inter_op_thread_pool_ == nullptr) {
|
|
LOGS(*session_logger_, INFO) << "Failed to create the inter-op thread pool for the parallel executor, setting ExecutionMode to SEQUENTIAL";
|
|
session_options_.execution_mode = ExecutionMode::ORT_SEQUENTIAL;
|
|
}
|
|
}
|
|
} else {
|
|
LOGS(*session_logger_, INFO) << "Using global/env threadpools since use_per_session_threads_ is false";
|
|
intra_op_thread_pool_from_env_ = session_env.GetIntraOpThreadPool();
|
|
inter_op_thread_pool_from_env_ = session_env.GetInterOpThreadPool();
|
|
ORT_ENFORCE(session_env.EnvCreatedWithGlobalThreadPools(),
|
|
"When the session is not configured to use per session"
|
|
" threadpools, the env must be created with the the CreateEnvWithGlobalThreadPools API.");
|
|
}
|
|
|
|
session_profiler_.Initialize(session_logger_);
|
|
if (session_options_.enable_profiling) {
|
|
StartProfiling(session_options_.profile_file_prefix);
|
|
}
|
|
|
|
telemetry_ = {};
|
|
// a monotonically increasing session id for use in telemetry
|
|
session_id_ = global_session_id_.fetch_add(1);
|
|
allocator_manager_ = std::make_shared<onnxruntime::AllocatorManager>();
|
|
}
|
|
|
|
InferenceSession::InferenceSession(const SessionOptions& session_options, const Environment& session_env)
|
|
:
|
|
#if !defined(ORT_MINIMAL_BUILD)
|
|
graph_transformation_mgr_(session_options.max_num_graph_transformation_steps),
|
|
insert_cast_transformer_("CastFloat16Transformer"),
|
|
#endif
|
|
logging_manager_(session_env.GetLoggingManager()),
|
|
environment_(session_env) {
|
|
// Initialize assets of this session instance
|
|
ConstructorCommon(session_options, session_env);
|
|
}
|
|
|
|
#if !defined(ORT_MINIMAL_BUILD)
|
|
InferenceSession::InferenceSession(const SessionOptions& session_options, const Environment& session_env,
|
|
const std::string& model_uri)
|
|
: model_location_(ToWideString(model_uri)),
|
|
graph_transformation_mgr_(session_options.max_num_graph_transformation_steps),
|
|
insert_cast_transformer_("CastFloat16Transformer"),
|
|
logging_manager_(session_env.GetLoggingManager()),
|
|
environment_(session_env) {
|
|
auto status = Model::Load(model_location_, model_proto_);
|
|
ORT_ENFORCE(status.IsOK(), "Given model could not be parsed while creating inference session. Error message: ",
|
|
status.ErrorMessage());
|
|
is_model_proto_parsed_ = true;
|
|
// Finalize session options and initialize assets of this session instance
|
|
ConstructorCommon(session_options, session_env);
|
|
}
|
|
|
|
#ifdef _WIN32
|
|
InferenceSession::InferenceSession(const SessionOptions& session_options,
|
|
const Environment& session_env,
|
|
const std::wstring& model_uri)
|
|
: graph_transformation_mgr_(session_options.max_num_graph_transformation_steps),
|
|
insert_cast_transformer_("CastFloat16Transformer"),
|
|
logging_manager_(session_env.GetLoggingManager()),
|
|
environment_(session_env) {
|
|
model_location_ = ToWideString(model_uri);
|
|
auto status = Model::Load(model_location_, model_proto_);
|
|
ORT_ENFORCE(status.IsOK(), "Given model could not be parsed while creating inference session. Error message: ",
|
|
status.ErrorMessage());
|
|
is_model_proto_parsed_ = true;
|
|
// Finalize session options and initialize assets of this session instance
|
|
ConstructorCommon(session_options, session_env);
|
|
}
|
|
#endif
|
|
|
|
InferenceSession::InferenceSession(const SessionOptions& session_options, const Environment& session_env,
|
|
std::istream& model_istream)
|
|
: graph_transformation_mgr_(session_options.max_num_graph_transformation_steps),
|
|
insert_cast_transformer_("CastFloat16Transformer"),
|
|
logging_manager_(session_env.GetLoggingManager()),
|
|
environment_(session_env) {
|
|
Status st = Model::Load(model_istream, &model_proto_);
|
|
ORT_ENFORCE(st.IsOK(), "Could not parse model successfully while constructing the inference session");
|
|
is_model_proto_parsed_ = true;
|
|
// Finalize session options and initialize assets of this session instance
|
|
ConstructorCommon(session_options, session_env);
|
|
}
|
|
|
|
InferenceSession::InferenceSession(const SessionOptions& session_options, const Environment& session_env,
|
|
const void* model_data, int model_data_len)
|
|
: graph_transformation_mgr_(session_options.max_num_graph_transformation_steps),
|
|
insert_cast_transformer_("CastFloat16Transformer"),
|
|
logging_manager_(session_env.GetLoggingManager()),
|
|
environment_(session_env) {
|
|
const bool result = model_proto_.ParseFromArray(model_data, model_data_len);
|
|
ORT_ENFORCE(result, "Could not parse model successfully while constructing the inference session");
|
|
is_model_proto_parsed_ = true;
|
|
// Finalize session options and initialize assets of this session instance
|
|
ConstructorCommon(session_options, session_env);
|
|
}
|
|
|
|
#endif // !defined(ORT_MINIMAL_BUILD)
|
|
|
|
InferenceSession::~InferenceSession() {
|
|
if (session_options_.enable_profiling) {
|
|
ORT_TRY {
|
|
EndProfiling();
|
|
}
|
|
ORT_CATCH(const std::exception& e) {
|
|
// TODO: Currently we have no way to transport this error to the API user
|
|
// Maybe this should be refactored, so that profiling must be explicitly
|
|
// started and stopped via C-API functions.
|
|
// And not like now a session option and therefore profiling must be started
|
|
// and stopped implicitly.
|
|
ORT_HANDLE_EXCEPTION([&]() {
|
|
LOGS(*session_logger_, ERROR) << "Error during EndProfiling(): " << e.what();
|
|
});
|
|
}
|
|
ORT_CATCH(...) {
|
|
LOGS(*session_logger_, ERROR) << "Unknown error during EndProfiling()";
|
|
}
|
|
}
|
|
|
|
#ifdef ENABLE_TRAINING
|
|
// TODO: Properly cancel outstanding background tasks
|
|
// Following implementation only handle the case where bg_thread is waiting for backward inputs
|
|
// Background thread can also be in other states, such as running Forward() or running Backward()
|
|
std::vector<int64_t> run_ids;
|
|
{
|
|
std::lock_guard<std::mutex> lock(bg_threads_mutex_);
|
|
for (auto it = bg_threads_.begin(); it != bg_threads_.end(); ++it) {
|
|
run_ids.push_back(it->first);
|
|
}
|
|
}
|
|
for (int64_t run_id : run_ids) {
|
|
if (!onnxruntime::contrib::OrtTasks::GetInstance().TaskIsCompleted(run_id)) {
|
|
CancelBackgroundTask(run_id);
|
|
}
|
|
}
|
|
#endif
|
|
|
|
#ifdef ONNXRUNTIME_ENABLE_INSTRUMENT
|
|
if (session_activity_started_)
|
|
TraceLoggingWriteStop(session_activity, "OrtInferenceSessionActivity");
|
|
#endif
|
|
#if !defined(ORT_MINIMAL_BUILD) && defined(ORT_MEMORY_PROFILE)
|
|
MemoryInfo::GenerateMemoryProfile();
|
|
#endif
|
|
}
|
|
|
|
common::Status InferenceSession::RegisterExecutionProvider(std::unique_ptr<IExecutionProvider> p_exec_provider) {
|
|
if (p_exec_provider == nullptr) {
|
|
return Status(common::ONNXRUNTIME, common::FAIL, "Received nullptr for exec provider");
|
|
}
|
|
|
|
std::lock_guard<onnxruntime::OrtMutex> l(session_mutex_);
|
|
|
|
if (is_inited_) {
|
|
// adding an EP is pointless as the graph as already been partitioned so no nodes will be assigned to
|
|
// the new EP
|
|
LOGS(*session_logger_, ERROR) << "Execution providers must be registered before the session is initialized. ";
|
|
return common::Status(common::ONNXRUNTIME, common::FAIL,
|
|
"Execution providers must be registered before the session is initialized.");
|
|
}
|
|
|
|
const std::string& provider_type = p_exec_provider->Type();
|
|
|
|
p_exec_provider->RegisterAllocator(allocator_manager_);
|
|
|
|
// Some session option values (default or user provided) may not work with some EPs.
|
|
// Rather than put the onus on the user to know these, make the appropriate change while logging the change.
|
|
if (provider_type == onnxruntime::kDmlExecutionProvider) {
|
|
// DML's memory is not byte addressable and hence mem pattern doesn't work.
|
|
if (session_options_.enable_mem_pattern) {
|
|
LOGS(*session_logger_, WARNING)
|
|
<< "Having memory pattern enabled is not supported while using the DML Execution Provider. "
|
|
<< "So disabling it for this session since it uses the DML Execution Provider.";
|
|
session_options_.enable_mem_pattern = false;
|
|
}
|
|
|
|
// Parallel execution mode does not support DML EP
|
|
if (session_options_.execution_mode != ExecutionMode::ORT_SEQUENTIAL) {
|
|
LOGS(*session_logger_, WARNING)
|
|
<< "Parallel execution mode does not support the DML Execution Provider. "
|
|
<< "So making the execution mode sequential for this session since it uses the DML Execution Provider.";
|
|
|
|
session_options_.execution_mode = ExecutionMode::ORT_SEQUENTIAL;
|
|
}
|
|
}
|
|
|
|
if (provider_type == onnxruntime::kCudaExecutionProvider) {
|
|
// Parallel execution mode does not support the CUDA EP
|
|
if (session_options_.execution_mode != ExecutionMode::ORT_SEQUENTIAL) {
|
|
LOGS(*session_logger_, WARNING)
|
|
<< "Parallel execution mode does not support the CUDA Execution Provider. "
|
|
<< "So making the execution mode sequential for this session since it uses the CUDA Execution Provider.";
|
|
session_options_.execution_mode = ExecutionMode::ORT_SEQUENTIAL;
|
|
}
|
|
|
|
auto trt_ep = execution_providers_.Get(kTensorrtExecutionProvider);
|
|
if (trt_ep) {
|
|
p_exec_provider->SetComputeStream(trt_ep->GetComputeStream());
|
|
}
|
|
}
|
|
|
|
VLOGS(*session_logger_, 1) << "Adding execution provider of type: " << provider_type;
|
|
auto p_data_xfr = p_exec_provider->GetDataTransfer();
|
|
if (p_data_xfr) {
|
|
auto st = data_transfer_mgr_.RegisterDataTransfer(std::move(p_data_xfr));
|
|
if (!st.IsOK()) {
|
|
return st;
|
|
}
|
|
}
|
|
|
|
p_exec_provider->SetLogger(session_logger_);
|
|
return execution_providers_.Add(provider_type, std::move(p_exec_provider));
|
|
}
|
|
|
|
// Custom Op support
|
|
#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_MINIMAL_BUILD_CUSTOM_OPS)
|
|
common::Status InferenceSession::AddCustomOpDomains(const std::vector<OrtCustomOpDomain*>& op_domains) {
|
|
std::shared_ptr<CustomRegistry> custom_registry;
|
|
ORT_RETURN_IF_ERROR_SESSIONID_(CreateCustomRegistry(op_domains, custom_registry));
|
|
ORT_RETURN_IF_ERROR_SESSIONID_(RegisterCustomRegistry(custom_registry));
|
|
return Status::OK();
|
|
}
|
|
|
|
common::Status InferenceSession::RegisterCustomRegistry(std::shared_ptr<CustomRegistry> custom_registry) {
|
|
if (custom_registry == nullptr) {
|
|
return Status(common::ONNXRUNTIME, common::FAIL, "Received nullptr for custom registry");
|
|
}
|
|
|
|
custom_registries_.push_back(custom_registry);
|
|
|
|
// Insert session-level customized kernel registry.
|
|
kernel_registry_manager_.RegisterKernelRegistry(custom_registry->GetKernelRegistry());
|
|
|
|
#if !defined(ORT_MINIMAL_BUILD)
|
|
custom_schema_registries_.push_back(custom_registry->GetOpschemaRegistry());
|
|
#endif
|
|
return Status::OK();
|
|
}
|
|
#endif // !defined(ORT_MINIMAL_BUILD) || defined(ORT_MINIMAL_BUILD_CUSTOM_OPS)
|
|
|
|
#if !defined(ORT_MINIMAL_BUILD)
|
|
common::Status InferenceSession::RegisterGraphTransformer(
|
|
std::unique_ptr<onnxruntime::GraphTransformer> p_graph_transformer, TransformerLevel level) {
|
|
if (p_graph_transformer == nullptr) {
|
|
return Status(common::ONNXRUNTIME, common::FAIL, "Received nullptr for graph transformer");
|
|
}
|
|
|
|
std::lock_guard<onnxruntime::OrtMutex> l(session_mutex_);
|
|
|
|
if (is_inited_) {
|
|
// adding a transformer now is pointless as the graph as already been transformed
|
|
LOGS(*session_logger_, ERROR) << "Graph transformers must be registered before the session is initialized.";
|
|
return common::Status(common::ONNXRUNTIME, common::FAIL,
|
|
"Graph transformers must be registered before the session is initialized.");
|
|
}
|
|
|
|
return graph_transformation_mgr_.Register(std::move(p_graph_transformer), level);
|
|
}
|
|
|
|
common::Status InferenceSession::AddCustomTransformerList(const std::vector<std::string>& transformers_to_enable) {
|
|
std::copy(transformers_to_enable.begin(), transformers_to_enable.end(), std::back_inserter(transformers_to_enable_));
|
|
|
|
return Status::OK();
|
|
}
|
|
|
|
common::Status InferenceSession::SaveToOrtFormat(const std::basic_string<ORTCHAR_T>& filepath) const {
|
|
ORT_RETURN_IF_NOT(FLATBUFFERS_LITTLEENDIAN, "ort format only supports little-edian machines");
|
|
|
|
// Get the byte size of the ModelProto and round it to the next MB and use it as flatbuffers' init_size
|
|
// TODO: Investigate whether we should set a max size, and clarify the cost of having a buffer smaller than
|
|
// what the total flatbuffers serialized size will be.
|
|
constexpr size_t m_bytes = 1024 * 1024;
|
|
size_t fbs_buffer_size = std::max(m_bytes, model_->ToProto().ByteSizeLong());
|
|
fbs_buffer_size = ((fbs_buffer_size + m_bytes - 1) / m_bytes) * m_bytes;
|
|
flatbuffers::FlatBufferBuilder builder(fbs_buffer_size);
|
|
|
|
auto ort_model_version = builder.CreateString(kOrtModelVersion);
|
|
flatbuffers::Offset<fbs::Model> model;
|
|
ORT_RETURN_IF_ERROR(
|
|
model_->SaveToOrtFormat(builder, model));
|
|
|
|
flatbuffers::Offset<fbs::SessionState> session_state;
|
|
ORT_RETURN_IF_ERROR(
|
|
session_state_->SaveToOrtFormat(builder, session_state));
|
|
|
|
fbs::InferenceSessionBuilder sb(builder);
|
|
sb.add_ort_version(ort_model_version);
|
|
sb.add_model(model);
|
|
sb.add_session_state(session_state);
|
|
auto session = sb.Finish();
|
|
builder.Finish(session, fbs::InferenceSessionIdentifier());
|
|
|
|
// TODO: Do we need to catch any std::exceptions from creating/writing to disk and convert to Status codes?
|
|
{
|
|
std::ofstream file(filepath, std::ios::binary);
|
|
|
|
uint8_t* buf = builder.GetBufferPointer();
|
|
int size = builder.GetSize();
|
|
file.write(reinterpret_cast<const char*>(buf), size);
|
|
file.close();
|
|
}
|
|
|
|
return Status::OK();
|
|
}
|
|
|
|
common::Status InferenceSession::Load(std::function<common::Status(std::shared_ptr<Model>&)> loader,
|
|
const std::string& event_name) {
|
|
Status status = Status::OK();
|
|
TimePoint tp;
|
|
if (session_profiler_.IsEnabled()) {
|
|
tp = session_profiler_.StartTime();
|
|
}
|
|
ORT_TRY {
|
|
std::lock_guard<onnxruntime::OrtMutex> l(session_mutex_);
|
|
if (is_model_loaded_) { // already loaded
|
|
LOGS(*session_logger_, ERROR) << "This session already contains a loaded model.";
|
|
return common::Status(common::ONNXRUNTIME, common::MODEL_LOADED, "This session already contains a loaded model.");
|
|
}
|
|
|
|
std::shared_ptr<onnxruntime::Model> p_tmp_model;
|
|
status = loader(p_tmp_model);
|
|
ORT_RETURN_IF_ERROR_SESSIONID_(status);
|
|
|
|
model_ = p_tmp_model;
|
|
|
|
status = DoPostLoadProcessing(*model_);
|
|
ORT_RETURN_IF_ERROR_SESSIONID_(status);
|
|
|
|
// all steps complete, mark the model as loaded.
|
|
is_model_loaded_ = true;
|
|
|
|
telemetry_.event_name_ = event_name;
|
|
}
|
|
ORT_CATCH(const std::exception& ex) {
|
|
ORT_HANDLE_EXCEPTION([&]() {
|
|
status = Status(common::ONNXRUNTIME, common::FAIL, "Exception during loading: " + std::string(ex.what()));
|
|
});
|
|
}
|
|
ORT_CATCH(...) {
|
|
LOGS(*session_logger_, ERROR) << "Unknown exception in Load()";
|
|
status = Status(common::ONNXRUNTIME, common::RUNTIME_EXCEPTION, "Encountered unknown exception in Load()");
|
|
}
|
|
|
|
if (session_profiler_.IsEnabled()) {
|
|
session_profiler_.EndTimeAndRecordEvent(profiling::SESSION_EVENT, event_name, tp);
|
|
}
|
|
|
|
return status;
|
|
}
|
|
|
|
template <typename T>
|
|
common::Status InferenceSession::Load(const std::basic_string<T>& model_uri) {
|
|
model_location_ = ToWideString(model_uri);
|
|
auto loader = [this](std::shared_ptr<onnxruntime::Model>& model) {
|
|
#ifdef ENABLE_LANGUAGE_INTEROP_OPS
|
|
LoadInterOp(model_location_, interop_domains_, [&](const char* msg) { LOGS(*session_logger_, WARNING) << msg; });
|
|
for (const auto& domain : interop_domains_) {
|
|
ORT_RETURN_IF_ERROR(AddCustomOpDomains({domain.get()}));
|
|
}
|
|
#endif
|
|
return onnxruntime::Model::Load(model_location_, model, HasLocalSchema() ? &custom_schema_registries_ : nullptr,
|
|
*session_logger_);
|
|
};
|
|
|
|
common::Status st = Load(loader, "model_loading_uri");
|
|
if (!st.IsOK()) {
|
|
std::ostringstream oss;
|
|
oss << "Load model from " << ToMBString(model_uri) << " failed:" << st.ErrorMessage();
|
|
return common::Status(st.Category(), st.Code(), oss.str());
|
|
}
|
|
return Status::OK();
|
|
}
|
|
|
|
#endif // !defined(ORT_MINIMAL_BUILD)
|
|
|
|
common::Status InferenceSession::Load(const std::string& model_uri) {
|
|
std::string model_type = session_options_.GetConfigOrDefault(kOrtSessionOptionsConfigLoadModelFormat, "");
|
|
bool has_explicit_type = !model_type.empty();
|
|
|
|
if ((has_explicit_type && model_type == "ORT") ||
|
|
(!has_explicit_type && experimental::utils::IsOrtFormatModel(model_uri))) {
|
|
#if defined(ENABLE_ORT_FORMAT_LOAD)
|
|
return LoadOrtModel(model_uri);
|
|
#else
|
|
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "ORT format model is not supported in this build.");
|
|
#endif
|
|
}
|
|
|
|
#if !defined(ORT_MINIMAL_BUILD)
|
|
if (is_model_proto_parsed_) {
|
|
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL,
|
|
"ModelProto corresponding to the model to be loaded has already been parsed. "
|
|
"Invoke Load().");
|
|
}
|
|
|
|
return Load<char>(model_uri);
|
|
#else
|
|
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "ONNX format model is not supported in this build.");
|
|
#endif
|
|
}
|
|
|
|
#ifdef _WIN32
|
|
common::Status InferenceSession::Load(const std::wstring& model_uri) {
|
|
std::string model_type = session_options_.GetConfigOrDefault(kOrtSessionOptionsConfigLoadModelFormat, "");
|
|
bool has_explicit_type = !model_type.empty();
|
|
|
|
if ((has_explicit_type && model_type == "ORT") ||
|
|
(!has_explicit_type && experimental::utils::IsOrtFormatModel(model_uri))) {
|
|
#if defined(ENABLE_ORT_FORMAT_LOAD)
|
|
return LoadOrtModel(model_uri);
|
|
#else
|
|
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "ORT format model is not supported in this build.");
|
|
#endif
|
|
}
|
|
|
|
#if !defined(ORT_MINIMAL_BUILD)
|
|
if (is_model_proto_parsed_) {
|
|
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL,
|
|
"ModelProto corresponding to the model to be loaded has already been parsed. "
|
|
"Invoke Load().");
|
|
}
|
|
|
|
return Load<PATH_CHAR_TYPE>(model_uri);
|
|
#else
|
|
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "ONNX format model is not supported in this build.");
|
|
#endif
|
|
}
|
|
#endif
|
|
|
|
common::Status InferenceSession::Load(const void* model_data, int model_data_len) {
|
|
std::string model_type = session_options_.GetConfigOrDefault(kOrtSessionOptionsConfigLoadModelFormat, "");
|
|
bool has_explicit_type = !model_type.empty();
|
|
|
|
if ((has_explicit_type && model_type == "ORT") ||
|
|
(!has_explicit_type &&
|
|
experimental::utils::IsOrtFormatModelBytes(model_data, model_data_len))) {
|
|
#if defined(ENABLE_ORT_FORMAT_LOAD)
|
|
return LoadOrtModel(model_data, model_data_len);
|
|
#else
|
|
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "ORT format model is not supported in this build.");
|
|
#endif
|
|
}
|
|
|
|
#if !defined(ORT_MINIMAL_BUILD)
|
|
if (is_model_proto_parsed_) {
|
|
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL,
|
|
"ModelProto corresponding to the model to be loaded has already been parsed. "
|
|
"Invoke Load().");
|
|
}
|
|
|
|
auto loader = [this, model_data, model_data_len](std::shared_ptr<onnxruntime::Model>& model) {
|
|
ModelProto model_proto;
|
|
|
|
const bool result = model_proto.ParseFromArray(model_data, model_data_len);
|
|
if (!result) {
|
|
return Status(common::ONNXRUNTIME, common::INVALID_PROTOBUF,
|
|
"Failed to load model because protobuf parsing failed.");
|
|
}
|
|
#ifdef ENABLE_LANGUAGE_INTEROP_OPS
|
|
LoadInterOp(model_proto, interop_domains_, [&](const char* msg) { LOGS(*session_logger_, WARNING) << msg; });
|
|
for (const auto& domain : interop_domains_) {
|
|
ORT_RETURN_IF_ERROR(AddCustomOpDomains({domain.get()}));
|
|
}
|
|
#endif
|
|
|
|
return onnxruntime::Model::Load(std::move(model_proto), PathString(), model,
|
|
HasLocalSchema() ? &custom_schema_registries_ : nullptr, *session_logger_);
|
|
};
|
|
|
|
return Load(loader, "model_loading_array");
|
|
#else
|
|
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "ONNX format model is not supported in this build.");
|
|
#endif
|
|
}
|
|
|
|
#if !defined(ORT_MINIMAL_BUILD)
|
|
|
|
common::Status InferenceSession::Load(const ModelProto& model_proto) {
|
|
if (is_model_proto_parsed_) {
|
|
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL,
|
|
"ModelProto corresponding to the model to be loaded has already been parsed. "
|
|
"Invoke Load().");
|
|
}
|
|
|
|
auto loader = [this, &model_proto](std::shared_ptr<onnxruntime::Model>& model) {
|
|
#ifdef ENABLE_LANGUAGE_INTEROP_OPS
|
|
LoadInterOp(model_proto, interop_domains_, [&](const char* msg) { LOGS(*session_logger_, WARNING) << msg; });
|
|
for (const auto& domain : interop_domains_) {
|
|
ORT_RETURN_IF_ERROR(AddCustomOpDomains({domain.get()}));
|
|
}
|
|
#endif
|
|
// This call will create a copy of model_proto and the constructed model instance will own the copy thereafter
|
|
return onnxruntime::Model::Load(model_proto, PathString(), model,
|
|
HasLocalSchema() ? &custom_schema_registries_ : nullptr, *session_logger_);
|
|
};
|
|
|
|
return Load(loader, "model_loading_proto");
|
|
}
|
|
|
|
common::Status InferenceSession::Load(std::unique_ptr<ModelProto> p_model_proto) {
|
|
if (is_model_proto_parsed_) {
|
|
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL,
|
|
"ModelProto corresponding to the model to be loaded has already been parsed. "
|
|
"Invoke Load().");
|
|
}
|
|
|
|
auto loader = [this, &p_model_proto](std::shared_ptr<onnxruntime::Model>& model) {
|
|
#ifdef ENABLE_LANGUAGE_INTEROP_OPS
|
|
LoadInterOp(*p_model_proto, interop_domains_, [&](const char* msg) { LOGS(*session_logger_, WARNING) << msg; });
|
|
for (const auto& domain : interop_domains_) {
|
|
ORT_RETURN_IF_ERROR(AddCustomOpDomains({domain.get()}));
|
|
}
|
|
#endif
|
|
return onnxruntime::Model::Load(std::move(*p_model_proto), PathString(), model,
|
|
HasLocalSchema() ? &custom_schema_registries_ : nullptr, *session_logger_);
|
|
};
|
|
|
|
return Load(loader, "model_loading_proto");
|
|
}
|
|
|
|
common::Status InferenceSession::Load(std::istream& model_istream) {
|
|
if (is_model_proto_parsed_) {
|
|
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL,
|
|
"ModelProto corresponding to the model to be loaded has already been parsed. "
|
|
"Invoke Load().");
|
|
}
|
|
|
|
auto loader = [this, &model_istream](std::shared_ptr<onnxruntime::Model>& model) {
|
|
ModelProto model_proto;
|
|
Status st = Model::Load(model_istream, &model_proto);
|
|
if (!st.IsOK()) {
|
|
return st;
|
|
}
|
|
#ifdef ENABLE_LANGUAGE_INTEROP_OPS
|
|
LoadInterOp(model_proto, interop_domains_, [&](const char* msg) { LOGS(*session_logger_, WARNING) << msg; });
|
|
for (const auto& domain : interop_domains_) {
|
|
ORT_RETURN_IF_ERROR(AddCustomOpDomains({domain.get()}));
|
|
}
|
|
#endif
|
|
return onnxruntime::Model::Load(std::move(model_proto), PathString(), model,
|
|
HasLocalSchema() ? &custom_schema_registries_ : nullptr, *session_logger_);
|
|
};
|
|
|
|
return Load(loader, "model_loading_istream");
|
|
}
|
|
|
|
common::Status InferenceSession::Load() {
|
|
if (!is_model_proto_parsed_) {
|
|
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL,
|
|
"ModelProto corresponding to the model to be loaded has not been parsed yet. "
|
|
"This API should be called in conjunction with a ctor that takes a model abstraction.");
|
|
}
|
|
|
|
auto loader = [this](std::shared_ptr<onnxruntime::Model>& model) {
|
|
#ifdef ENABLE_LANGUAGE_INTEROP_OPS
|
|
LoadInterOp(this->model_proto_, interop_domains_, [&](const char* msg) { LOGS(*session_logger_, WARNING) << msg; });
|
|
for (const auto& domain : interop_domains_) {
|
|
ORT_RETURN_IF_ERROR(AddCustomOpDomains({domain.get()}));
|
|
}
|
|
#endif
|
|
// Pass on ownership of the parsed ModelProto to the Model instance (its job here is done by this stage)
|
|
return Model::Load(std::move(this->model_proto_), model_location_, model,
|
|
HasLocalSchema() ? &custom_schema_registries_ : nullptr, *session_logger_);
|
|
};
|
|
|
|
return Load(loader, "model_loading_from_saved_proto");
|
|
}
|
|
|
|
common::Status InferenceSession::TransformGraph(onnxruntime::Graph& graph,
|
|
const onnxruntime::GraphTransformerManager& graph_transformer_mgr,
|
|
const ExecutionProviders& providers,
|
|
KernelRegistryManager& kernel_registry_manager,
|
|
const InsertCastTransformer& insert_cast_transformer,
|
|
SessionState& session_state,
|
|
bool saving_model_in_ort_format) {
|
|
// The transformer order:
|
|
// 1. built-in graph rewriter
|
|
// 2. each execution provider's transformer
|
|
// 3. do node placement according to kernel definition
|
|
// 4. insert copy nodes
|
|
// 5. insert cast nodes.
|
|
|
|
// first apply global(execution provider independent), level 1(default/system/basic) graph to graph optimizations
|
|
ORT_RETURN_IF_ERROR_SESSIONID_(
|
|
graph_transformer_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *session_logger_));
|
|
|
|
#ifdef USE_DML
|
|
// TODO: this is a temporary workaround to apply the DML EP's custom graph transformer prior to partitioning. This
|
|
// transformer applies DML-specific fusions that go beyond what ORT offers by default. Ideally the DML EP should
|
|
// apply these transforms during partitioning, but the full mutable Graph object isn't exposed to
|
|
// IExecutionProvider::GetCapability, which is necessary for the DML EP's transforms.
|
|
//
|
|
// To prevent this from interfering with other EPs, we only apply this transform if the DML EP is the only one that's
|
|
// registered (aside from the CPU EP, which is always registered by default.)
|
|
if (execution_providers_.Get(kDmlExecutionProvider) && execution_providers_.NumProviders() <= 2) {
|
|
Dml::GraphTransformer dml_transformer(onnxruntime::kDmlExecutionProvider,
|
|
execution_providers_.Get(kDmlExecutionProvider));
|
|
|
|
bool modified = false;
|
|
dml_transformer.Apply(graph, modified, *session_logger_);
|
|
}
|
|
#endif
|
|
|
|
// if saving model to ORT format we only assign nodes a custom EP can handle and don't compile them.
|
|
// we do this to preserve the original nodes in the model but prevent optimizers from changing them.
|
|
// at runtime, the ORT format model will re-do the partitioning/compilation of these nodes, which may change
|
|
// to cover fewer nodes due to device capabilities.
|
|
auto mode = saving_model_in_ort_format ? GraphPartitioner::Mode::kAssignOnly
|
|
: GraphPartitioner::Mode::kNormal;
|
|
|
|
// Do partitioning based on execution providers' capability.
|
|
GraphPartitioner partitioner(kernel_registry_manager, providers);
|
|
ORT_RETURN_IF_ERROR_SESSIONID_(partitioner.Partition(graph, session_state.ExportDll(),
|
|
session_state.GetMutableFuncMgr(), mode));
|
|
|
|
// apply transformers except default transformers
|
|
// Default transformers are required for correctness and they are owned and run by inference session
|
|
for (int i = static_cast<int>(TransformerLevel::Level1); i <= static_cast<int>(TransformerLevel::MaxLevel); i++) {
|
|
ORT_RETURN_IF_ERROR_SESSIONID_(
|
|
graph_transformer_mgr.ApplyTransformers(graph, static_cast<TransformerLevel>(i), *session_logger_));
|
|
}
|
|
|
|
bool modified = false;
|
|
// Insert cast node/s.
|
|
ORT_RETURN_IF_ERROR_SESSIONID_(insert_cast_transformer.Apply(graph, modified, *session_logger_));
|
|
|
|
// Now every node should be already assigned to an execution provider
|
|
std::unordered_map<std::string, std::vector<std::string>> node_placements;
|
|
bool is_verbose_mode = session_logger_->GetSeverity() == logging::Severity::kVERBOSE;
|
|
for (auto& node : graph.Nodes()) {
|
|
const auto& node_provider = node.GetExecutionProviderType();
|
|
if (node_provider.empty()) {
|
|
std::ostringstream oss;
|
|
oss << "Could not find an implementation for the node ";
|
|
if (!node.Name().empty())
|
|
oss << node.Name() << ":";
|
|
oss << node.OpType() << "(" << node.SinceVersion() << ")";
|
|
|
|
return Status(common::ONNXRUNTIME, common::NOT_IMPLEMENTED, oss.str());
|
|
} else {
|
|
if (is_verbose_mode) { // TODO: should we disable this if the number of nodes are above a certain threshold?
|
|
std::string node_str = node.OpType();
|
|
node_str += " (";
|
|
node_str += node.Name();
|
|
node_str += ")";
|
|
node_placements[node_provider].push_back(node_str);
|
|
}
|
|
}
|
|
}
|
|
|
|
// print placement info
|
|
if (is_verbose_mode) {
|
|
LOGS(*session_logger_, VERBOSE) << "Node placements";
|
|
if (node_placements.size() == 1) {
|
|
LOGS(*session_logger_, VERBOSE) << "All nodes have been placed on [" << node_placements.begin()->first << "].";
|
|
} else {
|
|
for (const auto& pr : node_placements) {
|
|
std::ostringstream all_nodes_str;
|
|
std::copy(pr.second.begin(), pr.second.end(), std::ostream_iterator<std::string>(all_nodes_str, ", "));
|
|
LOGS(*session_logger_, VERBOSE) << " Provider: [" << pr.first << "]"
|
|
<< ": [" << all_nodes_str.str() << "]";
|
|
}
|
|
}
|
|
}
|
|
|
|
std::vector<std::string> provider_types;
|
|
for (auto& provider_ptr : providers) {
|
|
provider_types.push_back(provider_ptr->Type());
|
|
}
|
|
|
|
// Insert copy node/s.
|
|
MemcpyTransformer copy_transformer{provider_types, kernel_registry_manager};
|
|
ORT_RETURN_IF_ERROR_SESSIONID_(copy_transformer.Apply(graph, modified, *session_logger_));
|
|
|
|
return common::Status::OK();
|
|
}
|
|
#endif // !defined(ORT_MINIMAL_BUILD)
|
|
|
|
#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)
|
|
Status InferenceSession::PartitionOrtFormatModel(onnxruntime::Graph& graph,
|
|
const ExecutionProviders& providers,
|
|
KernelRegistryManager& kernel_registry_manager,
|
|
SessionState& session_state) const {
|
|
std::unordered_map<std::string, uint64_t> compiled_kernel_hashes;
|
|
|
|
GraphPartitioner partitioner(kernel_registry_manager, providers);
|
|
ORT_RETURN_IF_ERROR_SESSIONID_(partitioner.Partition(graph, session_state.ExportDll(),
|
|
session_state.GetMutableFuncMgr(),
|
|
GraphPartitioner::Mode::kOrtFormatLoad,
|
|
&compiled_kernel_hashes));
|
|
|
|
if (!compiled_kernel_hashes.empty()) {
|
|
session_state.SetCompiledKernelHashes(std::move(compiled_kernel_hashes));
|
|
}
|
|
|
|
return Status::OK();
|
|
}
|
|
#endif
|
|
|
|
#if defined(ENABLE_ORT_FORMAT_LOAD)
|
|
template <typename T>
|
|
static Status LoadOrtModelBytes(const std::basic_string<T>& model_uri,
|
|
std::basic_string<ORTCHAR_T>& model_location,
|
|
std::vector<uint8_t>& bytes) {
|
|
size_t num_bytes = 0;
|
|
model_location = ToWideString(model_uri);
|
|
ORT_RETURN_IF_ERROR(Env::Default().GetFileLength(model_location.c_str(), num_bytes));
|
|
|
|
bytes.resize(num_bytes);
|
|
|
|
std::ifstream bytes_stream(model_uri, std::ifstream::in | std::ifstream::binary);
|
|
bytes_stream.read(reinterpret_cast<char*>(bytes.data()), num_bytes);
|
|
|
|
if (!bytes_stream) {
|
|
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL,
|
|
"Load model from ", ToMBString(model_uri), " failed. Only ",
|
|
bytes_stream.gcount(), "/", num_bytes, " bytes were able to be read.");
|
|
}
|
|
|
|
return Status::OK();
|
|
}
|
|
|
|
Status InferenceSession::LoadOrtModel(const std::string& model_uri) {
|
|
return LoadOrtModel(
|
|
[&]() {
|
|
ORT_RETURN_IF_ERROR(LoadOrtModelBytes(model_uri, model_location_, ort_format_model_bytes_));
|
|
return Status::OK();
|
|
});
|
|
}
|
|
|
|
#ifdef WIN32
|
|
Status InferenceSession::LoadOrtModel(const std::wstring& model_uri) {
|
|
return LoadOrtModel(
|
|
[&]() {
|
|
ORT_RETURN_IF_ERROR(LoadOrtModelBytes(model_uri, model_location_, ort_format_model_bytes_));
|
|
return Status::OK();
|
|
});
|
|
}
|
|
#endif
|
|
|
|
Status InferenceSession::LoadOrtModel(const void* model_data, int model_data_len) {
|
|
return LoadOrtModel([&]() {
|
|
// copy bytes as we need them to be available when InferenceSession::Initialize is called later.
|
|
//
|
|
// TODO: Provide Load API where we can take ownership of memory to avoid the copy,
|
|
// and/or a combined Load+Initialize where we don't need this temporary copy.
|
|
ort_format_model_bytes_.resize(model_data_len);
|
|
std::copy_n(reinterpret_cast<const uint8_t*>(model_data), model_data_len, ort_format_model_bytes_.data());
|
|
|
|
return Status::OK();
|
|
});
|
|
}
|
|
|
|
Status InferenceSession::LoadOrtModel(std::function<Status()> load_ort_format_model_bytes) {
|
|
static_assert(FLATBUFFERS_LITTLEENDIAN, "ORT format only supports little-endian machines");
|
|
|
|
std::lock_guard<onnxruntime::OrtMutex> l(session_mutex_);
|
|
|
|
if (is_model_loaded_) { // already loaded
|
|
Status status(common::ONNXRUNTIME, common::MODEL_LOADED, "This session already contains a loaded model.");
|
|
LOGS(*session_logger_, ERROR) << status.ErrorMessage();
|
|
return status;
|
|
}
|
|
|
|
if (is_inited_) {
|
|
Status status(common::ONNXRUNTIME, common::MODEL_LOADED, "This session has already been initialized.");
|
|
LOGS(*session_logger_, ERROR) << status.ErrorMessage();
|
|
return status;
|
|
}
|
|
|
|
ORT_RETURN_IF_ERROR(load_ort_format_model_bytes());
|
|
|
|
// Verify the ort_format_model_bytes_ is a valid InferenceSessionBuffer before we access the data
|
|
flatbuffers::Verifier verifier(ort_format_model_bytes_.data(), ort_format_model_bytes_.size());
|
|
ORT_RETURN_IF_NOT(fbs::VerifyInferenceSessionBuffer(verifier), "ORT model verification failed.");
|
|
|
|
const auto* fbs_session = fbs::GetInferenceSession(ort_format_model_bytes_.data());
|
|
ORT_RETURN_IF(nullptr == fbs_session, "InferenceSession is null. Invalid ORT format model.");
|
|
|
|
// Check version mismatch, for now we will only proceed when runtime version matches the model's ort version
|
|
const auto* fbs_ort_model_version = fbs_session->ort_version();
|
|
ORT_RETURN_IF(fbs_ort_model_version == nullptr, "Serialized version info is null. Invalid ORT format model.");
|
|
ORT_RETURN_IF_NOT(IsOrtModelVersionSupported(fbs_ort_model_version->str()),
|
|
"The ORT format model version [", fbs_ort_model_version->str(),
|
|
"] is not supported this build ", ORT_VERSION);
|
|
|
|
const auto* fbs_model = fbs_session->model();
|
|
ORT_RETURN_IF(nullptr == fbs_model, "Missing Model. Invalid ORT format model.");
|
|
|
|
// need to go from unique_ptr to shared_ptr when moving into model_
|
|
std::unique_ptr<Model> tmp_model;
|
|
#if !defined(ORT_MINIMAL_BUILD)
|
|
ORT_RETURN_IF_ERROR(Model::LoadFromOrtFormat(*fbs_model,
|
|
HasLocalSchema() ? &custom_schema_registries_ : nullptr,
|
|
*session_logger_, tmp_model));
|
|
|
|
#else
|
|
ORT_RETURN_IF_ERROR(Model::LoadFromOrtFormat(*fbs_model, *session_logger_, tmp_model));
|
|
#endif
|
|
|
|
ORT_RETURN_IF_ERROR(SaveModelMetadata(*tmp_model));
|
|
model_ = std::move(tmp_model);
|
|
|
|
// Initialize takes the session_mutex_ as well so we need to have released it prior to calling this
|
|
const auto* fbs_sess_state = fbs_session->session_state();
|
|
ORT_RETURN_IF(nullptr == fbs_sess_state, "SessionState is null. Invalid ORT format model.");
|
|
|
|
is_model_loaded_ = true;
|
|
|
|
return Status::OK();
|
|
}
|
|
#endif // defined(ENABLE_ORT_FORMAT_LOAD)
|
|
|
|
bool InferenceSession::IsInitialized() const {
|
|
std::lock_guard<onnxruntime::OrtMutex> l(session_mutex_);
|
|
return is_inited_;
|
|
}
|
|
|
|
static bool ModelHasFP16InputsHelper(const onnx::TypeProto& type_proto) {
|
|
switch (type_proto.value_case()) {
|
|
case ::onnx::TypeProto::ValueCase::kTensorType: {
|
|
if (type_proto.has_tensor_type()) {
|
|
auto& tensor_type = type_proto.tensor_type();
|
|
if (tensor_type.elem_type() == ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_FLOAT16) {
|
|
return true;
|
|
}
|
|
}
|
|
break;
|
|
}
|
|
case ::onnx::TypeProto::ValueCase::kSequenceType: {
|
|
if (type_proto.has_sequence_type()) {
|
|
auto& sequence_type = type_proto.sequence_type();
|
|
return ModelHasFP16InputsHelper(sequence_type.elem_type());
|
|
}
|
|
break;
|
|
}
|
|
case ::onnx::TypeProto::ValueCase::kMapType: {
|
|
if (type_proto.has_map_type()) {
|
|
auto& map_type = type_proto.map_type();
|
|
return ModelHasFP16InputsHelper(map_type.value_type());
|
|
}
|
|
break;
|
|
}
|
|
default:
|
|
break;
|
|
}
|
|
return false;
|
|
}
|
|
|
|
static bool ModelHasFP16Inputs(const Graph& graph) {
|
|
for (auto& input : graph.GetInputs()) {
|
|
if (input->Exists() && ModelHasFP16InputsHelper(*(input->TypeAsProto()))) {
|
|
return true;
|
|
}
|
|
}
|
|
return false;
|
|
}
|
|
|
|
common::Status InferenceSession::Initialize() {
|
|
Status status = Status::OK();
|
|
TimePoint tp;
|
|
if (session_profiler_.IsEnabled()) {
|
|
tp = session_profiler_.StartTime();
|
|
}
|
|
|
|
ORT_TRY {
|
|
LOGS(*session_logger_, INFO) << "Initializing session.";
|
|
const Env& env = Env::Default();
|
|
env.GetTelemetryProvider().LogSessionCreationStart();
|
|
|
|
bool have_cpu_ep = false;
|
|
|
|
{
|
|
std::lock_guard<onnxruntime::OrtMutex> initial_guard(session_mutex_);
|
|
|
|
if (!is_model_loaded_) {
|
|
LOGS(*session_logger_, ERROR) << "Model was not loaded";
|
|
return common::Status(common::ONNXRUNTIME, common::FAIL, "Model was not loaded.");
|
|
}
|
|
|
|
if (is_inited_) { // already initialized
|
|
LOGS(*session_logger_, INFO) << "Session has already been initialized.";
|
|
return common::Status::OK();
|
|
}
|
|
|
|
have_cpu_ep = execution_providers_.Get(onnxruntime::kCpuExecutionProvider) != nullptr;
|
|
}
|
|
|
|
// Register default CPUExecutionProvider if user didn't provide it through the Register() calls.
|
|
// RegisterExecutionProvider locks the session_mutex_ so we can't be holding it when we call that
|
|
if (!have_cpu_ep) {
|
|
LOGS(*session_logger_, INFO) << "Adding default CPU execution provider.";
|
|
CPUExecutionProviderInfo epi{session_options_.enable_cpu_mem_arena};
|
|
auto p_cpu_exec_provider = onnxruntime::make_unique<CPUExecutionProvider>(epi);
|
|
ORT_RETURN_IF_ERROR_SESSIONID_(RegisterExecutionProvider(std::move(p_cpu_exec_provider)));
|
|
}
|
|
|
|
// re-acquire mutex
|
|
std::lock_guard<onnxruntime::OrtMutex> l(session_mutex_);
|
|
|
|
// At this time we know all the providers that will be part of this session.
|
|
// Read shared allocators from the environment and update them in the respective providers.
|
|
//
|
|
// The reason for updating the providers is so that when the session state is created the allocators
|
|
// are setup appropariately keyed by OrtMemoryInfo with delegates going to the respective providers.
|
|
// Secondly, the GetAllocator() method inside IExecutionProvider is still used in various places, hence
|
|
// it doesn't make sense to just update the allocator map inside session state with these shared allocators; doing
|
|
// so would cause inconsistency between the allocator map inside session sate and that inside the providers.
|
|
// TODO: we could refactor the allocators to not require the call to GetAllocator but that change is much bigger
|
|
// since we've to take into account the per-thread cuda allocators.
|
|
// TODO (contd.) We could also possibly absorb the per-thread logic in a new allocator decorator that derives
|
|
// from IAllocator to keep things clean.
|
|
std::string use_env_allocators = session_options_.GetConfigOrDefault(kOrtSessionOptionsConfigUseEnvAllocators,
|
|
"0");
|
|
if (use_env_allocators == "1") {
|
|
LOGS(*session_logger_, INFO) << "This session will use the allocator registered with the environment.";
|
|
UpdateProvidersWithSharedAllocators();
|
|
}
|
|
|
|
#ifdef ONNXRUNTIME_ENABLE_INSTRUMENT
|
|
TraceLoggingWriteStart(session_activity, "OrtInferenceSessionActivity");
|
|
session_activity_started_ = true;
|
|
#endif
|
|
|
|
// now that we have all the execution providers, create the session state
|
|
session_state_ = onnxruntime::make_unique<SessionState>(
|
|
model_->MainGraph(),
|
|
execution_providers_,
|
|
session_options_.enable_mem_pattern && session_options_.execution_mode == ExecutionMode::ORT_SEQUENTIAL,
|
|
GetIntraOpThreadPoolToUse(),
|
|
GetInterOpThreadPoolToUse(),
|
|
data_transfer_mgr_,
|
|
*session_logger_,
|
|
session_profiler_,
|
|
session_options_.use_deterministic_compute);
|
|
|
|
onnxruntime::Graph& graph = model_->MainGraph();
|
|
|
|
// Collect the kernel registries from execution provider instances;
|
|
// There are 2 kinds of kernel registries with priority from high to low as below,
|
|
// 1. Custom execution provider type specific kernel registries.
|
|
// 2. common execution provider type specific kernel registries.
|
|
// Kernel registries are shared across sessions.
|
|
// The 1st ones should have already been registered via session-level API into KernelRegistryManager.
|
|
//
|
|
// Register 2nd registries into KernelRegistryManager.
|
|
ORT_RETURN_IF_ERROR_SESSIONID_(kernel_registry_manager_.RegisterKernels(execution_providers_));
|
|
|
|
bool loading_ort_format = !ort_format_model_bytes_.empty();
|
|
bool saving_model = !session_options_.optimized_model_filepath.empty();
|
|
bool saving_ort_format = false;
|
|
if (saving_model) {
|
|
std::string model_type = session_options_.GetConfigOrDefault(kOrtSessionOptionsConfigSaveModelFormat, "");
|
|
bool has_explicit_type = !model_type.empty();
|
|
saving_ort_format = ((has_explicit_type && model_type == "ORT") ||
|
|
(!has_explicit_type &&
|
|
experimental::utils::IsOrtFormatModel(session_options_.optimized_model_filepath)));
|
|
}
|
|
|
|
#if !defined(ORT_MINIMAL_BUILD)
|
|
if (!loading_ort_format) {
|
|
// add predefined transformers
|
|
AddPredefinedTransformers(graph_transformation_mgr_, session_options_.graph_optimization_level,
|
|
transformers_to_enable_);
|
|
|
|
// apply any transformations to the main graph and any subgraphs
|
|
ORT_RETURN_IF_ERROR_SESSIONID_(TransformGraph(graph, graph_transformation_mgr_,
|
|
execution_providers_, kernel_registry_manager_,
|
|
insert_cast_transformer_,
|
|
*session_state_,
|
|
saving_ort_format));
|
|
|
|
// now that all the transforms are done, call Resolve on the main graph. this will recurse into the subgraphs.
|
|
ORT_RETURN_IF_ERROR_SESSIONID_(graph.Resolve());
|
|
|
|
// Update temporary copies of metadata, input- and output definitions to the same state as the resolved graph
|
|
ORT_RETURN_IF_ERROR_SESSIONID_(SaveModelMetadata(*model_));
|
|
} else
|
|
#endif // !defined(ORT_MINIMAL_BUILD)
|
|
{
|
|
#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)
|
|
// nodes are already partitioned, but a custom EP may compile some at runtime.
|
|
// run the partitioning to allow that to happen.
|
|
//
|
|
// We always have the CPU EP, so only need to run this if some other EP is enabled
|
|
if (execution_providers_.NumProviders() > 1) {
|
|
ORT_RETURN_IF_ERROR_SESSIONID_(PartitionOrtFormatModel(graph, execution_providers_, kernel_registry_manager_,
|
|
*session_state_));
|
|
}
|
|
#endif
|
|
}
|
|
|
|
const experimental::fbs::SessionState* serialized_session_state =
|
|
loading_ort_format
|
|
? fbs::GetInferenceSession(ort_format_model_bytes_.data())->session_state()
|
|
: nullptr;
|
|
|
|
ORT_RETURN_IF_ERROR_SESSIONID_(
|
|
session_state_->FinalizeSessionState(model_location_, kernel_registry_manager_,
|
|
session_options_,
|
|
serialized_session_state,
|
|
// need to keep the initializers if saving the optimized model
|
|
!saving_model,
|
|
saving_ort_format));
|
|
|
|
#if !defined(ORT_MINIMAL_BUILD)
|
|
if (saving_model) {
|
|
if (session_state_->GetFuncMgr().NumFuncs() > 0) {
|
|
ORT_RETURN_IF_ERROR_SESSIONID_(
|
|
ORT_MAKE_STATUS(ONNXRUNTIME, FAIL,
|
|
"Unable to serialize model as it contains compiled nodes. "
|
|
"Please disable any execution providers which generate compiled nodes."));
|
|
}
|
|
|
|
if (session_options_.graph_optimization_level >= TransformerLevel::Level3) {
|
|
LOGS(*session_logger_, WARNING)
|
|
<< "Serializing optimized model with Graph Optimization level greater than ORT_ENABLE_EXTENDED. "
|
|
"The generated model may contain hardware and execution provider specific optimizations, "
|
|
"and should only be used in the same environment the model was optimized for.";
|
|
}
|
|
|
|
if (saving_ort_format) {
|
|
ORT_RETURN_IF_ERROR_SESSIONID_(SaveToOrtFormat(session_options_.optimized_model_filepath));
|
|
} else {
|
|
ORT_RETURN_IF_ERROR_SESSIONID_(Model::Save(*model_, session_options_.optimized_model_filepath));
|
|
}
|
|
}
|
|
#endif // !defined(ORT_MINIMAL_BUILD)
|
|
|
|
session_state_->ResolveMemoryPatternFlag();
|
|
is_inited_ = true;
|
|
|
|
// we don't directly use the ORT format bytes currently, so free those now
|
|
std::vector<uint8_t>().swap(ort_format_model_bytes_);
|
|
|
|
// and log telemetry
|
|
bool model_has_fp16_inputs = ModelHasFP16Inputs(graph);
|
|
env.GetTelemetryProvider().LogSessionCreation(
|
|
session_id_, model_->IrVersion(), model_->ProducerName(), model_->ProducerVersion(), model_->Domain(),
|
|
model_->MainGraph().DomainToVersionMap(), model_->MainGraph().Name(), model_->MetaData(),
|
|
telemetry_.event_name_, execution_providers_.GetIds(), model_has_fp16_inputs);
|
|
LOGS(*session_logger_, INFO) << "Session successfully initialized.";
|
|
}
|
|
ORT_CATCH(const NotImplementedException& ex) {
|
|
ORT_HANDLE_EXCEPTION([&]() {
|
|
status = ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED, "Exception during initialization: ", ex.what());
|
|
LOGS(*session_logger_, ERROR) << status.ErrorMessage();
|
|
});
|
|
}
|
|
ORT_CATCH(const std::exception& ex) {
|
|
ORT_HANDLE_EXCEPTION([&]() {
|
|
status = ORT_MAKE_STATUS(ONNXRUNTIME, RUNTIME_EXCEPTION, "Exception during initialization: ", ex.what());
|
|
LOGS(*session_logger_, ERROR) << status.ErrorMessage();
|
|
});
|
|
}
|
|
ORT_CATCH(...) {
|
|
status = ORT_MAKE_STATUS(ONNXRUNTIME, RUNTIME_EXCEPTION, "Encountered unknown exception in Initialize()");
|
|
LOGS(*session_logger_, ERROR) << status.ErrorMessage();
|
|
}
|
|
|
|
if (session_profiler_.IsEnabled()) {
|
|
session_profiler_.EndTimeAndRecordEvent(profiling::SESSION_EVENT, "session_initialization", tp);
|
|
}
|
|
|
|
if (status.IsOK()) {
|
|
for (auto& xp : execution_providers_) {
|
|
auto end_status = xp->OnSessionInitializationEnd();
|
|
if (status.IsOK()) {
|
|
status = end_status;
|
|
}
|
|
}
|
|
}
|
|
|
|
return status;
|
|
}
|
|
|
|
// This method should be called from within Initialize() only and before the creation of the session state.
|
|
// This ensures all providers have been registered in the session and the session state is consistent with the providers.
|
|
void InferenceSession::UpdateProvidersWithSharedAllocators() {
|
|
using namespace std;
|
|
const auto& provider_ids = execution_providers_.GetIds();
|
|
for (const auto& one_shared_alloc : environment_.GetRegisteredSharedAllocators()) {
|
|
for (const auto& id : provider_ids) {
|
|
auto* provider_ptr = execution_providers_.Get(id);
|
|
provider_ptr->ReplaceAllocator(one_shared_alloc);
|
|
}
|
|
}
|
|
}
|
|
|
|
int InferenceSession::GetCurrentNumRuns() const {
|
|
return current_num_runs_.load();
|
|
}
|
|
|
|
const std::vector<std::string>& InferenceSession::GetRegisteredProviderTypes() const {
|
|
return execution_providers_.GetIds();
|
|
}
|
|
|
|
const ProviderOptionsMap& InferenceSession::GetAllProviderOptions() const {
|
|
return execution_providers_.GetAllProviderOptions();
|
|
}
|
|
|
|
const SessionOptions& InferenceSession::GetSessionOptions() const {
|
|
return session_options_;
|
|
}
|
|
|
|
const DataTransferManager& InferenceSession::GetDataTransferManager() const {
|
|
return data_transfer_mgr_;
|
|
}
|
|
|
|
common::Status InferenceSession::CheckShapes(const std::string& input_name, const TensorShape& input_shape,
|
|
const TensorShape& expected_shape) const {
|
|
auto input_shape_sz = input_shape.NumDimensions();
|
|
auto expected_shape_sz = expected_shape.NumDimensions();
|
|
if (input_shape_sz != expected_shape_sz) {
|
|
std::ostringstream ostr;
|
|
ostr << "Invalid rank for input: " << input_name << " Got: " << input_shape_sz << " Expected: " << expected_shape_sz
|
|
<< " Please fix either the inputs or the model.";
|
|
return Status(ONNXRUNTIME, INVALID_ARGUMENT, ostr.str());
|
|
}
|
|
|
|
std::vector<size_t> invalid_dim_indices;
|
|
for (size_t i = 0; i < input_shape_sz; ++i) {
|
|
if (expected_shape[i] < 0) {
|
|
continue; // this represents a symbolic shape dimension
|
|
}
|
|
if (input_shape[i] != expected_shape[i]) {
|
|
invalid_dim_indices.push_back(i);
|
|
}
|
|
}
|
|
|
|
if (!invalid_dim_indices.empty()) {
|
|
std::ostringstream ostr;
|
|
ostr << "Got invalid dimensions for input: " << input_name << " for the following indices\n";
|
|
for (size_t i = 0, end = invalid_dim_indices.size(); i < end; ++i) {
|
|
size_t idx = invalid_dim_indices[i];
|
|
ostr << " index: " << idx << " Got: " << input_shape[idx] << " Expected: " << expected_shape[idx] << "\n";
|
|
}
|
|
ostr << " Please fix either the inputs or the model.";
|
|
return Status(ONNXRUNTIME, INVALID_ARGUMENT, ostr.str());
|
|
}
|
|
return Status::OK();
|
|
}
|
|
|
|
static common::Status CheckTypes(MLDataType actual, MLDataType expected, const std::string& base_type) {
|
|
if (actual == expected) {
|
|
return Status::OK();
|
|
}
|
|
std::ostringstream ostr;
|
|
ostr << "Unexpected input data type. Actual: (";
|
|
ostr << base_type;
|
|
ostr << "(";
|
|
ostr << DataTypeImpl::ToString(actual);
|
|
ostr << ")) , expected: (";
|
|
ostr << base_type;
|
|
ostr << "(";
|
|
ostr << DataTypeImpl::ToString(expected);
|
|
ostr << "))";
|
|
|
|
return Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT, ostr.str());
|
|
}
|
|
|
|
common::Status InferenceSession::ValidateInputs(const std::vector<std::string>& feed_names,
|
|
const std::vector<OrtValue>& feeds) const {
|
|
if (feed_names.size() != feeds.size()) {
|
|
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Size mismatch: feed_names has ", feed_names.size(),
|
|
"elements, but feeds has ", feeds.size(), " elements.");
|
|
}
|
|
|
|
for (size_t i = 0; i < feeds.size(); ++i) {
|
|
const auto& feed_name = feed_names[i];
|
|
|
|
auto iter = input_def_map_.find(feed_name);
|
|
if (input_def_map_.end() == iter) {
|
|
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Invalid Feed Input Name:", feed_name);
|
|
}
|
|
|
|
auto expected_type = iter->second.ml_data_type;
|
|
auto& input_ml_value = feeds.at(i);
|
|
if (input_ml_value.IsTensor()) {
|
|
// check for type
|
|
if (!expected_type->IsTensorType()) {
|
|
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input with name: ", feed_name,
|
|
" is not expected to be of type tensor.");
|
|
}
|
|
auto expected_element_type = expected_type->AsTensorType()->GetElementType();
|
|
auto input_element_type = input_ml_value.Get<Tensor>().DataType();
|
|
ORT_RETURN_IF_ERROR_SESSIONID_(CheckTypes(input_element_type, expected_element_type, "tensor"));
|
|
|
|
// check for shape
|
|
const auto& expected_shape = iter->second.tensor_shape;
|
|
if (expected_shape.NumDimensions() > 0) {
|
|
const auto& input_shape = input_ml_value.Get<Tensor>().Shape();
|
|
ORT_RETURN_IF_ERROR_SESSIONID_(CheckShapes(feed_name, input_shape, expected_shape));
|
|
}
|
|
} else if (input_ml_value.IsSparseTensor()) {
|
|
if (!expected_type->IsSparseTensorType()) {
|
|
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input with name: ", feed_name,
|
|
" is not expected to be of type sparse tensor.");
|
|
}
|
|
auto expected_element_type = expected_type->AsSparseTensorType()->GetElementType();
|
|
const SparseTensor& sparse_tensor = input_ml_value.Get<SparseTensor>();
|
|
auto input_element_type = sparse_tensor.Values().DataType();
|
|
ORT_RETURN_IF_ERROR_SESSIONID_(CheckTypes(input_element_type, expected_element_type, "sparse_tensor"));
|
|
// Check shape
|
|
const auto& expected_shape = iter->second.tensor_shape;
|
|
if (expected_shape.NumDimensions() > 0) {
|
|
const auto& input_shape = sparse_tensor.Shape();
|
|
ORT_RETURN_IF_ERROR_SESSIONID_(CheckShapes(feed_name, input_shape, expected_shape));
|
|
}
|
|
} else if (input_ml_value.IsTensorSequence()) {
|
|
if (!expected_type->IsTensorSequenceType()) {
|
|
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input with name: ", feed_name,
|
|
" is not expected to be of type tensor sequence.");
|
|
}
|
|
auto expected_element_type = expected_type->AsSequenceTensorBase()->GetElementType();
|
|
auto input_element_type = input_ml_value.Get<TensorSeq>().DataType();
|
|
ORT_RETURN_IF_ERROR_SESSIONID_(CheckTypes(input_element_type, expected_element_type, "seq"));
|
|
} else {
|
|
auto input_type = input_ml_value.Type();
|
|
ORT_RETURN_IF_ERROR_SESSIONID_(CheckTypes(input_type, expected_type, ""));
|
|
}
|
|
}
|
|
|
|
return Status::OK();
|
|
}
|
|
|
|
common::Status InferenceSession::ValidateOutputs(const std::vector<std::string>& output_names,
|
|
const std::vector<OrtValue>* p_fetches) const {
|
|
if (p_fetches == nullptr) {
|
|
return common::Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT, "Output vector pointer is NULL");
|
|
}
|
|
|
|
if (output_names.empty()) {
|
|
return common::Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT, "At least one output should be requested.");
|
|
}
|
|
|
|
if (!p_fetches->empty() && (output_names.size() != p_fetches->size())) {
|
|
std::ostringstream ostr;
|
|
ostr << "Output vector incorrectly sized: output_names.size(): " << output_names.size()
|
|
<< "p_fetches->size(): " << p_fetches->size();
|
|
return common::Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT, ostr.str());
|
|
}
|
|
|
|
for (const auto& name : output_names) {
|
|
if (model_output_names_.find(name) == model_output_names_.end()) {
|
|
return common::Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT, "Invalid Output Name:" + name);
|
|
}
|
|
}
|
|
|
|
// TODO add more validation here like checking shape of the allocated buffers
|
|
|
|
return common::Status::OK();
|
|
}
|
|
|
|
Status InferenceSession::Run(const RunOptions& run_options,
|
|
const std::vector<std::string>& feed_names, const std::vector<OrtValue>& feeds,
|
|
const std::vector<std::string>& output_names, std::vector<OrtValue>* p_fetches,
|
|
const std::vector<OrtDevice>* p_fetches_device_info) {
|
|
TimePoint tp;
|
|
if (session_profiler_.IsEnabled()) {
|
|
tp = session_profiler_.StartTime();
|
|
}
|
|
|
|
#ifdef ONNXRUNTIME_ENABLE_INSTRUMENT
|
|
TraceLoggingActivity<telemetry_provider_handle> ortrun_activity;
|
|
ortrun_activity.SetRelatedActivity(session_activity);
|
|
TraceLoggingWriteStart(ortrun_activity, "OrtRun");
|
|
#endif
|
|
Status retval = Status::OK();
|
|
const Env& env = Env::Default();
|
|
|
|
std::vector<IExecutionProvider*> exec_providers_to_stop;
|
|
exec_providers_to_stop.reserve(execution_providers_.NumProviders());
|
|
|
|
ORT_TRY {
|
|
if (!is_inited_) {
|
|
LOGS(*session_logger_, ERROR) << "Session was not initialized";
|
|
return Status(common::ONNXRUNTIME, common::FAIL, "Session not initialized.");
|
|
}
|
|
|
|
// log evaluation start to trace logging provider
|
|
env.GetTelemetryProvider().LogEvaluationStart();
|
|
|
|
ORT_RETURN_IF_ERROR_SESSIONID_(ValidateInputs(feed_names, feeds));
|
|
ORT_RETURN_IF_ERROR_SESSIONID_(ValidateOutputs(output_names, p_fetches));
|
|
|
|
FeedsFetchesInfo info(feed_names, output_names, session_state_->GetOrtValueNameIdxMap());
|
|
FeedsFetchesManager feeds_fetches_manager{std::move(info)};
|
|
|
|
if (p_fetches_device_info) {
|
|
// populate the target device info. ignored if pre-allocated fetches are provided
|
|
const auto& fetch_device_info = *p_fetches_device_info;
|
|
auto& fetch_info = feeds_fetches_manager.GetMutableFetchesDeviceCopyInfo();
|
|
|
|
for (size_t i = 0, end = output_names.size(); i < end; ++i) {
|
|
fetch_info[i].target_device = fetch_device_info[i];
|
|
}
|
|
}
|
|
|
|
if (!run_options.run_tag.empty()) {
|
|
LOGS(*session_logger_, INFO) << "Running with tag: " << run_options.run_tag;
|
|
}
|
|
|
|
++current_num_runs_;
|
|
|
|
// scope of owned_run_logger is just the call to Execute.
|
|
// If Execute ever becomes async we need a different approach
|
|
std::unique_ptr<logging::Logger> owned_run_logger;
|
|
auto run_logger = CreateLoggerForRun(run_options, owned_run_logger);
|
|
|
|
// info all execution providers InferenceSession:Run started
|
|
// TODO: only call OnRunStart for all providers in-use
|
|
for (auto& xp : execution_providers_) {
|
|
// call OnRunStart and add to exec_providers_to_stop if successful
|
|
auto start_func = [&xp, &exec_providers_to_stop]() {
|
|
auto status = xp->OnRunStart();
|
|
if (status.IsOK())
|
|
exec_providers_to_stop.push_back(xp.get());
|
|
|
|
return status;
|
|
};
|
|
|
|
ORT_CHECK_AND_SET_RETVAL(start_func());
|
|
}
|
|
|
|
#if !defined(ORT_MINIMAL_BUILD)
|
|
if (run_options.only_execute_path_to_fetches) {
|
|
session_state_->UpdateToBeExecutedNodes(feeds_fetches_manager.GetFeedsFetchesInfo().fetches_mlvalue_idxs);
|
|
}
|
|
#endif
|
|
|
|
// execute the graph
|
|
ORT_CHECK_AND_SET_RETVAL(utils::ExecuteGraph(*session_state_, feeds_fetches_manager, feeds, *p_fetches,
|
|
session_options_.execution_mode, run_options.terminate, run_logger,
|
|
run_options.only_execute_path_to_fetches));
|
|
}
|
|
ORT_CATCH(const std::exception& e) {
|
|
ORT_HANDLE_EXCEPTION([&]() {
|
|
retval = Status(common::ONNXRUNTIME, common::FAIL, e.what());
|
|
});
|
|
}
|
|
ORT_CATCH(...) {
|
|
retval = Status(common::ONNXRUNTIME, common::RUNTIME_EXCEPTION, "Encountered unknown exception in Run()");
|
|
}
|
|
|
|
// info all execution providers InferenceSession:Run ended
|
|
for (auto* xp : exec_providers_to_stop) {
|
|
auto status = xp->OnRunEnd();
|
|
ORT_CHECK_AND_SET_RETVAL(status);
|
|
}
|
|
|
|
--current_num_runs_;
|
|
|
|
// keep track of telemetry
|
|
++telemetry_.total_runs_since_last_;
|
|
telemetry_.total_run_duration_since_last_ += TimeDiffMicroSeconds(tp);
|
|
|
|
// time to send telemetry?
|
|
if (TimeDiffMicroSeconds(telemetry_.time_sent_last_) > telemetry_.kDurationBetweenSending) {
|
|
// send the telemetry
|
|
env.GetTelemetryProvider().LogRuntimePerf(session_id_, telemetry_.total_runs_since_last_,
|
|
telemetry_.total_run_duration_since_last_);
|
|
// reset counters
|
|
telemetry_.time_sent_last_ = std::chrono::high_resolution_clock::now();
|
|
telemetry_.total_runs_since_last_ = 0;
|
|
telemetry_.total_run_duration_since_last_ = 0;
|
|
}
|
|
|
|
// log evaluation stop to trace logging provider
|
|
env.GetTelemetryProvider().LogEvaluationStop();
|
|
|
|
// send out profiling events (optional)
|
|
if (session_profiler_.IsEnabled()) {
|
|
session_profiler_.EndTimeAndRecordEvent(profiling::SESSION_EVENT, "model_run", tp);
|
|
}
|
|
#ifdef ONNXRUNTIME_ENABLE_INSTRUMENT
|
|
TraceLoggingWriteStop(ortrun_activity, "OrtRun");
|
|
#endif
|
|
return retval;
|
|
}
|
|
|
|
common::Status InferenceSession::Run(const NameMLValMap& feeds, const std::vector<std::string>& output_names,
|
|
std::vector<OrtValue>* p_fetches) {
|
|
return Run(RunOptions(), feeds, output_names, p_fetches);
|
|
}
|
|
|
|
common::Status InferenceSession::Run(const RunOptions& run_options, const NameMLValMap& feeds_map,
|
|
const std::vector<std::string>& output_names, std::vector<OrtValue>* p_fetches) {
|
|
std::vector<std::string> feed_names;
|
|
std::vector<OrtValue> feeds;
|
|
|
|
auto num_feeds = feeds_map.size();
|
|
feed_names.reserve(num_feeds);
|
|
feeds.reserve(num_feeds);
|
|
|
|
for (auto& pair : feeds_map) {
|
|
feed_names.push_back(pair.first);
|
|
feeds.push_back(pair.second);
|
|
}
|
|
|
|
return Run(run_options, feed_names, feeds, output_names, p_fetches, nullptr);
|
|
}
|
|
|
|
std::pair<common::Status, const ModelMetadata*> InferenceSession::GetModelMetadata() const {
|
|
{
|
|
std::lock_guard<onnxruntime::OrtMutex> l(session_mutex_);
|
|
if (!is_model_loaded_) {
|
|
LOGS(*session_logger_, ERROR) << "Model was not loaded";
|
|
return std::make_pair(common::Status(common::ONNXRUNTIME, common::FAIL, "Model was not loaded."), nullptr);
|
|
}
|
|
}
|
|
|
|
return std::make_pair(common::Status::OK(), &model_metadata_);
|
|
}
|
|
|
|
std::pair<common::Status, const InputDefList*> InferenceSession::GetModelInputs() const {
|
|
{
|
|
std::lock_guard<onnxruntime::OrtMutex> l(session_mutex_);
|
|
if (!is_model_loaded_) {
|
|
LOGS(*session_logger_, ERROR) << "Model was not loaded";
|
|
return std::make_pair(common::Status(common::ONNXRUNTIME, common::FAIL, "Model was not loaded."), nullptr);
|
|
}
|
|
}
|
|
|
|
// return required inputs (excludes any inputs used for overriding initializers)
|
|
return std::make_pair(common::Status::OK(), &model_->MainGraph().GetInputs());
|
|
}
|
|
|
|
std::pair<common::Status, const InputDefList*> InferenceSession::GetOverridableInitializers() const {
|
|
{
|
|
std::lock_guard<onnxruntime::OrtMutex> l(session_mutex_);
|
|
if (!is_model_loaded_) {
|
|
LOGS(*session_logger_, ERROR) << "Model was not loaded";
|
|
return std::make_pair(common::Status(common::ONNXRUNTIME, common::FAIL, "Model was not loaded."), nullptr);
|
|
}
|
|
}
|
|
|
|
// returns a list of initializers that can be overriden.
|
|
return std::make_pair(common::Status::OK(), &model_->MainGraph().GetOverridableInitializers());
|
|
}
|
|
|
|
std::pair<common::Status, const OutputDefList*> InferenceSession::GetModelOutputs() const {
|
|
{
|
|
std::lock_guard<onnxruntime::OrtMutex> l(session_mutex_);
|
|
if (!is_model_loaded_) {
|
|
LOGS(*session_logger_, ERROR) << "Model was not loaded";
|
|
return std::make_pair(common::Status(common::ONNXRUNTIME, common::FAIL, "Model was not loaded."), nullptr);
|
|
}
|
|
}
|
|
|
|
return std::make_pair(common::Status::OK(), &output_def_list_);
|
|
}
|
|
|
|
common::Status InferenceSession::NewIOBinding(std::unique_ptr<IOBinding>* io_binding) {
|
|
{
|
|
std::lock_guard<onnxruntime::OrtMutex> l(session_mutex_);
|
|
if (!is_inited_) {
|
|
LOGS(*session_logger_, ERROR) << "Session was not initialized";
|
|
return common::Status(common::ONNXRUNTIME, common::FAIL, "Session not initialized.");
|
|
}
|
|
}
|
|
|
|
// private constructor, can't use make_unique
|
|
*io_binding = std::unique_ptr<IOBinding>(new IOBinding(*session_state_));
|
|
return Status::OK();
|
|
}
|
|
|
|
common::Status InferenceSession::Run(const RunOptions& run_options, IOBinding& io_binding) {
|
|
// TODO should Run() call io_binding.SynchronizeInputs() or should it let the callers do it?
|
|
// io_binding.SynchronizeInputs();
|
|
return Run(run_options, io_binding.GetInputNames(), io_binding.GetInputs(), io_binding.GetOutputNames(),
|
|
&io_binding.GetOutputs(), &io_binding.GetOutputsDeviceInfo());
|
|
}
|
|
|
|
common::Status InferenceSession::Run(IOBinding& io_binding) {
|
|
RunOptions run_options;
|
|
return Run(run_options, io_binding);
|
|
}
|
|
|
|
#ifdef ENABLE_TRAINING
|
|
common::Status InferenceSession::RunInBackgroundAndWaitForYield(const RunOptions& run_options, IOBinding& io_binding,
|
|
std::vector<OrtValue>& user_outputs, int64_t& run_id) {
|
|
std::promise<void> setup_promise;
|
|
std::future<void> setup_future = setup_promise.get_future();
|
|
|
|
// Passing run_options and io_binding by reference to the bg_thread,
|
|
// this is ok because they are ORTModule's member, and they are presistent through forward and backward calls
|
|
auto bg_thread = std::thread([this](std::future<void> setup_future, const RunOptions& run_options, IOBinding& io_binding) {
|
|
// wait until task is properly setup
|
|
setup_future.get();
|
|
|
|
common::Status status = Run(run_options, io_binding.GetInputNames(), io_binding.GetInputs(), io_binding.GetOutputNames(),
|
|
&io_binding.GetOutputs(), &io_binding.GetOutputsDeviceInfo());
|
|
|
|
onnxruntime::contrib::OrtTasks::GetInstance().SetStatus(status);
|
|
|
|
// If forward outputs still hasn't been consumed at this point, i.e. forward function hasn't complete itself
|
|
// this indicates that Run() call returned before hitting YieldOp, due to hitting some exception during the forward subgraph execution
|
|
// In this case, we need to wake up the foreground thread and pass along the failed status.
|
|
// Otherwise, foreground thread will be stuck waiting for forward_outputs.
|
|
if (onnxruntime::contrib::OrtTasks::GetInstance().ForwardOutputsIsValid()) {
|
|
ORT_ENFORCE(!status.IsOK());
|
|
// signal main thread for background thread completion
|
|
onnxruntime::contrib::OrtTasks::GetInstance().SetForwardOutputs(status, {});
|
|
}
|
|
},
|
|
std::move(setup_future), std::cref(run_options), std::ref(io_binding));
|
|
|
|
run_id = std::hash<std::thread::id>()(bg_thread.get_id());
|
|
{
|
|
std::lock_guard<std::mutex> lock(bg_threads_mutex_);
|
|
bg_threads_[run_id] = std::move(bg_thread);
|
|
}
|
|
|
|
onnxruntime::contrib::OrtTasks::GetInstance().CreateBackgroundTask(run_id);
|
|
|
|
LOGS(*session_logger_, VERBOSE) << "InferenceSession::Forward() call created a task with run_id " << run_id;
|
|
|
|
// background task is setup, unblock background thread to continue
|
|
setup_promise.set_value();
|
|
|
|
// Wait for data/signal from
|
|
// 1. Yield op, if the bg thread sucessfully reached Yield's signal point
|
|
// 2. The end of bg thread, if it hit execptions and returned earlier
|
|
auto forward_outputs = onnxruntime::contrib::OrtTasks::GetInstance().WaitForForwardOutputs(run_id);
|
|
const Status& forward_status = forward_outputs.first;
|
|
user_outputs = std::move(forward_outputs.second);
|
|
|
|
// background thread has completed without hitting Yield Op
|
|
if (!forward_status.IsOK()) {
|
|
std::thread thread;
|
|
{
|
|
std::lock_guard<std::mutex> lock(bg_threads_mutex_);
|
|
std::swap(thread, bg_threads_[run_id]);
|
|
bg_threads_.erase(run_id);
|
|
}
|
|
ORT_ENFORCE(thread.joinable());
|
|
thread.join();
|
|
onnxruntime::contrib::OrtTasks::GetInstance().RemoveTask(run_id);
|
|
return forward_status;
|
|
}
|
|
|
|
return Status::OK();
|
|
}
|
|
|
|
common::Status InferenceSession::ContinueRunInBackground(int64_t run_id, const std::vector<OrtValue>& backward_output_grads) {
|
|
LOGS(*session_logger_, VERBOSE) << "Running InferenceSession::Backward() with run_id " << run_id;
|
|
|
|
// resume background thread
|
|
onnxruntime::contrib::OrtTasks::GetInstance().SetBackwardInputs(run_id, backward_output_grads, false);
|
|
|
|
Status bg_thread_status = onnxruntime::contrib::OrtTasks::GetInstance().WaitForStatus(run_id);
|
|
|
|
std::thread bg_thread;
|
|
{
|
|
std::lock_guard<std::mutex> lock(bg_threads_mutex_);
|
|
std::swap(bg_thread, bg_threads_[run_id]);
|
|
bg_threads_.erase(run_id);
|
|
}
|
|
|
|
// wait for bg_thread to complete
|
|
ORT_ENFORCE(bg_thread.joinable());
|
|
bg_thread.join();
|
|
onnxruntime::contrib::OrtTasks::GetInstance().RemoveTask(run_id);
|
|
|
|
return bg_thread_status;
|
|
}
|
|
|
|
void InferenceSession::CancelBackgroundTask(int64_t run_id) {
|
|
LOGS(*session_logger_, WARNING) << "Canceling background task with run_id " << run_id;
|
|
|
|
// resume background thread with terminate = true
|
|
onnxruntime::contrib::OrtTasks::GetInstance().SetBackwardInputs(run_id, {}, true);
|
|
|
|
// wait for bg_thread to complete
|
|
std::thread bg_thread;
|
|
{
|
|
std::lock_guard<std::mutex> lock(bg_threads_mutex_);
|
|
std::swap(bg_thread, bg_threads_[run_id]);
|
|
bg_threads_.erase(run_id);
|
|
}
|
|
ORT_ENFORCE(bg_thread.joinable());
|
|
bg_thread.join();
|
|
onnxruntime::contrib::OrtTasks::GetInstance().RemoveTask(run_id);
|
|
}
|
|
#endif
|
|
|
|
template <typename T>
|
|
void InferenceSession::StartProfiling(const std::basic_string<T>& file_prefix) {
|
|
std::basic_ostringstream<T> ss;
|
|
ss << file_prefix << "_" << GetCurrentTimeString<T>() << ".json";
|
|
session_profiler_.StartProfiling(ss.str());
|
|
}
|
|
|
|
void InferenceSession::StartProfiling(const std::string& file_prefix) {
|
|
StartProfiling<char>(file_prefix);
|
|
}
|
|
|
|
#ifdef _WIN32
|
|
void InferenceSession::StartProfiling(const std::wstring& file_prefix) {
|
|
StartProfiling<PATH_CHAR_TYPE>(file_prefix);
|
|
}
|
|
#endif
|
|
|
|
void InferenceSession::StartProfiling(const logging::Logger* logger_ptr) {
|
|
session_profiler_.StartProfiling(logger_ptr);
|
|
}
|
|
|
|
std::string InferenceSession::EndProfiling() {
|
|
if (is_model_loaded_) {
|
|
if (session_profiler_.IsEnabled()) {
|
|
return session_profiler_.EndProfiling();
|
|
} else {
|
|
LOGS(*session_logger_, VERBOSE) << "Profiler is disabled.";
|
|
return std::string();
|
|
}
|
|
}
|
|
LOGS(*session_logger_, ERROR) << "Could not write a profile because no model was loaded.";
|
|
return std::string();
|
|
}
|
|
|
|
const profiling::Profiler& InferenceSession::GetProfiling() const {
|
|
return session_profiler_;
|
|
}
|
|
|
|
AllocatorPtr InferenceSession::GetAllocator(const OrtMemoryInfo& mem_info) const {
|
|
return session_state_->GetAllocator(mem_info);
|
|
}
|
|
|
|
#if !defined(ORT_MINIMAL_BUILD)
|
|
// assumes model has already been loaded before
|
|
common::Status InferenceSession::DoPostLoadProcessing(onnxruntime::Model& model) {
|
|
// TODO add other post load processing here
|
|
common::Status status = SaveModelMetadata(model);
|
|
return status;
|
|
}
|
|
#endif
|
|
|
|
common::Status InferenceSession::SaveModelMetadata(const onnxruntime::Model& model) {
|
|
VLOGS(*session_logger_, 1) << "Saving model metadata";
|
|
const onnxruntime::Graph& graph = model.MainGraph();
|
|
|
|
// save model metadata
|
|
model_metadata_.producer_name = model.ProducerName();
|
|
model_metadata_.description = model.DocString();
|
|
model_metadata_.graph_description = model.GraphDocString();
|
|
model_metadata_.domain = model.Domain();
|
|
model_metadata_.version = model.ModelVersion();
|
|
model_metadata_.custom_metadata_map = model.MetaData();
|
|
model_metadata_.graph_name = graph.Name();
|
|
|
|
required_inputs_.clear();
|
|
for (auto input : graph.GetInputs()) {
|
|
required_inputs_.insert(input->Name());
|
|
}
|
|
|
|
auto add_inputs = [this](const InputDefList& inputs) {
|
|
input_def_map_.clear();
|
|
input_def_map_.reserve(inputs.size());
|
|
for (auto elem : inputs) {
|
|
auto elem_type = utils::GetMLDataType(*elem);
|
|
auto elem_shape_proto = elem->Shape();
|
|
input_def_map_.insert(
|
|
{elem->Name(),
|
|
InputDefMetaData(
|
|
elem, elem_type,
|
|
elem_shape_proto ? utils::GetTensorShapeFromTensorShapeProto(*elem_shape_proto) : TensorShape())});
|
|
}
|
|
};
|
|
|
|
if (graph.CanOverrideInitializer()) {
|
|
// for IR 4 or higher it is optional to have a matching graph input for an initializer, and if one exists the
|
|
// initializer is explicitly overridable.
|
|
add_inputs(graph.GetInputsIncludingInitializers());
|
|
} else {
|
|
// for IR < 4 we don't allow overriding initializers so that they can be treated as constant. exclude them from
|
|
// the list of valid inputs by just using the GetInputs() list.
|
|
add_inputs(graph.GetInputs());
|
|
}
|
|
|
|
// save outputs
|
|
const auto& outputs = graph.GetOutputs();
|
|
output_def_list_ = outputs; // A direct copy of outputs
|
|
|
|
model_output_names_.clear();
|
|
model_output_names_.reserve(outputs.size());
|
|
for (const auto& elem : outputs) {
|
|
model_output_names_.insert(elem->Name());
|
|
}
|
|
|
|
VLOGS(*session_logger_, 1) << "Done saving model metadata";
|
|
return common::Status::OK();
|
|
}
|
|
|
|
// Create a Logger for a single execution if possible. Otherwise use the default logger.
|
|
// If a new logger is created, it will also be stored in new_run_logger,
|
|
// which must remain valid for the duration of the execution.
|
|
// If the default logger is used, new_run_logger will remain empty.
|
|
// The returned value should be used in the execution.
|
|
const logging::Logger& InferenceSession::CreateLoggerForRun(const RunOptions& run_options,
|
|
std::unique_ptr<logging::Logger>& new_run_logger) {
|
|
const logging::Logger* run_logger;
|
|
|
|
// create a per-run logger if we can
|
|
if (logging_manager_ != nullptr) {
|
|
std::string run_log_id{session_options_.session_logid};
|
|
|
|
if (!session_options_.session_logid.empty() && !run_options.run_tag.empty()) {
|
|
run_log_id += ":";
|
|
}
|
|
|
|
run_log_id += run_options.run_tag;
|
|
|
|
logging::Severity severity = logging::Severity::kWARNING;
|
|
if (run_options.run_log_severity_level == -1) {
|
|
severity = session_logger_->GetSeverity();
|
|
} else {
|
|
ORT_ENFORCE(run_options.run_log_severity_level >= 0 &&
|
|
run_options.run_log_severity_level <= static_cast<int>(logging::Severity::kFATAL),
|
|
"Invalid run log severity level. Not a valid onnxruntime::logging::Severity value: ",
|
|
run_options.run_log_severity_level);
|
|
severity = static_cast<logging::Severity>(run_options.run_log_severity_level);
|
|
}
|
|
|
|
new_run_logger = logging_manager_->CreateLogger(run_log_id, severity, false, run_options.run_log_verbosity_level);
|
|
|
|
run_logger = new_run_logger.get();
|
|
VLOGS(*run_logger, 1) << "Created logger for run with id of " << run_log_id;
|
|
} else {
|
|
// fallback to using default logger. this does NOT have any session or run specific id/tag in it
|
|
run_logger = session_logger_;
|
|
VLOGS(*run_logger, 1) << "Using default logger for run " << run_options.run_tag;
|
|
}
|
|
|
|
return *run_logger;
|
|
}
|
|
|
|
void InferenceSession::InitLogger(logging::LoggingManager* logging_manager) {
|
|
// create logger for session, using provided logging manager if possible
|
|
if (logging_manager != nullptr) {
|
|
logging::Severity severity = logging::Severity::kWARNING;
|
|
if (session_options_.session_log_severity_level == -1) {
|
|
severity = logging::LoggingManager::DefaultLogger().GetSeverity();
|
|
} else {
|
|
ORT_ENFORCE(session_options_.session_log_severity_level >= 0 &&
|
|
session_options_.session_log_severity_level <= static_cast<int>(logging::Severity::kFATAL),
|
|
"Invalid session log severity level. Not a valid onnxruntime::logging::Severity value: ",
|
|
session_options_.session_log_severity_level);
|
|
severity = static_cast<logging::Severity>(session_options_.session_log_severity_level);
|
|
}
|
|
|
|
owned_session_logger_ = logging_manager_->CreateLogger(session_options_.session_logid, severity, false,
|
|
session_options_.session_log_verbosity_level);
|
|
session_logger_ = owned_session_logger_.get();
|
|
} else {
|
|
session_logger_ = &logging::LoggingManager::DefaultLogger();
|
|
}
|
|
}
|
|
|
|
#if !defined(ORT_MINIMAL_BUILD)
|
|
|
|
// Registers all the predefined transformers with transformer manager
|
|
void InferenceSession::AddPredefinedTransformers(GraphTransformerManager& transformer_manager,
|
|
TransformerLevel graph_optimization_level,
|
|
const std::vector<std::string>& custom_list) {
|
|
auto add_transformers = [&](TransformerLevel level) {
|
|
// Generate and register transformers for level
|
|
auto transformers_to_register =
|
|
optimizer_utils::GenerateTransformers(level, session_options_,
|
|
*execution_providers_.Get(onnxruntime::kCpuExecutionProvider),
|
|
custom_list);
|
|
for (auto& entry : transformers_to_register) {
|
|
transformer_manager.Register(std::move(entry), level);
|
|
}
|
|
};
|
|
|
|
ORT_ENFORCE(graph_optimization_level <= TransformerLevel::MaxLevel,
|
|
"Exceeded max transformer level. Current level is set to " +
|
|
std::to_string(static_cast<uint32_t>(graph_optimization_level)));
|
|
|
|
for (int i = static_cast<int>(TransformerLevel::Level1); i <= static_cast<int>(TransformerLevel::MaxLevel); i++) {
|
|
TransformerLevel level = static_cast<TransformerLevel>(i);
|
|
if ((graph_optimization_level >= level) || !custom_list.empty()) {
|
|
add_transformers(level);
|
|
}
|
|
}
|
|
}
|
|
|
|
#endif // !defined(ORT_MINIMAL_BUILD)
|
|
|
|
common::Status InferenceSession::WaitForNotification(Notification* p_executor_done, int64_t timeout_in_ms) {
|
|
if (timeout_in_ms > 0) {
|
|
ORT_NOT_IMPLEMENTED(__FUNCTION__, "timeout_in_ms >0 is not supported"); // TODO
|
|
}
|
|
p_executor_done->Wait();
|
|
|
|
return Status::OK();
|
|
}
|
|
|
|
SessionIOBinding::SessionIOBinding(InferenceSession* session) : sess_(session) {
|
|
ORT_ENFORCE(session->NewIOBinding(&binding_).IsOK());
|
|
}
|
|
|
|
InferenceSession* SessionIOBinding::GetInferenceSession() {
|
|
return sess_;
|
|
}
|
|
|
|
IOBinding* SessionIOBinding::Get() {
|
|
return binding_.get();
|
|
}
|
|
|
|
} // namespace onnxruntime
|