Add an option to use the input model bytes (ORT format only) directly without copy at session creation (#8502)

* Do not copy the model_data when session is started by CreateSessionFromArray

* Add config option for disabling copy model bytes

* Add one additional test

* Address CR comments
This commit is contained in:
Guoyu Wang 2021-07-27 09:11:42 -07:00 committed by GitHub
parent 1ae32655b3
commit 4c939e1cb7
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 90 additions and 29 deletions

View file

@ -61,6 +61,14 @@ static const char* const kOrtSessionOptionsUseDeviceAllocatorForInitializers = "
static const char* const kOrtSessionOptionsConfigAllowInterOpSpinning = "session.inter_op.allow_spinning";
static const char* const kOrtSessionOptionsConfigAllowIntraOpSpinning = "session.intra_op.allow_spinning";
// Key for using model bytes directly for ORT format
// If a session is created using an input byte array contains the ORT format model data,
// By default we will copy the model bytes at the time of session creation to ensure the model bytes
// buffer is valid.
// Setting this option to "1" will disable copy the model bytes, and use the model bytes directly. The caller
// has to guarantee that the model bytes are valid until the ORT session using the model bytes is destroyed.
static const char* const kOrtSessionOptionsConfigUseORTModelBytesDirectly = "session.use_ort_model_bytes_directly";
// NNAPI EP keys begin
// Note: These options should be specified prior to appending the NNAPI EP to the session options object in order for
// them to take effect.
@ -71,5 +79,3 @@ static const char* const kOrtSessionOptionsConfigAllowIntraOpSpinning = "session
// If not specified, the default set of stop ops is used. To specify an empty stop ops types list and disable stop op
// exclusion, set the value to "".
static const char* const kOrtSessionOptionsConfigNnapiEpPartitioningStopOps = "ep.nnapi.partitioning_stop_ops";
// NNAPI EP keys end

View file

@ -959,15 +959,16 @@ Status InferenceSession::PartitionOrtFormatModel(onnxruntime::Graph& graph,
template <typename T>
static Status LoadOrtModelBytes(const std::basic_string<T>& model_uri,
std::basic_string<ORTCHAR_T>& model_location,
std::vector<uint8_t>& bytes) {
gsl::span<const uint8_t>& bytes,
std::vector<uint8_t>& bytes_data_holder) {
size_t num_bytes = 0;
model_location = ToWideString(model_uri);
ORT_RETURN_IF_ERROR(Env::Default().GetFileLength(model_location.c_str(), num_bytes));
bytes.resize(num_bytes);
bytes_data_holder.resize(num_bytes);
std::ifstream bytes_stream(model_uri, std::ifstream::in | std::ifstream::binary);
bytes_stream.read(reinterpret_cast<char*>(bytes.data()), num_bytes);
bytes_stream.read(reinterpret_cast<char*>(bytes_data_holder.data()), num_bytes);
if (!bytes_stream) {
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL,
@ -975,13 +976,17 @@ static Status LoadOrtModelBytes(const std::basic_string<T>& model_uri,
bytes_stream.gcount(), "/", num_bytes, " bytes were able to be read.");
}
bytes = gsl::span<const uint8_t>(bytes_data_holder.data(), num_bytes);
return Status::OK();
}
Status InferenceSession::LoadOrtModel(const std::string& model_uri) {
return LoadOrtModel(
[&]() {
ORT_RETURN_IF_ERROR(LoadOrtModelBytes(model_uri, model_location_, ort_format_model_bytes_));
ORT_RETURN_IF_ERROR(
LoadOrtModelBytes(model_uri, model_location_,
ort_format_model_bytes_, ort_format_model_bytes_data_holder_));
return Status::OK();
});
}
@ -990,7 +995,9 @@ Status InferenceSession::LoadOrtModel(const std::string& model_uri) {
Status InferenceSession::LoadOrtModel(const std::wstring& model_uri) {
return LoadOrtModel(
[&]() {
ORT_RETURN_IF_ERROR(LoadOrtModelBytes(model_uri, model_location_, ort_format_model_bytes_));
ORT_RETURN_IF_ERROR(
LoadOrtModelBytes(model_uri, model_location_,
ort_format_model_bytes_, ort_format_model_bytes_data_holder_));
return Status::OK();
});
}
@ -998,13 +1005,19 @@ Status InferenceSession::LoadOrtModel(const std::wstring& model_uri) {
Status InferenceSession::LoadOrtModel(const void* model_data, int model_data_len) {
return LoadOrtModel([&]() {
// copy bytes as we need them to be available when InferenceSession::Initialize is called later.
//
// TODO: Provide Load API where we can take ownership of memory to avoid the copy,
// and/or a combined Load+Initialize where we don't need this temporary copy.
ort_format_model_bytes_.resize(model_data_len);
std::copy_n(reinterpret_cast<const uint8_t*>(model_data), model_data_len, ort_format_model_bytes_.data());
const auto use_ort_model_bytes_directly =
GetSessionOptions().config_options.GetConfigOrDefault(kOrtSessionOptionsConfigUseORTModelBytesDirectly, "0");
if (use_ort_model_bytes_directly != "1") {
// copy bytes as we need them to be available when InferenceSession::Initialize is called later.
ort_format_model_bytes_data_holder_.resize(model_data_len);
std::copy_n(reinterpret_cast<const uint8_t*>(model_data), model_data_len,
ort_format_model_bytes_data_holder_.data());
ort_format_model_bytes_ = gsl::span<const uint8_t>(ort_format_model_bytes_data_holder_.data(), model_data_len);
} else {
// Use the model_data directly to reduce memory consumption
// This will require the model_data to be alive until the InferenceSession is initialized
ort_format_model_bytes_ = gsl::span<const uint8_t>(reinterpret_cast<const uint8_t*>(model_data), model_data_len);
}
return Status::OK();
});
}
@ -1318,7 +1331,9 @@ common::Status InferenceSession::Initialize() {
is_inited_ = true;
// we don't directly use the ORT format bytes currently, so free those now
std::vector<uint8_t>().swap(ort_format_model_bytes_);
// TODO, we may need to keep the bytes if we are using the offset directly in the initializers
ort_format_model_bytes_ = gsl::span<const uint8_t>();
std::vector<uint8_t>().swap(ort_format_model_bytes_data_holder_);
// and log telemetry
bool model_has_fp16_inputs = ModelHasFP16Inputs(graph);

View file

@ -184,10 +184,10 @@ class InferenceSession {
/**
* Filter the enabled optimizers (either transformer or rewrite rule) using optimizers_to_disable.
* For an optimizer to be enabled, it must be allowed at the current optimization level (as specified in
* session options), and NOT in optimizers_to_disable.
* session options), and NOT in optimizers_to_disable.
* This allows finer grained control of the enabled/disabled optimizations.
* Must be called before Initialize() to take effect.
*
*
* Calling this API is optional.
* @return OK if success.
*/
@ -217,12 +217,12 @@ class InferenceSession {
/**
* Load an ONNX or ORT format model.
*
* Set SessionOptions session config value ORT_SESSION_OPTIONS_CONFIG_LOAD_MODEL_FORMAT to 'ORT' or 'ONNX' to
* Set SessionOptions session config value ORT_SESSION_OPTIONS_CONFIG_LOAD_MODEL_FORMAT to 'ORT' or 'ONNX' to
* explicitly choose model format.
*
* If format is not explicitly specified and filename ends in '.ort' it will be inferred to be an ORT format model.
* All other files are assumed to be in ONNX format.
*
*
* @param model_uri absolute path of the model file.
* @return OK if success.
*/
@ -233,11 +233,11 @@ class InferenceSession {
/**
* Load an ONNX or ORT format model.
*
* Set SessionOptions session config value ORT_SESSION_OPTIONS_CONFIG_LOAD_MODEL_FORMAT to 'ORT' or 'ONNX' to
* Set SessionOptions session config value ORT_SESSION_OPTIONS_CONFIG_LOAD_MODEL_FORMAT to 'ORT' or 'ONNX' to
* explicitly choose model format.
*
* If format is not explicitly specified the model format will be inferred from the bytes, defaulting to ONNX.
*
*
* @param model_data Model data buffer
* @param model_data_len Model data buffer size
* @return OK if success.
@ -309,7 +309,7 @@ class InferenceSession {
#ifdef ENABLE_TRAINING
/**
* Partially run a pre-loaded and pre-intialized model.
* @param run_options run options.
* @param run_options run options.
* @param feeds inputs owned by client code and should not be changed during
* execution of this function.
* @param fetches outputs produced after the executin of this function.
@ -430,9 +430,9 @@ class InferenceSession {
}
/**
* Add a PrepackedWeightsContainer instance to the session so as to store the pre-packed weights
* Add a PrepackedWeightsContainer instance to the session so as to store the pre-packed weights
* of shared initializers to be shared across sessions.
* @param prepacked_weights_container PrepackedWeightsContainer instance
* @param prepacked_weights_container PrepackedWeightsContainer instance
*/
Status AddPrePackedWeightsContainer(PrepackedWeightsContainer* prepacked_weights_container);
@ -704,13 +704,31 @@ class InferenceSession {
bool is_model_proto_parsed_ = false;
const Environment& environment_;
// Bytes from an ORT format model.
// We store them currently to make the Load + Initialize behave the same way as for an ONNX model
// as we need some of the bytes for the Load (create the Model) and some for the Initialize (create SessionState).
// View of the bytes from an ORT format model.
// If the session is started with an input byte array contains model data, and the caller
// specifies that ORT should use the model bytes directly by setting the session config option
// "session.use_ort_model_bytes_directly" to "1"
// We use the the byte array directly without copy to reduce peak memory usage
// (Short term) This will require the user to guarantee the life time of the model data
// until the session is created.
// (Longer term) If we are going to use the memory offsets directly for initializers, the model data
// should be alive until the InferenceSession goes away.
// If the session is started with an input byte array contains model data, and the caller does not
// specify ORT should use the model bytes directly
// Or the session is started with a model_uri
// We store them currently in the ort_format_model_bytes_data_holder_ to make the Load + Initialize
// behave the same way as for an ONNX model, as we need some of the bytes for the Load (create the Model)
// and some for the Initialize (create SessionState).
// Short term we free them after Initialize.
// Longer term we may want to directly refer to offsets in this buffer for initializers so we don't need to copy
// those into new OrtValue instances, at which point we won't free them until the InferenceSession goes away.
std::vector<uint8_t> ort_format_model_bytes_;
gsl::span<const uint8_t> ort_format_model_bytes_;
// This holds the actual model data
// In case if the session is started with an input byte array contains model data, and the caller
// specifies that ORT should use the model bytes directly by setting the session config option
// "session.use_ort_model_bytes_directly" to "1", this will be empty
std::vector<uint8_t> ort_format_model_bytes_data_holder_;
std::shared_ptr<onnxruntime::AllocatorManager> allocator_manager_;

View file

@ -37,13 +37,19 @@ struct OrtModelTestInfo {
std::function<void(const std::vector<OrtValue>&)> output_verifier;
std::vector<std::pair<std::string, std::string>> configs;
bool run_use_buffer{false};
bool disable_copy_ort_buffer{false};
};
static void RunOrtModel(const OrtModelTestInfo& test_info) {
SessionOptions so;
so.session_logid = test_info.logid;
for (const auto& config : test_info.configs)
for (const auto& config : test_info.configs) {
so.config_options.AddConfigEntry(config.first.c_str(), config.second.c_str());
}
if (test_info.disable_copy_ort_buffer) {
so.config_options.AddConfigEntry(kOrtSessionOptionsConfigUseORTModelBytesDirectly, "1");
}
std::vector<char> model_data;
InferenceSessionWrapper session_object{so, GetEnvironment()};
@ -448,6 +454,14 @@ TEST(OrtModelOnlyTests, LoadOrtFormatModelFromBuffer) {
RunOrtModel(test_info);
}
// Load the model from a buffer instead of a file path, and not copy the buffer in session creation
TEST(OrtModelOnlyTests, LoadOrtFormatModelFromBufferNoCopy) {
OrtModelTestInfo test_info = GetTestInfoForLoadOrtFormatModel();
test_info.run_use_buffer = true;
test_info.disable_copy_ort_buffer = true;
RunOrtModel(test_info);
}
#if !defined(DISABLE_ML_OPS)
// test that we can deserialize and run a previously saved ORT format model
// for a model with sequence and map outputs
@ -502,6 +516,14 @@ TEST(OrtModelOnlyTests, LoadOrtFormatModelMLOpsFromBuffer) {
RunOrtModel(test_info);
}
// Load the model from a buffer instead of a file path, and not copy the buffer in session creation
TEST(OrtModelOnlyTests, LoadOrtFormatModelMLOpsFromBufferNoCopy) {
OrtModelTestInfo test_info = GetTestInfoForLoadOrtFormatModelMLOps();
test_info.run_use_buffer = true;
test_info.disable_copy_ort_buffer = true;
RunOrtModel(test_info);
}
#endif // !defined(DISABLE_ML_OPS)
} // namespace test