mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-31 23:27:43 +00:00
Add an option to use the input model bytes (ORT format only) directly without copy at session creation (#8502)
* Do not copy the model_data when session is started by CreateSessionFromArray * Add config option for disabling copy model bytes * Add one additional test * Address CR comments
This commit is contained in:
parent
1ae32655b3
commit
4c939e1cb7
4 changed files with 90 additions and 29 deletions
|
|
@ -61,6 +61,14 @@ static const char* const kOrtSessionOptionsUseDeviceAllocatorForInitializers = "
|
|||
static const char* const kOrtSessionOptionsConfigAllowInterOpSpinning = "session.inter_op.allow_spinning";
|
||||
static const char* const kOrtSessionOptionsConfigAllowIntraOpSpinning = "session.intra_op.allow_spinning";
|
||||
|
||||
// Key for using model bytes directly for ORT format
|
||||
// If a session is created using an input byte array contains the ORT format model data,
|
||||
// By default we will copy the model bytes at the time of session creation to ensure the model bytes
|
||||
// buffer is valid.
|
||||
// Setting this option to "1" will disable copy the model bytes, and use the model bytes directly. The caller
|
||||
// has to guarantee that the model bytes are valid until the ORT session using the model bytes is destroyed.
|
||||
static const char* const kOrtSessionOptionsConfigUseORTModelBytesDirectly = "session.use_ort_model_bytes_directly";
|
||||
|
||||
// NNAPI EP keys begin
|
||||
// Note: These options should be specified prior to appending the NNAPI EP to the session options object in order for
|
||||
// them to take effect.
|
||||
|
|
@ -71,5 +79,3 @@ static const char* const kOrtSessionOptionsConfigAllowIntraOpSpinning = "session
|
|||
// If not specified, the default set of stop ops is used. To specify an empty stop ops types list and disable stop op
|
||||
// exclusion, set the value to "".
|
||||
static const char* const kOrtSessionOptionsConfigNnapiEpPartitioningStopOps = "ep.nnapi.partitioning_stop_ops";
|
||||
|
||||
// NNAPI EP keys end
|
||||
|
|
|
|||
|
|
@ -959,15 +959,16 @@ Status InferenceSession::PartitionOrtFormatModel(onnxruntime::Graph& graph,
|
|||
template <typename T>
|
||||
static Status LoadOrtModelBytes(const std::basic_string<T>& model_uri,
|
||||
std::basic_string<ORTCHAR_T>& model_location,
|
||||
std::vector<uint8_t>& bytes) {
|
||||
gsl::span<const uint8_t>& bytes,
|
||||
std::vector<uint8_t>& bytes_data_holder) {
|
||||
size_t num_bytes = 0;
|
||||
model_location = ToWideString(model_uri);
|
||||
ORT_RETURN_IF_ERROR(Env::Default().GetFileLength(model_location.c_str(), num_bytes));
|
||||
|
||||
bytes.resize(num_bytes);
|
||||
bytes_data_holder.resize(num_bytes);
|
||||
|
||||
std::ifstream bytes_stream(model_uri, std::ifstream::in | std::ifstream::binary);
|
||||
bytes_stream.read(reinterpret_cast<char*>(bytes.data()), num_bytes);
|
||||
bytes_stream.read(reinterpret_cast<char*>(bytes_data_holder.data()), num_bytes);
|
||||
|
||||
if (!bytes_stream) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL,
|
||||
|
|
@ -975,13 +976,17 @@ static Status LoadOrtModelBytes(const std::basic_string<T>& model_uri,
|
|||
bytes_stream.gcount(), "/", num_bytes, " bytes were able to be read.");
|
||||
}
|
||||
|
||||
bytes = gsl::span<const uint8_t>(bytes_data_holder.data(), num_bytes);
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status InferenceSession::LoadOrtModel(const std::string& model_uri) {
|
||||
return LoadOrtModel(
|
||||
[&]() {
|
||||
ORT_RETURN_IF_ERROR(LoadOrtModelBytes(model_uri, model_location_, ort_format_model_bytes_));
|
||||
ORT_RETURN_IF_ERROR(
|
||||
LoadOrtModelBytes(model_uri, model_location_,
|
||||
ort_format_model_bytes_, ort_format_model_bytes_data_holder_));
|
||||
return Status::OK();
|
||||
});
|
||||
}
|
||||
|
|
@ -990,7 +995,9 @@ Status InferenceSession::LoadOrtModel(const std::string& model_uri) {
|
|||
Status InferenceSession::LoadOrtModel(const std::wstring& model_uri) {
|
||||
return LoadOrtModel(
|
||||
[&]() {
|
||||
ORT_RETURN_IF_ERROR(LoadOrtModelBytes(model_uri, model_location_, ort_format_model_bytes_));
|
||||
ORT_RETURN_IF_ERROR(
|
||||
LoadOrtModelBytes(model_uri, model_location_,
|
||||
ort_format_model_bytes_, ort_format_model_bytes_data_holder_));
|
||||
return Status::OK();
|
||||
});
|
||||
}
|
||||
|
|
@ -998,13 +1005,19 @@ Status InferenceSession::LoadOrtModel(const std::wstring& model_uri) {
|
|||
|
||||
Status InferenceSession::LoadOrtModel(const void* model_data, int model_data_len) {
|
||||
return LoadOrtModel([&]() {
|
||||
// copy bytes as we need them to be available when InferenceSession::Initialize is called later.
|
||||
//
|
||||
// TODO: Provide Load API where we can take ownership of memory to avoid the copy,
|
||||
// and/or a combined Load+Initialize where we don't need this temporary copy.
|
||||
ort_format_model_bytes_.resize(model_data_len);
|
||||
std::copy_n(reinterpret_cast<const uint8_t*>(model_data), model_data_len, ort_format_model_bytes_.data());
|
||||
|
||||
const auto use_ort_model_bytes_directly =
|
||||
GetSessionOptions().config_options.GetConfigOrDefault(kOrtSessionOptionsConfigUseORTModelBytesDirectly, "0");
|
||||
if (use_ort_model_bytes_directly != "1") {
|
||||
// copy bytes as we need them to be available when InferenceSession::Initialize is called later.
|
||||
ort_format_model_bytes_data_holder_.resize(model_data_len);
|
||||
std::copy_n(reinterpret_cast<const uint8_t*>(model_data), model_data_len,
|
||||
ort_format_model_bytes_data_holder_.data());
|
||||
ort_format_model_bytes_ = gsl::span<const uint8_t>(ort_format_model_bytes_data_holder_.data(), model_data_len);
|
||||
} else {
|
||||
// Use the model_data directly to reduce memory consumption
|
||||
// This will require the model_data to be alive until the InferenceSession is initialized
|
||||
ort_format_model_bytes_ = gsl::span<const uint8_t>(reinterpret_cast<const uint8_t*>(model_data), model_data_len);
|
||||
}
|
||||
return Status::OK();
|
||||
});
|
||||
}
|
||||
|
|
@ -1318,7 +1331,9 @@ common::Status InferenceSession::Initialize() {
|
|||
is_inited_ = true;
|
||||
|
||||
// we don't directly use the ORT format bytes currently, so free those now
|
||||
std::vector<uint8_t>().swap(ort_format_model_bytes_);
|
||||
// TODO, we may need to keep the bytes if we are using the offset directly in the initializers
|
||||
ort_format_model_bytes_ = gsl::span<const uint8_t>();
|
||||
std::vector<uint8_t>().swap(ort_format_model_bytes_data_holder_);
|
||||
|
||||
// and log telemetry
|
||||
bool model_has_fp16_inputs = ModelHasFP16Inputs(graph);
|
||||
|
|
|
|||
|
|
@ -184,10 +184,10 @@ class InferenceSession {
|
|||
/**
|
||||
* Filter the enabled optimizers (either transformer or rewrite rule) using optimizers_to_disable.
|
||||
* For an optimizer to be enabled, it must be allowed at the current optimization level (as specified in
|
||||
* session options), and NOT in optimizers_to_disable.
|
||||
* session options), and NOT in optimizers_to_disable.
|
||||
* This allows finer grained control of the enabled/disabled optimizations.
|
||||
* Must be called before Initialize() to take effect.
|
||||
*
|
||||
*
|
||||
* Calling this API is optional.
|
||||
* @return OK if success.
|
||||
*/
|
||||
|
|
@ -217,12 +217,12 @@ class InferenceSession {
|
|||
/**
|
||||
* Load an ONNX or ORT format model.
|
||||
*
|
||||
* Set SessionOptions session config value ORT_SESSION_OPTIONS_CONFIG_LOAD_MODEL_FORMAT to 'ORT' or 'ONNX' to
|
||||
* Set SessionOptions session config value ORT_SESSION_OPTIONS_CONFIG_LOAD_MODEL_FORMAT to 'ORT' or 'ONNX' to
|
||||
* explicitly choose model format.
|
||||
*
|
||||
* If format is not explicitly specified and filename ends in '.ort' it will be inferred to be an ORT format model.
|
||||
* All other files are assumed to be in ONNX format.
|
||||
*
|
||||
*
|
||||
* @param model_uri absolute path of the model file.
|
||||
* @return OK if success.
|
||||
*/
|
||||
|
|
@ -233,11 +233,11 @@ class InferenceSession {
|
|||
/**
|
||||
* Load an ONNX or ORT format model.
|
||||
*
|
||||
* Set SessionOptions session config value ORT_SESSION_OPTIONS_CONFIG_LOAD_MODEL_FORMAT to 'ORT' or 'ONNX' to
|
||||
* Set SessionOptions session config value ORT_SESSION_OPTIONS_CONFIG_LOAD_MODEL_FORMAT to 'ORT' or 'ONNX' to
|
||||
* explicitly choose model format.
|
||||
*
|
||||
* If format is not explicitly specified the model format will be inferred from the bytes, defaulting to ONNX.
|
||||
*
|
||||
*
|
||||
* @param model_data Model data buffer
|
||||
* @param model_data_len Model data buffer size
|
||||
* @return OK if success.
|
||||
|
|
@ -309,7 +309,7 @@ class InferenceSession {
|
|||
#ifdef ENABLE_TRAINING
|
||||
/**
|
||||
* Partially run a pre-loaded and pre-intialized model.
|
||||
* @param run_options run options.
|
||||
* @param run_options run options.
|
||||
* @param feeds inputs owned by client code and should not be changed during
|
||||
* execution of this function.
|
||||
* @param fetches outputs produced after the executin of this function.
|
||||
|
|
@ -430,9 +430,9 @@ class InferenceSession {
|
|||
}
|
||||
|
||||
/**
|
||||
* Add a PrepackedWeightsContainer instance to the session so as to store the pre-packed weights
|
||||
* Add a PrepackedWeightsContainer instance to the session so as to store the pre-packed weights
|
||||
* of shared initializers to be shared across sessions.
|
||||
* @param prepacked_weights_container PrepackedWeightsContainer instance
|
||||
* @param prepacked_weights_container PrepackedWeightsContainer instance
|
||||
*/
|
||||
Status AddPrePackedWeightsContainer(PrepackedWeightsContainer* prepacked_weights_container);
|
||||
|
||||
|
|
@ -704,13 +704,31 @@ class InferenceSession {
|
|||
bool is_model_proto_parsed_ = false;
|
||||
const Environment& environment_;
|
||||
|
||||
// Bytes from an ORT format model.
|
||||
// We store them currently to make the Load + Initialize behave the same way as for an ONNX model
|
||||
// as we need some of the bytes for the Load (create the Model) and some for the Initialize (create SessionState).
|
||||
// View of the bytes from an ORT format model.
|
||||
// If the session is started with an input byte array contains model data, and the caller
|
||||
// specifies that ORT should use the model bytes directly by setting the session config option
|
||||
// "session.use_ort_model_bytes_directly" to "1"
|
||||
// We use the the byte array directly without copy to reduce peak memory usage
|
||||
// (Short term) This will require the user to guarantee the life time of the model data
|
||||
// until the session is created.
|
||||
// (Longer term) If we are going to use the memory offsets directly for initializers, the model data
|
||||
// should be alive until the InferenceSession goes away.
|
||||
// If the session is started with an input byte array contains model data, and the caller does not
|
||||
// specify ORT should use the model bytes directly
|
||||
// Or the session is started with a model_uri
|
||||
// We store them currently in the ort_format_model_bytes_data_holder_ to make the Load + Initialize
|
||||
// behave the same way as for an ONNX model, as we need some of the bytes for the Load (create the Model)
|
||||
// and some for the Initialize (create SessionState).
|
||||
// Short term we free them after Initialize.
|
||||
// Longer term we may want to directly refer to offsets in this buffer for initializers so we don't need to copy
|
||||
// those into new OrtValue instances, at which point we won't free them until the InferenceSession goes away.
|
||||
std::vector<uint8_t> ort_format_model_bytes_;
|
||||
gsl::span<const uint8_t> ort_format_model_bytes_;
|
||||
|
||||
// This holds the actual model data
|
||||
// In case if the session is started with an input byte array contains model data, and the caller
|
||||
// specifies that ORT should use the model bytes directly by setting the session config option
|
||||
// "session.use_ort_model_bytes_directly" to "1", this will be empty
|
||||
std::vector<uint8_t> ort_format_model_bytes_data_holder_;
|
||||
|
||||
std::shared_ptr<onnxruntime::AllocatorManager> allocator_manager_;
|
||||
|
||||
|
|
|
|||
|
|
@ -37,13 +37,19 @@ struct OrtModelTestInfo {
|
|||
std::function<void(const std::vector<OrtValue>&)> output_verifier;
|
||||
std::vector<std::pair<std::string, std::string>> configs;
|
||||
bool run_use_buffer{false};
|
||||
bool disable_copy_ort_buffer{false};
|
||||
};
|
||||
|
||||
static void RunOrtModel(const OrtModelTestInfo& test_info) {
|
||||
SessionOptions so;
|
||||
so.session_logid = test_info.logid;
|
||||
for (const auto& config : test_info.configs)
|
||||
for (const auto& config : test_info.configs) {
|
||||
so.config_options.AddConfigEntry(config.first.c_str(), config.second.c_str());
|
||||
}
|
||||
|
||||
if (test_info.disable_copy_ort_buffer) {
|
||||
so.config_options.AddConfigEntry(kOrtSessionOptionsConfigUseORTModelBytesDirectly, "1");
|
||||
}
|
||||
|
||||
std::vector<char> model_data;
|
||||
InferenceSessionWrapper session_object{so, GetEnvironment()};
|
||||
|
|
@ -448,6 +454,14 @@ TEST(OrtModelOnlyTests, LoadOrtFormatModelFromBuffer) {
|
|||
RunOrtModel(test_info);
|
||||
}
|
||||
|
||||
// Load the model from a buffer instead of a file path, and not copy the buffer in session creation
|
||||
TEST(OrtModelOnlyTests, LoadOrtFormatModelFromBufferNoCopy) {
|
||||
OrtModelTestInfo test_info = GetTestInfoForLoadOrtFormatModel();
|
||||
test_info.run_use_buffer = true;
|
||||
test_info.disable_copy_ort_buffer = true;
|
||||
RunOrtModel(test_info);
|
||||
}
|
||||
|
||||
#if !defined(DISABLE_ML_OPS)
|
||||
// test that we can deserialize and run a previously saved ORT format model
|
||||
// for a model with sequence and map outputs
|
||||
|
|
@ -502,6 +516,14 @@ TEST(OrtModelOnlyTests, LoadOrtFormatModelMLOpsFromBuffer) {
|
|||
RunOrtModel(test_info);
|
||||
}
|
||||
|
||||
// Load the model from a buffer instead of a file path, and not copy the buffer in session creation
|
||||
TEST(OrtModelOnlyTests, LoadOrtFormatModelMLOpsFromBufferNoCopy) {
|
||||
OrtModelTestInfo test_info = GetTestInfoForLoadOrtFormatModelMLOps();
|
||||
test_info.run_use_buffer = true;
|
||||
test_info.disable_copy_ort_buffer = true;
|
||||
RunOrtModel(test_info);
|
||||
}
|
||||
|
||||
#endif // !defined(DISABLE_ML_OPS)
|
||||
|
||||
} // namespace test
|
||||
|
|
|
|||
Loading…
Reference in a new issue