mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-07-04 04:07:22 +00:00
Incrementally free initializers while saving to OrtValue instances (#12485)
* Free initializer TensorProto instances as they're converted to OrtValue to reduce peak memory usage. Co-authored-by: Pranav Sharma <prs@microsoft.com>
This commit is contained in:
parent
730240d2a5
commit
56bd96a3f5
7 changed files with 40 additions and 22 deletions
|
|
@ -1468,8 +1468,13 @@ Status SessionState::FinalizeSessionStateImpl(const std::basic_string<PATH_CHAR_
|
|||
Env::Default(), graph_location, *graph_viewer_,
|
||||
execution_providers_.GetDefaultCpuAllocator(),
|
||||
ort_value_name_idx_map_, initializer_allocation_order, *tensor_allocator,
|
||||
[this](int idx, const OrtValue& value, const OrtCallback& d, bool constant, bool sparse) -> Status {
|
||||
return AddInitializedTensor(idx, value, &d, constant, sparse);
|
||||
[this, remove_initializers](const std::string& name, int idx, const OrtValue& value, const OrtCallback& d,
|
||||
bool constant, bool sparse) -> Status {
|
||||
ORT_RETURN_IF_ERROR(AddInitializedTensor(idx, value, &d, constant, sparse));
|
||||
if (remove_initializers) {
|
||||
graph_.RemoveInitializedTensor(name);
|
||||
}
|
||||
return Status::OK();
|
||||
},
|
||||
logger_, data_transfer_mgr_, *p_seq_exec_plan_, session_options, memory_profile_func));
|
||||
|
||||
|
|
|
|||
|
|
@ -230,11 +230,13 @@ common::Status SaveInitializedTensors(
|
|||
};
|
||||
|
||||
// 1. first plan the memory
|
||||
const onnxruntime::InitializedTensorSet& initialized_tensor_set = graph.GetAllInitializedTensors();
|
||||
const InitializedTensorSet& initialized_tensor_set = graph.GetAllInitializedTensors();
|
||||
InlinedHashMap<int, const ONNX_NAMESPACE::TensorProto*> id_to_initialized_tensor;
|
||||
id_to_initialized_tensor.reserve(initialized_tensor_set.size());
|
||||
InlinedHashSet<int> user_supplied_initializer_ids; // set containing the ort value ids of all user supplied initializers
|
||||
|
||||
id_to_initialized_tensor.reserve(initialized_tensor_set.size());
|
||||
user_supplied_initializer_ids.reserve(initialized_tensor_set.size());
|
||||
|
||||
for (const auto& entry : initialized_tensor_set) {
|
||||
int ort_value_index;
|
||||
ORT_RETURN_IF_ERROR(ort_value_name_idx_map.GetIdx(entry.first, ort_value_index));
|
||||
|
|
@ -291,7 +293,13 @@ common::Status SaveInitializedTensors(
|
|||
// 3. create weight tensors based on weights buffer
|
||||
for (const auto& entry : id_to_initialized_tensor) {
|
||||
int ort_value_index = entry.first;
|
||||
const char* name = (entry.second->name().empty()) ? "" : entry.second->name().c_str();
|
||||
const std::string& name = entry.second->name();
|
||||
|
||||
if (name.empty()) {
|
||||
LOGS(logger, INFO) << "Skipping entry for missing optional value at idx " << ort_value_index;
|
||||
continue;
|
||||
}
|
||||
|
||||
OrtValue ort_value;
|
||||
|
||||
if (user_supplied_initializer_ids.find(entry.first) != user_supplied_initializer_ids.end()) {
|
||||
|
|
@ -317,17 +325,19 @@ common::Status SaveInitializedTensors(
|
|||
}
|
||||
}
|
||||
|
||||
// 'name' is a reference to a string within the TensorProto that save_tensor_func may free
|
||||
// so we need to output this message prior to calling save_tensor_func
|
||||
VLOGS(logger, 1) << "Adding weight with name : " << name << " with index: " << ort_value_index;
|
||||
|
||||
// any outer scope value is shadowed by a local value and can't override it.
|
||||
// due to that check_outer_scope is false
|
||||
const bool constant = graph.IsConstantInitializer(name, /* check_outer_scope */ false);
|
||||
#if !defined(DISABLE_SPARSE_TENSORS)
|
||||
const bool sparse = graph.GetGraph().IsSparseInitializer(name);
|
||||
ORT_RETURN_IF_ERROR(save_tensor_func(ort_value_index, ort_value, deleter, constant, sparse));
|
||||
ORT_RETURN_IF_ERROR(save_tensor_func(name, ort_value_index, ort_value, deleter, constant, sparse));
|
||||
#else
|
||||
ORT_RETURN_IF_ERROR(save_tensor_func(ort_value_index, ort_value, deleter, constant, false));
|
||||
ORT_RETURN_IF_ERROR(save_tensor_func(name, ort_value_index, ort_value, deleter, constant, false));
|
||||
#endif
|
||||
|
||||
VLOGS(logger, 1) << "Added weight with name : " << name << " with index: " << ort_value_index;
|
||||
}
|
||||
|
||||
LOGS(logger, INFO) << "Done saving initialized tensors";
|
||||
|
|
|
|||
|
|
@ -30,9 +30,10 @@ class Logger;
|
|||
}
|
||||
|
||||
namespace session_state_utils {
|
||||
using SaveTensorFunction = std::function<Status(int idx, const OrtValue& value, const OrtCallback& d,
|
||||
bool constant, bool sparse)>;
|
||||
using SaveTensorFunction = std::function<Status(const std::string& name, int idx, const OrtValue& value,
|
||||
const OrtCallback& d, bool constant, bool sparse)>;
|
||||
using MemoryProfileFunction = std::function<void(ITensorAllocator& planner)>;
|
||||
|
||||
common::Status SaveInitializedTensors(
|
||||
const Env& env, const std::basic_string<PATH_CHAR_TYPE>& graph_loc,
|
||||
const GraphViewer& graph, const AllocatorPtr& default_cpu_memory_info,
|
||||
|
|
@ -44,6 +45,7 @@ common::Status SaveInitializedTensors(
|
|||
const ExecutionPlanBase& exec_plan,
|
||||
const SessionOptions& session_options,
|
||||
const MemoryProfileFunction& memory_profile_func);
|
||||
|
||||
common::Status SaveInputOutputNamesToNodeMapping(const GraphViewer& graph,
|
||||
SessionState& session_state,
|
||||
gsl::span<const NodeArg* const> implicit_inputs);
|
||||
|
|
|
|||
|
|
@ -9,12 +9,12 @@ common::Status SimpleTensorAllocator::Trace(int /*id*/, const ONNX_NAMESPACE::Te
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
common::Status SimpleTensorAllocator::GetPreallocatedBuffer(int ort_value_index, const char* /*name*/,
|
||||
common::Status SimpleTensorAllocator::GetPreallocatedBuffer(int ort_value_index, const std::string& /*name*/,
|
||||
std::optional<MemBuffer>& /*buf_out*/,
|
||||
AllocatorPtr& alloc_out) {
|
||||
const struct OrtMemoryInfo& location = seq_plan_.GetLocation(ort_value_index);
|
||||
// just return allocator and let others handle it.
|
||||
alloc_out = GetAllocator(location);
|
||||
return Status::OK();
|
||||
// just return allocator and let others handle it.
|
||||
alloc_out = GetAllocator(location);
|
||||
return Status::OK();
|
||||
}
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -30,7 +30,8 @@ class SimpleTensorAllocator : public ITensorAllocator {
|
|||
planned_memory_sizes_in_byte.clear();
|
||||
return Status::OK();
|
||||
}
|
||||
common::Status GetPreallocatedBuffer(int ort_value_index, const char* name, std::optional<MemBuffer>& buf_out, AllocatorPtr& alloc_out) override;
|
||||
common::Status GetPreallocatedBuffer(int ort_value_index, const std::string& name, std::optional<MemBuffer>& buf_out,
|
||||
AllocatorPtr& alloc_out) override;
|
||||
common::Status Trace(int id, const ONNX_NAMESPACE::TensorProto* value) override;
|
||||
const MemoryPatternGroup& GetMemPatterns() override {
|
||||
return mem_patterns_;
|
||||
|
|
|
|||
|
|
@ -41,14 +41,14 @@ class ITensorAllocator {
|
|||
* or, in the case of not reserved tensor, returns an allocator so that
|
||||
* the caller can take care of the dynamic buffer allocation.
|
||||
* buf_out and alloc_out, one and only one can be non-null
|
||||
*
|
||||
* @param ort_value_index [In] int id of the tensor
|
||||
*
|
||||
* @param ort_value_index [In] int id of the tensor
|
||||
* @param name [In] name of the tensor
|
||||
* @param buf_out [Out] pre reserved buffer, if not null
|
||||
* @param alloc_out [Out] allocator based on tensor's location, if not null
|
||||
* @return
|
||||
*/
|
||||
virtual common::Status GetPreallocatedBuffer(int ort_value_index, const char* name,
|
||||
* @return
|
||||
*/
|
||||
virtual common::Status GetPreallocatedBuffer(int ort_value_index, const std::string& name,
|
||||
std::optional<MemBuffer>& buf_out,
|
||||
AllocatorPtr& alloc_out) = 0;
|
||||
|
||||
|
|
|
|||
|
|
@ -74,7 +74,7 @@ class TensorAllocatorWithMemPattern : public ITensorAllocator {
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
common::Status GetPreallocatedBuffer(int ort_value_index, const char* name,
|
||||
common::Status GetPreallocatedBuffer(int ort_value_index, const std::string& name,
|
||||
std::optional<MemBuffer>& buf_out, AllocatorPtr& alloc_out) override {
|
||||
if (!is_sealed_) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Internal error.");
|
||||
|
|
|
|||
Loading…
Reference in a new issue