Refactor InferenceSession Load member functions. (#12430)

Fix comparison of path characters when checking for ".ort" suffix.

Some clean up of InferenceSession Load functions.
- Reduce duplication between std::string/std::wstring versions.
- Renaming for clarity.
This commit is contained in:
Edward Chen 2022-08-03 16:28:26 -07:00 committed by GitHub
parent 97268e023c
commit 3efd9a73bb
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
7 changed files with 83 additions and 132 deletions

View file

@ -6,6 +6,13 @@
#include <string>
#include <type_traits>
// for std::tolower or std::towlower
#ifdef _WIN32
#include <cwctype>
#else
#include <cctype>
#endif
#include "core/common/common.h"
#include "core/session/onnxruntime_c_api.h"
@ -16,22 +23,30 @@ using PathChar = ORTCHAR_T;
// string type for filesystem paths
using PathString = std::basic_string<PathChar>;
inline PathString ToPathString(const PathString& s) {
return s;
}
#ifdef _WIN32
static_assert(std::is_same<PathString, std::wstring>::value, "PathString is not std::wstring!");
inline PathString ToPathString(const std::string& s) {
return ToWideString(s);
}
inline PathString ToPathString(const std::wstring& s) {
return s;
inline PathChar ToLowerPathChar(PathChar c) {
return std::towlower(c);
}
#else
static_assert(std::is_same<PathString, std::string>::value, "PathString is not std::string!");
inline PathString ToPathString(const std::string& s) {
return s;
inline PathChar ToLowerPathChar(PathChar c) {
return std::tolower(c);
}
#endif
} // namespace onnxruntime

View file

@ -302,6 +302,15 @@ Status LoadOpsetImportOrtFormat(const flatbuffers::Vector<flatbuffers::Offset<fb
return Status::OK();
}
bool IsOrtFormatModel(const PathString& filename) {
const auto len = filename.size();
return len > 4 &&
filename[len - 4] == ORT_TSTR('.') &&
ToLowerPathChar(filename[len - 3]) == ORT_TSTR('o') &&
ToLowerPathChar(filename[len - 2]) == ORT_TSTR('r') &&
ToLowerPathChar(filename[len - 1]) == ORT_TSTR('t');
}
bool IsOrtFormatModelBytes(const void* bytes, int num_bytes) {
return num_bytes > 8 && // check buffer is large enough to contain identifier so we don't read random memory
fbs::InferenceSessionBufferHasIdentifier(bytes);

View file

@ -3,9 +3,10 @@
#pragma once
#include <cctype>
#include <unordered_map>
#include <core/common/status.h>
#include "core/common/path_string.h"
#include "core/common/status.h"
namespace ONNX_NAMESPACE {
class ValueInfoProto;
@ -58,15 +59,7 @@ onnxruntime::common::Status LoadOpsetImportOrtFormat(
std::unordered_map<std::string, int>& domain_to_version);
// check if filename ends in .ort
template <typename T>
bool IsOrtFormatModel(const std::basic_string<T>& filename) {
auto len = filename.size();
return len > 4 &&
filename[len - 4] == '.' &&
std::tolower(filename[len - 3]) == 'o' &&
std::tolower(filename[len - 2]) == 'r' &&
std::tolower(filename[len - 1]) == 't';
}
bool IsOrtFormatModel(const PathString& filename);
// check if bytes has the flatbuffer ORT identifier
bool IsOrtFormatModelBytes(const void* bytes, int num_bytes);

View file

@ -386,8 +386,8 @@ InferenceSession::InferenceSession(const SessionOptions& session_options,
#if !defined(ORT_MINIMAL_BUILD)
InferenceSession::InferenceSession(const SessionOptions& session_options, const Environment& session_env,
const std::string& model_uri)
: model_location_(ToWideString(model_uri)),
const PathString& model_uri)
: model_location_(model_uri),
graph_transformation_mgr_(session_options.max_num_graph_transformation_steps),
insert_cast_transformer_("CastFloat16Transformer"),
logging_manager_(session_env.GetLoggingManager()),
@ -403,18 +403,8 @@ InferenceSession::InferenceSession(const SessionOptions& session_options, const
#ifdef _WIN32
InferenceSession::InferenceSession(const SessionOptions& session_options,
const Environment& session_env,
const std::wstring& model_uri)
: graph_transformation_mgr_(session_options.max_num_graph_transformation_steps),
insert_cast_transformer_("CastFloat16Transformer"),
logging_manager_(session_env.GetLoggingManager()),
environment_(session_env) {
model_location_ = ToWideString(model_uri);
auto status = Model::Load(model_location_, model_proto_);
ORT_ENFORCE(status.IsOK(), "Given model could not be parsed while creating inference session. Error message: ",
status.ErrorMessage());
is_model_proto_parsed_ = true;
// Finalize session options and initialize assets of this session instance
ConstructorCommon(session_options, session_env);
const std::string& model_uri)
: InferenceSession(session_options, session_env, ToPathString(model_uri)) {
}
#endif
@ -592,7 +582,7 @@ common::Status InferenceSession::RegisterGraphTransformer(
return graph_transformation_mgr_.Register(std::move(p_graph_transformer), level);
}
common::Status InferenceSession::SaveToOrtFormat(const std::basic_string<ORTCHAR_T>& filepath) const {
common::Status InferenceSession::SaveToOrtFormat(const PathString& filepath) const {
ORT_RETURN_IF_NOT(FLATBUFFERS_LITTLEENDIAN, "ort format only supports little-endian machines");
// Get the byte size of the ModelProto and round it to the next MB and use it as flatbuffers' init_size
@ -632,8 +622,8 @@ common::Status InferenceSession::SaveToOrtFormat(const std::basic_string<ORTCHAR
return Status::OK();
}
common::Status InferenceSession::Load(std::function<common::Status(std::shared_ptr<Model>&)> loader,
const std::string& event_name) {
common::Status InferenceSession::LoadWithLoader(std::function<common::Status(std::shared_ptr<Model>&)> loader,
const std::string& event_name) {
Status status = Status::OK();
TimePoint tp;
if (session_profiler_.IsEnabled()) {
@ -666,8 +656,9 @@ common::Status InferenceSession::Load(std::function<common::Status(std::shared_p
});
}
ORT_CATCH(...) {
LOGS(*session_logger_, ERROR) << "Unknown exception in Load()";
status = Status(common::ONNXRUNTIME, common::RUNTIME_EXCEPTION, "Encountered unknown exception in Load()");
LOGS(*session_logger_, ERROR) << "Unknown exception";
status = Status(common::ONNXRUNTIME, common::RUNTIME_EXCEPTION,
"Encountered unknown exception in LoadWithLoader()");
}
if (session_profiler_.IsEnabled()) {
@ -677,9 +668,8 @@ common::Status InferenceSession::Load(std::function<common::Status(std::shared_p
return status;
}
template <typename T>
common::Status InferenceSession::Load(const std::basic_string<T>& model_uri) {
model_location_ = ToWideString(model_uri);
common::Status InferenceSession::LoadOnnxModel(const PathString& model_uri) {
model_location_ = model_uri;
auto loader = [this](std::shared_ptr<onnxruntime::Model>& model) {
#ifdef ENABLE_LANGUAGE_INTEROP_OPS
LoadInterOp(model_location_, interop_domains_, [&](const char* msg) { LOGS(*session_logger_, WARNING) << msg; });
@ -694,7 +684,7 @@ common::Status InferenceSession::Load(const std::basic_string<T>& model_uri) {
ModelOptions(true, strict_shape_type_inference));
};
common::Status st = Load(loader, "model_loading_uri");
common::Status st = LoadWithLoader(loader, "model_loading_uri");
if (!st.IsOK()) {
std::ostringstream oss;
oss << "Load model from " << ToUTF8String(model_uri) << " failed:" << st.ErrorMessage();
@ -712,7 +702,7 @@ common::Status InferenceSession::FilterEnabledOptimizers(InlinedHashSet<std::str
}
#endif // !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)
common::Status InferenceSession::Load(const std::string& model_uri) {
common::Status InferenceSession::Load(const PathString& model_uri) {
std::string model_type = session_options_.config_options.GetConfigOrDefault(kOrtSessionOptionsConfigLoadModelFormat, "");
bool has_explicit_type = !model_type.empty();
@ -728,33 +718,15 @@ common::Status InferenceSession::Load(const std::string& model_uri) {
"Invoke Load().");
}
return Load<char>(model_uri);
return LoadOnnxModel(model_uri);
#else
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "ONNX format model is not supported in this build.");
#endif
}
#ifdef _WIN32
common::Status InferenceSession::Load(const std::wstring& model_uri) {
std::string model_type = session_options_.config_options.GetConfigOrDefault(kOrtSessionOptionsConfigLoadModelFormat, "");
bool has_explicit_type = !model_type.empty();
if ((has_explicit_type && model_type == "ORT") ||
(!has_explicit_type && fbs::utils::IsOrtFormatModel(model_uri))) {
return LoadOrtModel(model_uri);
}
#if !defined(ORT_MINIMAL_BUILD)
if (is_model_proto_parsed_) {
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL,
"ModelProto corresponding to the model to be loaded has already been parsed. "
"Invoke Load().");
}
return Load<PATH_CHAR_TYPE>(model_uri);
#else
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "ONNX format model is not supported in this build.");
#endif
common::Status InferenceSession::Load(const std::string& model_uri) {
return Load(ToPathString(model_uri));
}
#endif
@ -797,7 +769,7 @@ common::Status InferenceSession::Load(const void* model_data, int model_data_len
ModelOptions(true, strict_shape_type_inference));
};
return Load(loader, "model_loading_array");
return LoadWithLoader(loader, "model_loading_array");
#else
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "ONNX format model is not supported in this build.");
#endif
@ -805,7 +777,7 @@ common::Status InferenceSession::Load(const void* model_data, int model_data_len
#if !defined(ORT_MINIMAL_BUILD)
common::Status InferenceSession::Load(const ModelProto& model_proto) {
common::Status InferenceSession::LoadOnnxModel(ModelProto model_proto) {
if (is_model_proto_parsed_) {
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL,
"ModelProto corresponding to the model to be loaded has already been parsed. "
@ -821,37 +793,17 @@ common::Status InferenceSession::Load(const ModelProto& model_proto) {
#endif
const bool strict_shape_type_inference = session_options_.config_options.GetConfigOrDefault(
kOrtSessionOptionsConfigStrictShapeTypeInference, "0") == "1";
// This call will create a copy of model_proto and the constructed model instance will own the copy thereafter
return onnxruntime::Model::Load(model_proto, PathString(), model,
// This call will move model_proto to the constructed model instance
return onnxruntime::Model::Load(std::move(model_proto), PathString(), model,
HasLocalSchema() ? &custom_schema_registries_ : nullptr, *session_logger_,
ModelOptions(true, strict_shape_type_inference));
};
return Load(loader, "model_loading_proto");
return LoadWithLoader(loader, "model_loading_proto");
}
common::Status InferenceSession::Load(std::unique_ptr<ModelProto> p_model_proto) {
if (is_model_proto_parsed_) {
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL,
"ModelProto corresponding to the model to be loaded has already been parsed. "
"Invoke Load().");
}
auto loader = [this, &p_model_proto](std::shared_ptr<onnxruntime::Model>& model) {
#ifdef ENABLE_LANGUAGE_INTEROP_OPS
LoadInterOp(*p_model_proto, interop_domains_, [&](const char* msg) { LOGS(*session_logger_, WARNING) << msg; });
for (const auto& domain : interop_domains_) {
ORT_RETURN_IF_ERROR(AddCustomOpDomains({domain.get()}));
}
#endif
const bool strict_shape_type_inference = session_options_.config_options.GetConfigOrDefault(
kOrtSessionOptionsConfigStrictShapeTypeInference, "0") == "1";
return onnxruntime::Model::Load(std::move(*p_model_proto), PathString(), model,
HasLocalSchema() ? &custom_schema_registries_ : nullptr, *session_logger_,
ModelOptions(true, strict_shape_type_inference));
};
return Load(loader, "model_loading_proto");
common::Status InferenceSession::LoadOnnxModel(std::unique_ptr<ModelProto> p_model_proto) {
return LoadOnnxModel(std::move(*p_model_proto));
}
common::Status InferenceSession::Load(std::istream& model_istream, bool allow_released_opsets_only) {
@ -882,7 +834,7 @@ common::Status InferenceSession::Load(std::istream& model_istream, bool allow_re
*session_logger_, model_opts);
};
return Load(loader, "model_loading_istream");
return LoadWithLoader(loader, "model_loading_istream");
}
common::Status InferenceSession::Load() {
@ -907,7 +859,7 @@ common::Status InferenceSession::Load() {
ModelOptions(true, strict_shape_type_inference));
};
return Load(loader, "model_loading_from_saved_proto");
return LoadWithLoader(loader, "model_loading_from_saved_proto");
}
common::Status InferenceSession::TransformGraph(onnxruntime::Graph& graph,
@ -986,14 +938,11 @@ common::Status InferenceSession::TransformGraph(onnxruntime::Graph& graph,
}
#endif // !defined(ORT_MINIMAL_BUILD)
template <typename T>
static Status LoadOrtModelBytes(const std::basic_string<T>& model_uri,
std::basic_string<ORTCHAR_T>& model_location,
static Status LoadOrtModelBytes(const PathString& model_uri,
gsl::span<const uint8_t>& bytes,
std::vector<uint8_t>& bytes_data_holder) {
size_t num_bytes = 0;
model_location = ToWideString(model_uri);
ORT_RETURN_IF_ERROR(Env::Default().GetFileLength(model_location.c_str(), num_bytes));
ORT_RETURN_IF_ERROR(Env::Default().GetFileLength(model_uri.c_str(), num_bytes));
bytes_data_holder.resize(num_bytes);
@ -1011,30 +960,18 @@ static Status LoadOrtModelBytes(const std::basic_string<T>& model_uri,
return Status::OK();
}
Status InferenceSession::LoadOrtModel(const std::string& model_uri) {
return LoadOrtModel(
Status InferenceSession::LoadOrtModel(const PathString& model_uri) {
return LoadOrtModelWithLoader(
[&]() {
model_location_ = model_uri;
ORT_RETURN_IF_ERROR(
LoadOrtModelBytes(model_uri, model_location_,
ort_format_model_bytes_, ort_format_model_bytes_data_holder_));
LoadOrtModelBytes(model_location_, ort_format_model_bytes_, ort_format_model_bytes_data_holder_));
return Status::OK();
});
}
#ifdef WIN32
Status InferenceSession::LoadOrtModel(const std::wstring& model_uri) {
return LoadOrtModel(
[&]() {
ORT_RETURN_IF_ERROR(
LoadOrtModelBytes(model_uri, model_location_,
ort_format_model_bytes_, ort_format_model_bytes_data_holder_));
return Status::OK();
});
}
#endif
Status InferenceSession::LoadOrtModel(const void* model_data, int model_data_len) {
return LoadOrtModel([&]() {
return LoadOrtModelWithLoader([&]() {
const auto use_ort_model_bytes_directly =
GetSessionOptions().config_options.GetConfigOrDefault(kOrtSessionOptionsConfigUseORTModelBytesDirectly, "0");
if (use_ort_model_bytes_directly != "1") {
@ -1052,7 +989,7 @@ Status InferenceSession::LoadOrtModel(const void* model_data, int model_data_len
});
}
Status InferenceSession::LoadOrtModel(std::function<Status()> load_ort_format_model_bytes) {
Status InferenceSession::LoadOrtModelWithLoader(std::function<Status()> load_ort_format_model_bytes) {
static_assert(FLATBUFFERS_LITTLEENDIAN, "ORT format only supports little-endian machines");
std::lock_guard<onnxruntime::OrtMutex> l(session_mutex_);

View file

@ -9,6 +9,7 @@
#include "core/common/common.h"
#include "core/common/inlined_containers.h"
#include "core/common/logging/logging.h"
#include "core/common/path_string.h"
#include "core/common/profiler.h"
#include "core/common/status.h"
#include "core/framework/execution_providers.h"
@ -154,11 +155,11 @@ class InferenceSession {
*/
InferenceSession(const SessionOptions& session_options,
const Environment& session_env,
const std::string& model_uri);
const PathString& model_uri);
#ifdef _WIN32
InferenceSession(const SessionOptions& session_options,
const Environment& session_env,
const std::wstring& model_uri);
const std::string& model_uri);
#endif
/**
@ -256,9 +257,9 @@ class InferenceSession {
* @param model_uri absolute path of the model file.
* @return OK if success.
*/
common::Status Load(const std::string& model_uri) ORT_MUST_USE_RESULT;
common::Status Load(const PathString& model_uri) ORT_MUST_USE_RESULT;
#ifdef _WIN32
common::Status Load(const std::wstring& model_uri) ORT_MUST_USE_RESULT;
common::Status Load(const std::string& model_uri) ORT_MUST_USE_RESULT;
#endif
/**
* Load an ONNX or ORT format model.
@ -482,17 +483,17 @@ class InferenceSession {
* @param protobuf object corresponding to the model file. model_proto will be copied by the API.
* @return OK if success.
*/
common::Status Load(const ONNX_NAMESPACE::ModelProto& model_proto) ORT_MUST_USE_RESULT;
common::Status LoadOnnxModel(ONNX_NAMESPACE::ModelProto model_proto) ORT_MUST_USE_RESULT;
/**
* Load an ONNX model.
* @param protobuf object corresponding to the model file. This is primarily to support large models.
* @return OK if success.
*/
common::Status Load(std::unique_ptr<ONNX_NAMESPACE::ModelProto> p_model_proto) ORT_MUST_USE_RESULT;
common::Status LoadOnnxModel(std::unique_ptr<ONNX_NAMESPACE::ModelProto> p_model_proto) ORT_MUST_USE_RESULT;
common::Status Load(std::function<common::Status(std::shared_ptr<Model>&)> loader,
const std::string& event_name) ORT_MUST_USE_RESULT;
common::Status LoadWithLoader(std::function<common::Status(std::shared_ptr<Model>&)> loader,
const std::string& event_name) ORT_MUST_USE_RESULT;
common::Status DoPostLoadProcessing(onnxruntime::Model& model) ORT_MUST_USE_RESULT;
@ -541,7 +542,7 @@ class InferenceSession {
std::unordered_set<std::string> model_output_names_;
// The file path of where the model was loaded. e.g. /tmp/test_squeezenet/model.onnx
std::basic_string<ORTCHAR_T> model_location_;
PathString model_location_;
// The list of execution providers.
ExecutionProviders execution_providers_;
@ -556,14 +557,13 @@ class InferenceSession {
#if !defined(ORT_MINIMAL_BUILD)
template <typename T>
common::Status Load(const std::basic_string<T>& model_uri) ORT_MUST_USE_RESULT;
common::Status LoadOnnxModel(const PathString& model_uri) ORT_MUST_USE_RESULT;
bool HasLocalSchema() const {
return !custom_schema_registries_.empty();
}
common::Status SaveToOrtFormat(const std::basic_string<ORTCHAR_T>& filepath) const;
common::Status SaveToOrtFormat(const PathString& filepath) const;
#endif
/**
@ -571,10 +571,7 @@ class InferenceSession {
* @param model_uri absolute path of the model file.
* @return OK if success.
*/
common::Status LoadOrtModel(const std::string& model_uri) ORT_MUST_USE_RESULT;
#ifdef _WIN32
common::Status LoadOrtModel(const std::wstring& model_uri) ORT_MUST_USE_RESULT;
#endif
common::Status LoadOrtModel(const PathString& model_uri) ORT_MUST_USE_RESULT;
/**
* Load an ORT format model.
@ -586,7 +583,7 @@ class InferenceSession {
*/
common::Status LoadOrtModel(const void* model_data, int model_data_len) ORT_MUST_USE_RESULT;
common::Status LoadOrtModel(std::function<Status()> load_ort_format_model_bytes) ORT_MUST_USE_RESULT;
common::Status LoadOrtModelWithLoader(std::function<Status()> load_ort_format_model_bytes) ORT_MUST_USE_RESULT;
// Create a Logger for a single execution if possible. Otherwise use the default logger.
// If a new logger is created, it will also be stored in new_run_logger,

View file

@ -42,7 +42,7 @@ class FenceCudaTestInferenceSession : public InferenceSession {
FenceCudaTestInferenceSession(const SessionOptions& so, const Environment& env) : InferenceSession(so, env) {}
Status LoadModel(onnxruntime::Model& model) {
auto model_proto = model.ToProto();
auto st = Load(model_proto);
auto st = LoadOnnxModel(std::move(model_proto));
return st;
}
};

View file

@ -30,7 +30,7 @@ class InferenceSessionProtectedLoadAccessor : public onnxruntime::InferenceSessi
public:
onnxruntime::common::Status
Load(std::unique_ptr<ONNX_NAMESPACE::ModelProto> p_model_proto) {
return onnxruntime::InferenceSession::Load(std::move(p_model_proto));
return onnxruntime::InferenceSession::LoadOnnxModel(std::move(p_model_proto));
}
const onnxruntime::SessionState& GetSessionState() {
return onnxruntime::InferenceSession::GetSessionState();
@ -296,4 +296,4 @@ ORT_API_STATUS_IMPL(winmla::SessionGetNamedDimensionsOverrides, _In_ OrtSession*
named_dimension_overrides = override_map.GetView();
return nullptr;
API_IMPL_END
}
}