From 3efd9a73bb783da7f52ab61c84c166c3a4bbdbed Mon Sep 17 00:00:00 2001 From: Edward Chen <18449977+edgchen1@users.noreply.github.com> Date: Wed, 3 Aug 2022 16:28:26 -0700 Subject: [PATCH] Refactor InferenceSession Load member functions. (#12430) Fix comparison of path characters when checking for ".ort" suffix. Some clean up of InferenceSession Load functions. - Reduce duplication between std::string/std::wstring versions. - Renaming for clarity. --- onnxruntime/core/common/path_string.h | 23 ++- .../core/flatbuffers/flatbuffers_utils.cc | 9 ++ .../core/flatbuffers/flatbuffers_utils.h | 15 +- onnxruntime/core/session/inference_session.cc | 131 +++++------------- onnxruntime/core/session/inference_session.h | 31 ++--- .../test/framework/cuda/fence_cuda_test.cc | 2 +- winml/adapter/winml_adapter_session.cpp | 4 +- 7 files changed, 83 insertions(+), 132 deletions(-) diff --git a/onnxruntime/core/common/path_string.h b/onnxruntime/core/common/path_string.h index 616e622887..bdef867c9d 100644 --- a/onnxruntime/core/common/path_string.h +++ b/onnxruntime/core/common/path_string.h @@ -6,6 +6,13 @@ #include #include +// for std::tolower or std::towlower +#ifdef _WIN32 +#include +#else +#include +#endif + #include "core/common/common.h" #include "core/session/onnxruntime_c_api.h" @@ -16,22 +23,30 @@ using PathChar = ORTCHAR_T; // string type for filesystem paths using PathString = std::basic_string; +inline PathString ToPathString(const PathString& s) { + return s; +} + #ifdef _WIN32 + static_assert(std::is_same::value, "PathString is not std::wstring!"); inline PathString ToPathString(const std::string& s) { return ToWideString(s); } -inline PathString ToPathString(const std::wstring& s) { - return s; +inline PathChar ToLowerPathChar(PathChar c) { + return std::towlower(c); } + #else + static_assert(std::is_same::value, "PathString is not std::string!"); -inline PathString ToPathString(const std::string& s) { - return s; +inline PathChar ToLowerPathChar(PathChar c) { + return std::tolower(c); } + #endif } // namespace onnxruntime diff --git a/onnxruntime/core/flatbuffers/flatbuffers_utils.cc b/onnxruntime/core/flatbuffers/flatbuffers_utils.cc index 8382c427f6..2d926daf32 100644 --- a/onnxruntime/core/flatbuffers/flatbuffers_utils.cc +++ b/onnxruntime/core/flatbuffers/flatbuffers_utils.cc @@ -302,6 +302,15 @@ Status LoadOpsetImportOrtFormat(const flatbuffers::Vector 4 && + filename[len - 4] == ORT_TSTR('.') && + ToLowerPathChar(filename[len - 3]) == ORT_TSTR('o') && + ToLowerPathChar(filename[len - 2]) == ORT_TSTR('r') && + ToLowerPathChar(filename[len - 1]) == ORT_TSTR('t'); +} + bool IsOrtFormatModelBytes(const void* bytes, int num_bytes) { return num_bytes > 8 && // check buffer is large enough to contain identifier so we don't read random memory fbs::InferenceSessionBufferHasIdentifier(bytes); diff --git a/onnxruntime/core/flatbuffers/flatbuffers_utils.h b/onnxruntime/core/flatbuffers/flatbuffers_utils.h index 8f8681131c..570cec7404 100644 --- a/onnxruntime/core/flatbuffers/flatbuffers_utils.h +++ b/onnxruntime/core/flatbuffers/flatbuffers_utils.h @@ -3,9 +3,10 @@ #pragma once -#include #include -#include + +#include "core/common/path_string.h" +#include "core/common/status.h" namespace ONNX_NAMESPACE { class ValueInfoProto; @@ -58,15 +59,7 @@ onnxruntime::common::Status LoadOpsetImportOrtFormat( std::unordered_map& domain_to_version); // check if filename ends in .ort -template -bool IsOrtFormatModel(const std::basic_string& filename) { - auto len = filename.size(); - return len > 4 && - filename[len - 4] == '.' && - std::tolower(filename[len - 3]) == 'o' && - std::tolower(filename[len - 2]) == 'r' && - std::tolower(filename[len - 1]) == 't'; -} +bool IsOrtFormatModel(const PathString& filename); // check if bytes has the flatbuffer ORT identifier bool IsOrtFormatModelBytes(const void* bytes, int num_bytes); diff --git a/onnxruntime/core/session/inference_session.cc b/onnxruntime/core/session/inference_session.cc index 87b4e30a1e..2eced63620 100644 --- a/onnxruntime/core/session/inference_session.cc +++ b/onnxruntime/core/session/inference_session.cc @@ -386,8 +386,8 @@ InferenceSession::InferenceSession(const SessionOptions& session_options, #if !defined(ORT_MINIMAL_BUILD) InferenceSession::InferenceSession(const SessionOptions& session_options, const Environment& session_env, - const std::string& model_uri) - : model_location_(ToWideString(model_uri)), + const PathString& model_uri) + : model_location_(model_uri), graph_transformation_mgr_(session_options.max_num_graph_transformation_steps), insert_cast_transformer_("CastFloat16Transformer"), logging_manager_(session_env.GetLoggingManager()), @@ -403,18 +403,8 @@ InferenceSession::InferenceSession(const SessionOptions& session_options, const #ifdef _WIN32 InferenceSession::InferenceSession(const SessionOptions& session_options, const Environment& session_env, - const std::wstring& model_uri) - : graph_transformation_mgr_(session_options.max_num_graph_transformation_steps), - insert_cast_transformer_("CastFloat16Transformer"), - logging_manager_(session_env.GetLoggingManager()), - environment_(session_env) { - model_location_ = ToWideString(model_uri); - auto status = Model::Load(model_location_, model_proto_); - ORT_ENFORCE(status.IsOK(), "Given model could not be parsed while creating inference session. Error message: ", - status.ErrorMessage()); - is_model_proto_parsed_ = true; - // Finalize session options and initialize assets of this session instance - ConstructorCommon(session_options, session_env); + const std::string& model_uri) + : InferenceSession(session_options, session_env, ToPathString(model_uri)) { } #endif @@ -592,7 +582,7 @@ common::Status InferenceSession::RegisterGraphTransformer( return graph_transformation_mgr_.Register(std::move(p_graph_transformer), level); } -common::Status InferenceSession::SaveToOrtFormat(const std::basic_string& filepath) const { +common::Status InferenceSession::SaveToOrtFormat(const PathString& filepath) const { ORT_RETURN_IF_NOT(FLATBUFFERS_LITTLEENDIAN, "ort format only supports little-endian machines"); // Get the byte size of the ModelProto and round it to the next MB and use it as flatbuffers' init_size @@ -632,8 +622,8 @@ common::Status InferenceSession::SaveToOrtFormat(const std::basic_string&)> loader, - const std::string& event_name) { +common::Status InferenceSession::LoadWithLoader(std::function&)> loader, + const std::string& event_name) { Status status = Status::OK(); TimePoint tp; if (session_profiler_.IsEnabled()) { @@ -666,8 +656,9 @@ common::Status InferenceSession::Load(std::function -common::Status InferenceSession::Load(const std::basic_string& model_uri) { - model_location_ = ToWideString(model_uri); +common::Status InferenceSession::LoadOnnxModel(const PathString& model_uri) { + model_location_ = model_uri; auto loader = [this](std::shared_ptr& model) { #ifdef ENABLE_LANGUAGE_INTEROP_OPS LoadInterOp(model_location_, interop_domains_, [&](const char* msg) { LOGS(*session_logger_, WARNING) << msg; }); @@ -694,7 +684,7 @@ common::Status InferenceSession::Load(const std::basic_string& model_uri) { ModelOptions(true, strict_shape_type_inference)); }; - common::Status st = Load(loader, "model_loading_uri"); + common::Status st = LoadWithLoader(loader, "model_loading_uri"); if (!st.IsOK()) { std::ostringstream oss; oss << "Load model from " << ToUTF8String(model_uri) << " failed:" << st.ErrorMessage(); @@ -712,7 +702,7 @@ common::Status InferenceSession::FilterEnabledOptimizers(InlinedHashSet(model_uri); + return LoadOnnxModel(model_uri); #else return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "ONNX format model is not supported in this build."); #endif } #ifdef _WIN32 -common::Status InferenceSession::Load(const std::wstring& model_uri) { - std::string model_type = session_options_.config_options.GetConfigOrDefault(kOrtSessionOptionsConfigLoadModelFormat, ""); - bool has_explicit_type = !model_type.empty(); - - if ((has_explicit_type && model_type == "ORT") || - (!has_explicit_type && fbs::utils::IsOrtFormatModel(model_uri))) { - return LoadOrtModel(model_uri); - } - -#if !defined(ORT_MINIMAL_BUILD) - if (is_model_proto_parsed_) { - return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, - "ModelProto corresponding to the model to be loaded has already been parsed. " - "Invoke Load()."); - } - - return Load(model_uri); -#else - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "ONNX format model is not supported in this build."); -#endif +common::Status InferenceSession::Load(const std::string& model_uri) { + return Load(ToPathString(model_uri)); } #endif @@ -797,7 +769,7 @@ common::Status InferenceSession::Load(const void* model_data, int model_data_len ModelOptions(true, strict_shape_type_inference)); }; - return Load(loader, "model_loading_array"); + return LoadWithLoader(loader, "model_loading_array"); #else return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "ONNX format model is not supported in this build."); #endif @@ -805,7 +777,7 @@ common::Status InferenceSession::Load(const void* model_data, int model_data_len #if !defined(ORT_MINIMAL_BUILD) -common::Status InferenceSession::Load(const ModelProto& model_proto) { +common::Status InferenceSession::LoadOnnxModel(ModelProto model_proto) { if (is_model_proto_parsed_) { return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "ModelProto corresponding to the model to be loaded has already been parsed. " @@ -821,37 +793,17 @@ common::Status InferenceSession::Load(const ModelProto& model_proto) { #endif const bool strict_shape_type_inference = session_options_.config_options.GetConfigOrDefault( kOrtSessionOptionsConfigStrictShapeTypeInference, "0") == "1"; - // This call will create a copy of model_proto and the constructed model instance will own the copy thereafter - return onnxruntime::Model::Load(model_proto, PathString(), model, + // This call will move model_proto to the constructed model instance + return onnxruntime::Model::Load(std::move(model_proto), PathString(), model, HasLocalSchema() ? &custom_schema_registries_ : nullptr, *session_logger_, ModelOptions(true, strict_shape_type_inference)); }; - return Load(loader, "model_loading_proto"); + return LoadWithLoader(loader, "model_loading_proto"); } -common::Status InferenceSession::Load(std::unique_ptr p_model_proto) { - if (is_model_proto_parsed_) { - return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, - "ModelProto corresponding to the model to be loaded has already been parsed. " - "Invoke Load()."); - } - - auto loader = [this, &p_model_proto](std::shared_ptr& model) { -#ifdef ENABLE_LANGUAGE_INTEROP_OPS - LoadInterOp(*p_model_proto, interop_domains_, [&](const char* msg) { LOGS(*session_logger_, WARNING) << msg; }); - for (const auto& domain : interop_domains_) { - ORT_RETURN_IF_ERROR(AddCustomOpDomains({domain.get()})); - } -#endif - const bool strict_shape_type_inference = session_options_.config_options.GetConfigOrDefault( - kOrtSessionOptionsConfigStrictShapeTypeInference, "0") == "1"; - return onnxruntime::Model::Load(std::move(*p_model_proto), PathString(), model, - HasLocalSchema() ? &custom_schema_registries_ : nullptr, *session_logger_, - ModelOptions(true, strict_shape_type_inference)); - }; - - return Load(loader, "model_loading_proto"); +common::Status InferenceSession::LoadOnnxModel(std::unique_ptr p_model_proto) { + return LoadOnnxModel(std::move(*p_model_proto)); } common::Status InferenceSession::Load(std::istream& model_istream, bool allow_released_opsets_only) { @@ -882,7 +834,7 @@ common::Status InferenceSession::Load(std::istream& model_istream, bool allow_re *session_logger_, model_opts); }; - return Load(loader, "model_loading_istream"); + return LoadWithLoader(loader, "model_loading_istream"); } common::Status InferenceSession::Load() { @@ -907,7 +859,7 @@ common::Status InferenceSession::Load() { ModelOptions(true, strict_shape_type_inference)); }; - return Load(loader, "model_loading_from_saved_proto"); + return LoadWithLoader(loader, "model_loading_from_saved_proto"); } common::Status InferenceSession::TransformGraph(onnxruntime::Graph& graph, @@ -986,14 +938,11 @@ common::Status InferenceSession::TransformGraph(onnxruntime::Graph& graph, } #endif // !defined(ORT_MINIMAL_BUILD) -template -static Status LoadOrtModelBytes(const std::basic_string& model_uri, - std::basic_string& model_location, +static Status LoadOrtModelBytes(const PathString& model_uri, gsl::span& bytes, std::vector& bytes_data_holder) { size_t num_bytes = 0; - model_location = ToWideString(model_uri); - ORT_RETURN_IF_ERROR(Env::Default().GetFileLength(model_location.c_str(), num_bytes)); + ORT_RETURN_IF_ERROR(Env::Default().GetFileLength(model_uri.c_str(), num_bytes)); bytes_data_holder.resize(num_bytes); @@ -1011,30 +960,18 @@ static Status LoadOrtModelBytes(const std::basic_string& model_uri, return Status::OK(); } -Status InferenceSession::LoadOrtModel(const std::string& model_uri) { - return LoadOrtModel( +Status InferenceSession::LoadOrtModel(const PathString& model_uri) { + return LoadOrtModelWithLoader( [&]() { + model_location_ = model_uri; ORT_RETURN_IF_ERROR( - LoadOrtModelBytes(model_uri, model_location_, - ort_format_model_bytes_, ort_format_model_bytes_data_holder_)); + LoadOrtModelBytes(model_location_, ort_format_model_bytes_, ort_format_model_bytes_data_holder_)); return Status::OK(); }); } -#ifdef WIN32 -Status InferenceSession::LoadOrtModel(const std::wstring& model_uri) { - return LoadOrtModel( - [&]() { - ORT_RETURN_IF_ERROR( - LoadOrtModelBytes(model_uri, model_location_, - ort_format_model_bytes_, ort_format_model_bytes_data_holder_)); - return Status::OK(); - }); -} -#endif - Status InferenceSession::LoadOrtModel(const void* model_data, int model_data_len) { - return LoadOrtModel([&]() { + return LoadOrtModelWithLoader([&]() { const auto use_ort_model_bytes_directly = GetSessionOptions().config_options.GetConfigOrDefault(kOrtSessionOptionsConfigUseORTModelBytesDirectly, "0"); if (use_ort_model_bytes_directly != "1") { @@ -1052,7 +989,7 @@ Status InferenceSession::LoadOrtModel(const void* model_data, int model_data_len }); } -Status InferenceSession::LoadOrtModel(std::function load_ort_format_model_bytes) { +Status InferenceSession::LoadOrtModelWithLoader(std::function load_ort_format_model_bytes) { static_assert(FLATBUFFERS_LITTLEENDIAN, "ORT format only supports little-endian machines"); std::lock_guard l(session_mutex_); diff --git a/onnxruntime/core/session/inference_session.h b/onnxruntime/core/session/inference_session.h index a41e1c38fd..6676b521a3 100644 --- a/onnxruntime/core/session/inference_session.h +++ b/onnxruntime/core/session/inference_session.h @@ -9,6 +9,7 @@ #include "core/common/common.h" #include "core/common/inlined_containers.h" #include "core/common/logging/logging.h" +#include "core/common/path_string.h" #include "core/common/profiler.h" #include "core/common/status.h" #include "core/framework/execution_providers.h" @@ -154,11 +155,11 @@ class InferenceSession { */ InferenceSession(const SessionOptions& session_options, const Environment& session_env, - const std::string& model_uri); + const PathString& model_uri); #ifdef _WIN32 InferenceSession(const SessionOptions& session_options, const Environment& session_env, - const std::wstring& model_uri); + const std::string& model_uri); #endif /** @@ -256,9 +257,9 @@ class InferenceSession { * @param model_uri absolute path of the model file. * @return OK if success. */ - common::Status Load(const std::string& model_uri) ORT_MUST_USE_RESULT; + common::Status Load(const PathString& model_uri) ORT_MUST_USE_RESULT; #ifdef _WIN32 - common::Status Load(const std::wstring& model_uri) ORT_MUST_USE_RESULT; + common::Status Load(const std::string& model_uri) ORT_MUST_USE_RESULT; #endif /** * Load an ONNX or ORT format model. @@ -482,17 +483,17 @@ class InferenceSession { * @param protobuf object corresponding to the model file. model_proto will be copied by the API. * @return OK if success. */ - common::Status Load(const ONNX_NAMESPACE::ModelProto& model_proto) ORT_MUST_USE_RESULT; + common::Status LoadOnnxModel(ONNX_NAMESPACE::ModelProto model_proto) ORT_MUST_USE_RESULT; /** * Load an ONNX model. * @param protobuf object corresponding to the model file. This is primarily to support large models. * @return OK if success. */ - common::Status Load(std::unique_ptr p_model_proto) ORT_MUST_USE_RESULT; + common::Status LoadOnnxModel(std::unique_ptr p_model_proto) ORT_MUST_USE_RESULT; - common::Status Load(std::function&)> loader, - const std::string& event_name) ORT_MUST_USE_RESULT; + common::Status LoadWithLoader(std::function&)> loader, + const std::string& event_name) ORT_MUST_USE_RESULT; common::Status DoPostLoadProcessing(onnxruntime::Model& model) ORT_MUST_USE_RESULT; @@ -541,7 +542,7 @@ class InferenceSession { std::unordered_set model_output_names_; // The file path of where the model was loaded. e.g. /tmp/test_squeezenet/model.onnx - std::basic_string model_location_; + PathString model_location_; // The list of execution providers. ExecutionProviders execution_providers_; @@ -556,14 +557,13 @@ class InferenceSession { #if !defined(ORT_MINIMAL_BUILD) - template - common::Status Load(const std::basic_string& model_uri) ORT_MUST_USE_RESULT; + common::Status LoadOnnxModel(const PathString& model_uri) ORT_MUST_USE_RESULT; bool HasLocalSchema() const { return !custom_schema_registries_.empty(); } - common::Status SaveToOrtFormat(const std::basic_string& filepath) const; + common::Status SaveToOrtFormat(const PathString& filepath) const; #endif /** @@ -571,10 +571,7 @@ class InferenceSession { * @param model_uri absolute path of the model file. * @return OK if success. */ - common::Status LoadOrtModel(const std::string& model_uri) ORT_MUST_USE_RESULT; -#ifdef _WIN32 - common::Status LoadOrtModel(const std::wstring& model_uri) ORT_MUST_USE_RESULT; -#endif + common::Status LoadOrtModel(const PathString& model_uri) ORT_MUST_USE_RESULT; /** * Load an ORT format model. @@ -586,7 +583,7 @@ class InferenceSession { */ common::Status LoadOrtModel(const void* model_data, int model_data_len) ORT_MUST_USE_RESULT; - common::Status LoadOrtModel(std::function load_ort_format_model_bytes) ORT_MUST_USE_RESULT; + common::Status LoadOrtModelWithLoader(std::function load_ort_format_model_bytes) ORT_MUST_USE_RESULT; // Create a Logger for a single execution if possible. Otherwise use the default logger. // If a new logger is created, it will also be stored in new_run_logger, diff --git a/onnxruntime/test/framework/cuda/fence_cuda_test.cc b/onnxruntime/test/framework/cuda/fence_cuda_test.cc index 9f3bd945e2..390dd339a1 100644 --- a/onnxruntime/test/framework/cuda/fence_cuda_test.cc +++ b/onnxruntime/test/framework/cuda/fence_cuda_test.cc @@ -42,7 +42,7 @@ class FenceCudaTestInferenceSession : public InferenceSession { FenceCudaTestInferenceSession(const SessionOptions& so, const Environment& env) : InferenceSession(so, env) {} Status LoadModel(onnxruntime::Model& model) { auto model_proto = model.ToProto(); - auto st = Load(model_proto); + auto st = LoadOnnxModel(std::move(model_proto)); return st; } }; diff --git a/winml/adapter/winml_adapter_session.cpp b/winml/adapter/winml_adapter_session.cpp index 94238b2659..92aeddb633 100644 --- a/winml/adapter/winml_adapter_session.cpp +++ b/winml/adapter/winml_adapter_session.cpp @@ -30,7 +30,7 @@ class InferenceSessionProtectedLoadAccessor : public onnxruntime::InferenceSessi public: onnxruntime::common::Status Load(std::unique_ptr p_model_proto) { - return onnxruntime::InferenceSession::Load(std::move(p_model_proto)); + return onnxruntime::InferenceSession::LoadOnnxModel(std::move(p_model_proto)); } const onnxruntime::SessionState& GetSessionState() { return onnxruntime::InferenceSession::GetSessionState(); @@ -296,4 +296,4 @@ ORT_API_STATUS_IMPL(winmla::SessionGetNamedDimensionsOverrides, _In_ OrtSession* named_dimension_overrides = override_map.GetView(); return nullptr; API_IMPL_END -} \ No newline at end of file +}