From 4c939e1cb7b24fab71602e0cd63fcf62bbcb5b92 Mon Sep 17 00:00:00 2001 From: Guoyu Wang <62914304+gwang-msft@users.noreply.github.com> Date: Tue, 27 Jul 2021 09:11:42 -0700 Subject: [PATCH] Add an option to use the input model bytes (ORT format only) directly without copy at session creation (#8502) * Do not copy the model_data when session is started by CreateSessionFromArray * Add config option for disabling copy model bytes * Add one additional test * Address CR comments --- .../onnxruntime_session_options_config_keys.h | 10 ++++- onnxruntime/core/session/inference_session.cc | 41 +++++++++++------ onnxruntime/core/session/inference_session.h | 44 +++++++++++++------ .../test/framework/ort_model_only_test.cc | 24 +++++++++- 4 files changed, 90 insertions(+), 29 deletions(-) diff --git a/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h b/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h index 5ef05e38ad..57ff73fde9 100644 --- a/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h +++ b/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h @@ -61,6 +61,14 @@ static const char* const kOrtSessionOptionsUseDeviceAllocatorForInitializers = " static const char* const kOrtSessionOptionsConfigAllowInterOpSpinning = "session.inter_op.allow_spinning"; static const char* const kOrtSessionOptionsConfigAllowIntraOpSpinning = "session.intra_op.allow_spinning"; +// Key for using model bytes directly for ORT format +// If a session is created using an input byte array contains the ORT format model data, +// By default we will copy the model bytes at the time of session creation to ensure the model bytes +// buffer is valid. +// Setting this option to "1" will disable copy the model bytes, and use the model bytes directly. The caller +// has to guarantee that the model bytes are valid until the ORT session using the model bytes is destroyed. +static const char* const kOrtSessionOptionsConfigUseORTModelBytesDirectly = "session.use_ort_model_bytes_directly"; + // NNAPI EP keys begin // Note: These options should be specified prior to appending the NNAPI EP to the session options object in order for // them to take effect. @@ -71,5 +79,3 @@ static const char* const kOrtSessionOptionsConfigAllowIntraOpSpinning = "session // If not specified, the default set of stop ops is used. To specify an empty stop ops types list and disable stop op // exclusion, set the value to "". static const char* const kOrtSessionOptionsConfigNnapiEpPartitioningStopOps = "ep.nnapi.partitioning_stop_ops"; - -// NNAPI EP keys end diff --git a/onnxruntime/core/session/inference_session.cc b/onnxruntime/core/session/inference_session.cc index 9762be044d..51c3b51f97 100644 --- a/onnxruntime/core/session/inference_session.cc +++ b/onnxruntime/core/session/inference_session.cc @@ -959,15 +959,16 @@ Status InferenceSession::PartitionOrtFormatModel(onnxruntime::Graph& graph, template static Status LoadOrtModelBytes(const std::basic_string& model_uri, std::basic_string& model_location, - std::vector& bytes) { + gsl::span& bytes, + std::vector& bytes_data_holder) { size_t num_bytes = 0; model_location = ToWideString(model_uri); ORT_RETURN_IF_ERROR(Env::Default().GetFileLength(model_location.c_str(), num_bytes)); - bytes.resize(num_bytes); + bytes_data_holder.resize(num_bytes); std::ifstream bytes_stream(model_uri, std::ifstream::in | std::ifstream::binary); - bytes_stream.read(reinterpret_cast(bytes.data()), num_bytes); + bytes_stream.read(reinterpret_cast(bytes_data_holder.data()), num_bytes); if (!bytes_stream) { return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, @@ -975,13 +976,17 @@ static Status LoadOrtModelBytes(const std::basic_string& model_uri, bytes_stream.gcount(), "/", num_bytes, " bytes were able to be read."); } + bytes = gsl::span(bytes_data_holder.data(), num_bytes); + return Status::OK(); } Status InferenceSession::LoadOrtModel(const std::string& model_uri) { return LoadOrtModel( [&]() { - ORT_RETURN_IF_ERROR(LoadOrtModelBytes(model_uri, model_location_, ort_format_model_bytes_)); + ORT_RETURN_IF_ERROR( + LoadOrtModelBytes(model_uri, model_location_, + ort_format_model_bytes_, ort_format_model_bytes_data_holder_)); return Status::OK(); }); } @@ -990,7 +995,9 @@ Status InferenceSession::LoadOrtModel(const std::string& model_uri) { Status InferenceSession::LoadOrtModel(const std::wstring& model_uri) { return LoadOrtModel( [&]() { - ORT_RETURN_IF_ERROR(LoadOrtModelBytes(model_uri, model_location_, ort_format_model_bytes_)); + ORT_RETURN_IF_ERROR( + LoadOrtModelBytes(model_uri, model_location_, + ort_format_model_bytes_, ort_format_model_bytes_data_holder_)); return Status::OK(); }); } @@ -998,13 +1005,19 @@ Status InferenceSession::LoadOrtModel(const std::wstring& model_uri) { Status InferenceSession::LoadOrtModel(const void* model_data, int model_data_len) { return LoadOrtModel([&]() { - // copy bytes as we need them to be available when InferenceSession::Initialize is called later. - // - // TODO: Provide Load API where we can take ownership of memory to avoid the copy, - // and/or a combined Load+Initialize where we don't need this temporary copy. - ort_format_model_bytes_.resize(model_data_len); - std::copy_n(reinterpret_cast(model_data), model_data_len, ort_format_model_bytes_.data()); - + const auto use_ort_model_bytes_directly = + GetSessionOptions().config_options.GetConfigOrDefault(kOrtSessionOptionsConfigUseORTModelBytesDirectly, "0"); + if (use_ort_model_bytes_directly != "1") { + // copy bytes as we need them to be available when InferenceSession::Initialize is called later. + ort_format_model_bytes_data_holder_.resize(model_data_len); + std::copy_n(reinterpret_cast(model_data), model_data_len, + ort_format_model_bytes_data_holder_.data()); + ort_format_model_bytes_ = gsl::span(ort_format_model_bytes_data_holder_.data(), model_data_len); + } else { + // Use the model_data directly to reduce memory consumption + // This will require the model_data to be alive until the InferenceSession is initialized + ort_format_model_bytes_ = gsl::span(reinterpret_cast(model_data), model_data_len); + } return Status::OK(); }); } @@ -1318,7 +1331,9 @@ common::Status InferenceSession::Initialize() { is_inited_ = true; // we don't directly use the ORT format bytes currently, so free those now - std::vector().swap(ort_format_model_bytes_); + // TODO, we may need to keep the bytes if we are using the offset directly in the initializers + ort_format_model_bytes_ = gsl::span(); + std::vector().swap(ort_format_model_bytes_data_holder_); // and log telemetry bool model_has_fp16_inputs = ModelHasFP16Inputs(graph); diff --git a/onnxruntime/core/session/inference_session.h b/onnxruntime/core/session/inference_session.h index cfc45b120f..32abe60b98 100644 --- a/onnxruntime/core/session/inference_session.h +++ b/onnxruntime/core/session/inference_session.h @@ -184,10 +184,10 @@ class InferenceSession { /** * Filter the enabled optimizers (either transformer or rewrite rule) using optimizers_to_disable. * For an optimizer to be enabled, it must be allowed at the current optimization level (as specified in - * session options), and NOT in optimizers_to_disable. + * session options), and NOT in optimizers_to_disable. * This allows finer grained control of the enabled/disabled optimizations. * Must be called before Initialize() to take effect. - * + * * Calling this API is optional. * @return OK if success. */ @@ -217,12 +217,12 @@ class InferenceSession { /** * Load an ONNX or ORT format model. * - * Set SessionOptions session config value ORT_SESSION_OPTIONS_CONFIG_LOAD_MODEL_FORMAT to 'ORT' or 'ONNX' to + * Set SessionOptions session config value ORT_SESSION_OPTIONS_CONFIG_LOAD_MODEL_FORMAT to 'ORT' or 'ONNX' to * explicitly choose model format. * * If format is not explicitly specified and filename ends in '.ort' it will be inferred to be an ORT format model. * All other files are assumed to be in ONNX format. - * + * * @param model_uri absolute path of the model file. * @return OK if success. */ @@ -233,11 +233,11 @@ class InferenceSession { /** * Load an ONNX or ORT format model. * - * Set SessionOptions session config value ORT_SESSION_OPTIONS_CONFIG_LOAD_MODEL_FORMAT to 'ORT' or 'ONNX' to + * Set SessionOptions session config value ORT_SESSION_OPTIONS_CONFIG_LOAD_MODEL_FORMAT to 'ORT' or 'ONNX' to * explicitly choose model format. * * If format is not explicitly specified the model format will be inferred from the bytes, defaulting to ONNX. - * + * * @param model_data Model data buffer * @param model_data_len Model data buffer size * @return OK if success. @@ -309,7 +309,7 @@ class InferenceSession { #ifdef ENABLE_TRAINING /** * Partially run a pre-loaded and pre-intialized model. - * @param run_options run options. + * @param run_options run options. * @param feeds inputs owned by client code and should not be changed during * execution of this function. * @param fetches outputs produced after the executin of this function. @@ -430,9 +430,9 @@ class InferenceSession { } /** - * Add a PrepackedWeightsContainer instance to the session so as to store the pre-packed weights + * Add a PrepackedWeightsContainer instance to the session so as to store the pre-packed weights * of shared initializers to be shared across sessions. - * @param prepacked_weights_container PrepackedWeightsContainer instance + * @param prepacked_weights_container PrepackedWeightsContainer instance */ Status AddPrePackedWeightsContainer(PrepackedWeightsContainer* prepacked_weights_container); @@ -704,13 +704,31 @@ class InferenceSession { bool is_model_proto_parsed_ = false; const Environment& environment_; - // Bytes from an ORT format model. - // We store them currently to make the Load + Initialize behave the same way as for an ONNX model - // as we need some of the bytes for the Load (create the Model) and some for the Initialize (create SessionState). + // View of the bytes from an ORT format model. + // If the session is started with an input byte array contains model data, and the caller + // specifies that ORT should use the model bytes directly by setting the session config option + // "session.use_ort_model_bytes_directly" to "1" + // We use the the byte array directly without copy to reduce peak memory usage + // (Short term) This will require the user to guarantee the life time of the model data + // until the session is created. + // (Longer term) If we are going to use the memory offsets directly for initializers, the model data + // should be alive until the InferenceSession goes away. + // If the session is started with an input byte array contains model data, and the caller does not + // specify ORT should use the model bytes directly + // Or the session is started with a model_uri + // We store them currently in the ort_format_model_bytes_data_holder_ to make the Load + Initialize + // behave the same way as for an ONNX model, as we need some of the bytes for the Load (create the Model) + // and some for the Initialize (create SessionState). // Short term we free them after Initialize. // Longer term we may want to directly refer to offsets in this buffer for initializers so we don't need to copy // those into new OrtValue instances, at which point we won't free them until the InferenceSession goes away. - std::vector ort_format_model_bytes_; + gsl::span ort_format_model_bytes_; + + // This holds the actual model data + // In case if the session is started with an input byte array contains model data, and the caller + // specifies that ORT should use the model bytes directly by setting the session config option + // "session.use_ort_model_bytes_directly" to "1", this will be empty + std::vector ort_format_model_bytes_data_holder_; std::shared_ptr allocator_manager_; diff --git a/onnxruntime/test/framework/ort_model_only_test.cc b/onnxruntime/test/framework/ort_model_only_test.cc index f79a38afb5..3f14ad39b6 100644 --- a/onnxruntime/test/framework/ort_model_only_test.cc +++ b/onnxruntime/test/framework/ort_model_only_test.cc @@ -37,13 +37,19 @@ struct OrtModelTestInfo { std::function&)> output_verifier; std::vector> configs; bool run_use_buffer{false}; + bool disable_copy_ort_buffer{false}; }; static void RunOrtModel(const OrtModelTestInfo& test_info) { SessionOptions so; so.session_logid = test_info.logid; - for (const auto& config : test_info.configs) + for (const auto& config : test_info.configs) { so.config_options.AddConfigEntry(config.first.c_str(), config.second.c_str()); + } + + if (test_info.disable_copy_ort_buffer) { + so.config_options.AddConfigEntry(kOrtSessionOptionsConfigUseORTModelBytesDirectly, "1"); + } std::vector model_data; InferenceSessionWrapper session_object{so, GetEnvironment()}; @@ -448,6 +454,14 @@ TEST(OrtModelOnlyTests, LoadOrtFormatModelFromBuffer) { RunOrtModel(test_info); } +// Load the model from a buffer instead of a file path, and not copy the buffer in session creation +TEST(OrtModelOnlyTests, LoadOrtFormatModelFromBufferNoCopy) { + OrtModelTestInfo test_info = GetTestInfoForLoadOrtFormatModel(); + test_info.run_use_buffer = true; + test_info.disable_copy_ort_buffer = true; + RunOrtModel(test_info); +} + #if !defined(DISABLE_ML_OPS) // test that we can deserialize and run a previously saved ORT format model // for a model with sequence and map outputs @@ -502,6 +516,14 @@ TEST(OrtModelOnlyTests, LoadOrtFormatModelMLOpsFromBuffer) { RunOrtModel(test_info); } +// Load the model from a buffer instead of a file path, and not copy the buffer in session creation +TEST(OrtModelOnlyTests, LoadOrtFormatModelMLOpsFromBufferNoCopy) { + OrtModelTestInfo test_info = GetTestInfoForLoadOrtFormatModelMLOps(); + test_info.run_use_buffer = true; + test_info.disable_copy_ort_buffer = true; + RunOrtModel(test_info); +} + #endif // !defined(DISABLE_ML_OPS) } // namespace test