Incrementally free initializers while saving to OrtValue instances (#12485)

* Free initializer TensorProto instances as they're converted to OrtValue to reduce peak memory usage.

Co-authored-by: Pranav Sharma <prs@microsoft.com>
This commit is contained in:
Scott McKay 2022-08-09 10:59:10 +10:00 committed by GitHub
parent 730240d2a5
commit 56bd96a3f5
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
7 changed files with 40 additions and 22 deletions

View file

@ -1468,8 +1468,13 @@ Status SessionState::FinalizeSessionStateImpl(const std::basic_string<PATH_CHAR_
Env::Default(), graph_location, *graph_viewer_,
execution_providers_.GetDefaultCpuAllocator(),
ort_value_name_idx_map_, initializer_allocation_order, *tensor_allocator,
[this](int idx, const OrtValue& value, const OrtCallback& d, bool constant, bool sparse) -> Status {
return AddInitializedTensor(idx, value, &d, constant, sparse);
[this, remove_initializers](const std::string& name, int idx, const OrtValue& value, const OrtCallback& d,
bool constant, bool sparse) -> Status {
ORT_RETURN_IF_ERROR(AddInitializedTensor(idx, value, &d, constant, sparse));
if (remove_initializers) {
graph_.RemoveInitializedTensor(name);
}
return Status::OK();
},
logger_, data_transfer_mgr_, *p_seq_exec_plan_, session_options, memory_profile_func));

View file

@ -230,11 +230,13 @@ common::Status SaveInitializedTensors(
};
// 1. first plan the memory
const onnxruntime::InitializedTensorSet& initialized_tensor_set = graph.GetAllInitializedTensors();
const InitializedTensorSet& initialized_tensor_set = graph.GetAllInitializedTensors();
InlinedHashMap<int, const ONNX_NAMESPACE::TensorProto*> id_to_initialized_tensor;
id_to_initialized_tensor.reserve(initialized_tensor_set.size());
InlinedHashSet<int> user_supplied_initializer_ids; // set containing the ort value ids of all user supplied initializers
id_to_initialized_tensor.reserve(initialized_tensor_set.size());
user_supplied_initializer_ids.reserve(initialized_tensor_set.size());
for (const auto& entry : initialized_tensor_set) {
int ort_value_index;
ORT_RETURN_IF_ERROR(ort_value_name_idx_map.GetIdx(entry.first, ort_value_index));
@ -291,7 +293,13 @@ common::Status SaveInitializedTensors(
// 3. create weight tensors based on weights buffer
for (const auto& entry : id_to_initialized_tensor) {
int ort_value_index = entry.first;
const char* name = (entry.second->name().empty()) ? "" : entry.second->name().c_str();
const std::string& name = entry.second->name();
if (name.empty()) {
LOGS(logger, INFO) << "Skipping entry for missing optional value at idx " << ort_value_index;
continue;
}
OrtValue ort_value;
if (user_supplied_initializer_ids.find(entry.first) != user_supplied_initializer_ids.end()) {
@ -317,17 +325,19 @@ common::Status SaveInitializedTensors(
}
}
// 'name' is a reference to a string within the TensorProto that save_tensor_func may free
// so we need to output this message prior to calling save_tensor_func
VLOGS(logger, 1) << "Adding weight with name : " << name << " with index: " << ort_value_index;
// any outer scope value is shadowed by a local value and can't override it.
// due to that check_outer_scope is false
const bool constant = graph.IsConstantInitializer(name, /* check_outer_scope */ false);
#if !defined(DISABLE_SPARSE_TENSORS)
const bool sparse = graph.GetGraph().IsSparseInitializer(name);
ORT_RETURN_IF_ERROR(save_tensor_func(ort_value_index, ort_value, deleter, constant, sparse));
ORT_RETURN_IF_ERROR(save_tensor_func(name, ort_value_index, ort_value, deleter, constant, sparse));
#else
ORT_RETURN_IF_ERROR(save_tensor_func(ort_value_index, ort_value, deleter, constant, false));
ORT_RETURN_IF_ERROR(save_tensor_func(name, ort_value_index, ort_value, deleter, constant, false));
#endif
VLOGS(logger, 1) << "Added weight with name : " << name << " with index: " << ort_value_index;
}
LOGS(logger, INFO) << "Done saving initialized tensors";

View file

@ -30,9 +30,10 @@ class Logger;
}
namespace session_state_utils {
using SaveTensorFunction = std::function<Status(int idx, const OrtValue& value, const OrtCallback& d,
bool constant, bool sparse)>;
using SaveTensorFunction = std::function<Status(const std::string& name, int idx, const OrtValue& value,
const OrtCallback& d, bool constant, bool sparse)>;
using MemoryProfileFunction = std::function<void(ITensorAllocator& planner)>;
common::Status SaveInitializedTensors(
const Env& env, const std::basic_string<PATH_CHAR_TYPE>& graph_loc,
const GraphViewer& graph, const AllocatorPtr& default_cpu_memory_info,
@ -44,6 +45,7 @@ common::Status SaveInitializedTensors(
const ExecutionPlanBase& exec_plan,
const SessionOptions& session_options,
const MemoryProfileFunction& memory_profile_func);
common::Status SaveInputOutputNamesToNodeMapping(const GraphViewer& graph,
SessionState& session_state,
gsl::span<const NodeArg* const> implicit_inputs);

View file

@ -9,12 +9,12 @@ common::Status SimpleTensorAllocator::Trace(int /*id*/, const ONNX_NAMESPACE::Te
return Status::OK();
}
common::Status SimpleTensorAllocator::GetPreallocatedBuffer(int ort_value_index, const char* /*name*/,
common::Status SimpleTensorAllocator::GetPreallocatedBuffer(int ort_value_index, const std::string& /*name*/,
std::optional<MemBuffer>& /*buf_out*/,
AllocatorPtr& alloc_out) {
const struct OrtMemoryInfo& location = seq_plan_.GetLocation(ort_value_index);
// just return allocator and let others handle it.
alloc_out = GetAllocator(location);
return Status::OK();
// just return allocator and let others handle it.
alloc_out = GetAllocator(location);
return Status::OK();
}
} // namespace onnxruntime

View file

@ -30,7 +30,8 @@ class SimpleTensorAllocator : public ITensorAllocator {
planned_memory_sizes_in_byte.clear();
return Status::OK();
}
common::Status GetPreallocatedBuffer(int ort_value_index, const char* name, std::optional<MemBuffer>& buf_out, AllocatorPtr& alloc_out) override;
common::Status GetPreallocatedBuffer(int ort_value_index, const std::string& name, std::optional<MemBuffer>& buf_out,
AllocatorPtr& alloc_out) override;
common::Status Trace(int id, const ONNX_NAMESPACE::TensorProto* value) override;
const MemoryPatternGroup& GetMemPatterns() override {
return mem_patterns_;

View file

@ -41,14 +41,14 @@ class ITensorAllocator {
* or, in the case of not reserved tensor, returns an allocator so that
* the caller can take care of the dynamic buffer allocation.
* buf_out and alloc_out, one and only one can be non-null
*
* @param ort_value_index [In] int id of the tensor
*
* @param ort_value_index [In] int id of the tensor
* @param name [In] name of the tensor
* @param buf_out [Out] pre reserved buffer, if not null
* @param alloc_out [Out] allocator based on tensor's location, if not null
* @return
*/
virtual common::Status GetPreallocatedBuffer(int ort_value_index, const char* name,
* @return
*/
virtual common::Status GetPreallocatedBuffer(int ort_value_index, const std::string& name,
std::optional<MemBuffer>& buf_out,
AllocatorPtr& alloc_out) = 0;

View file

@ -74,7 +74,7 @@ class TensorAllocatorWithMemPattern : public ITensorAllocator {
return Status::OK();
}
common::Status GetPreallocatedBuffer(int ort_value_index, const char* name,
common::Status GetPreallocatedBuffer(int ort_value_index, const std::string& name,
std::optional<MemBuffer>& buf_out, AllocatorPtr& alloc_out) override {
if (!is_sealed_) {
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Internal error.");