diff --git a/onnxruntime/core/framework/session_state.cc b/onnxruntime/core/framework/session_state.cc index b7aacdc315..5f25d23225 100644 --- a/onnxruntime/core/framework/session_state.cc +++ b/onnxruntime/core/framework/session_state.cc @@ -1468,8 +1468,13 @@ Status SessionState::FinalizeSessionStateImpl(const std::basic_string Status { - return AddInitializedTensor(idx, value, &d, constant, sparse); + [this, remove_initializers](const std::string& name, int idx, const OrtValue& value, const OrtCallback& d, + bool constant, bool sparse) -> Status { + ORT_RETURN_IF_ERROR(AddInitializedTensor(idx, value, &d, constant, sparse)); + if (remove_initializers) { + graph_.RemoveInitializedTensor(name); + } + return Status::OK(); }, logger_, data_transfer_mgr_, *p_seq_exec_plan_, session_options, memory_profile_func)); diff --git a/onnxruntime/core/framework/session_state_utils.cc b/onnxruntime/core/framework/session_state_utils.cc index d305d8463d..3bb41ad251 100644 --- a/onnxruntime/core/framework/session_state_utils.cc +++ b/onnxruntime/core/framework/session_state_utils.cc @@ -230,11 +230,13 @@ common::Status SaveInitializedTensors( }; // 1. first plan the memory - const onnxruntime::InitializedTensorSet& initialized_tensor_set = graph.GetAllInitializedTensors(); + const InitializedTensorSet& initialized_tensor_set = graph.GetAllInitializedTensors(); InlinedHashMap id_to_initialized_tensor; - id_to_initialized_tensor.reserve(initialized_tensor_set.size()); InlinedHashSet user_supplied_initializer_ids; // set containing the ort value ids of all user supplied initializers + + id_to_initialized_tensor.reserve(initialized_tensor_set.size()); user_supplied_initializer_ids.reserve(initialized_tensor_set.size()); + for (const auto& entry : initialized_tensor_set) { int ort_value_index; ORT_RETURN_IF_ERROR(ort_value_name_idx_map.GetIdx(entry.first, ort_value_index)); @@ -291,7 +293,13 @@ common::Status SaveInitializedTensors( // 3. create weight tensors based on weights buffer for (const auto& entry : id_to_initialized_tensor) { int ort_value_index = entry.first; - const char* name = (entry.second->name().empty()) ? "" : entry.second->name().c_str(); + const std::string& name = entry.second->name(); + + if (name.empty()) { + LOGS(logger, INFO) << "Skipping entry for missing optional value at idx " << ort_value_index; + continue; + } + OrtValue ort_value; if (user_supplied_initializer_ids.find(entry.first) != user_supplied_initializer_ids.end()) { @@ -317,17 +325,19 @@ common::Status SaveInitializedTensors( } } + // 'name' is a reference to a string within the TensorProto that save_tensor_func may free + // so we need to output this message prior to calling save_tensor_func + VLOGS(logger, 1) << "Adding weight with name : " << name << " with index: " << ort_value_index; + // any outer scope value is shadowed by a local value and can't override it. // due to that check_outer_scope is false const bool constant = graph.IsConstantInitializer(name, /* check_outer_scope */ false); #if !defined(DISABLE_SPARSE_TENSORS) const bool sparse = graph.GetGraph().IsSparseInitializer(name); - ORT_RETURN_IF_ERROR(save_tensor_func(ort_value_index, ort_value, deleter, constant, sparse)); + ORT_RETURN_IF_ERROR(save_tensor_func(name, ort_value_index, ort_value, deleter, constant, sparse)); #else - ORT_RETURN_IF_ERROR(save_tensor_func(ort_value_index, ort_value, deleter, constant, false)); + ORT_RETURN_IF_ERROR(save_tensor_func(name, ort_value_index, ort_value, deleter, constant, false)); #endif - - VLOGS(logger, 1) << "Added weight with name : " << name << " with index: " << ort_value_index; } LOGS(logger, INFO) << "Done saving initialized tensors"; diff --git a/onnxruntime/core/framework/session_state_utils.h b/onnxruntime/core/framework/session_state_utils.h index a30dc9585c..7dbf91de44 100644 --- a/onnxruntime/core/framework/session_state_utils.h +++ b/onnxruntime/core/framework/session_state_utils.h @@ -30,9 +30,10 @@ class Logger; } namespace session_state_utils { -using SaveTensorFunction = std::function; +using SaveTensorFunction = std::function; using MemoryProfileFunction = std::function; + common::Status SaveInitializedTensors( const Env& env, const std::basic_string& graph_loc, const GraphViewer& graph, const AllocatorPtr& default_cpu_memory_info, @@ -44,6 +45,7 @@ common::Status SaveInitializedTensors( const ExecutionPlanBase& exec_plan, const SessionOptions& session_options, const MemoryProfileFunction& memory_profile_func); + common::Status SaveInputOutputNamesToNodeMapping(const GraphViewer& graph, SessionState& session_state, gsl::span implicit_inputs); diff --git a/onnxruntime/core/framework/simple_tensor_allocator.cc b/onnxruntime/core/framework/simple_tensor_allocator.cc index f76d4aaec2..3d8a935c3b 100644 --- a/onnxruntime/core/framework/simple_tensor_allocator.cc +++ b/onnxruntime/core/framework/simple_tensor_allocator.cc @@ -9,12 +9,12 @@ common::Status SimpleTensorAllocator::Trace(int /*id*/, const ONNX_NAMESPACE::Te return Status::OK(); } -common::Status SimpleTensorAllocator::GetPreallocatedBuffer(int ort_value_index, const char* /*name*/, +common::Status SimpleTensorAllocator::GetPreallocatedBuffer(int ort_value_index, const std::string& /*name*/, std::optional& /*buf_out*/, AllocatorPtr& alloc_out) { const struct OrtMemoryInfo& location = seq_plan_.GetLocation(ort_value_index); - // just return allocator and let others handle it. - alloc_out = GetAllocator(location); - return Status::OK(); + // just return allocator and let others handle it. + alloc_out = GetAllocator(location); + return Status::OK(); } } // namespace onnxruntime diff --git a/onnxruntime/core/framework/simple_tensor_allocator.h b/onnxruntime/core/framework/simple_tensor_allocator.h index d2d61fa4b0..ecdee8a12f 100644 --- a/onnxruntime/core/framework/simple_tensor_allocator.h +++ b/onnxruntime/core/framework/simple_tensor_allocator.h @@ -30,7 +30,8 @@ class SimpleTensorAllocator : public ITensorAllocator { planned_memory_sizes_in_byte.clear(); return Status::OK(); } - common::Status GetPreallocatedBuffer(int ort_value_index, const char* name, std::optional& buf_out, AllocatorPtr& alloc_out) override; + common::Status GetPreallocatedBuffer(int ort_value_index, const std::string& name, std::optional& buf_out, + AllocatorPtr& alloc_out) override; common::Status Trace(int id, const ONNX_NAMESPACE::TensorProto* value) override; const MemoryPatternGroup& GetMemPatterns() override { return mem_patterns_; diff --git a/onnxruntime/core/framework/tensor_allocator.h b/onnxruntime/core/framework/tensor_allocator.h index cd5610b2a2..2ab683bae7 100644 --- a/onnxruntime/core/framework/tensor_allocator.h +++ b/onnxruntime/core/framework/tensor_allocator.h @@ -41,14 +41,14 @@ class ITensorAllocator { * or, in the case of not reserved tensor, returns an allocator so that * the caller can take care of the dynamic buffer allocation. * buf_out and alloc_out, one and only one can be non-null - * - * @param ort_value_index [In] int id of the tensor + * + * @param ort_value_index [In] int id of the tensor * @param name [In] name of the tensor * @param buf_out [Out] pre reserved buffer, if not null * @param alloc_out [Out] allocator based on tensor's location, if not null - * @return - */ - virtual common::Status GetPreallocatedBuffer(int ort_value_index, const char* name, + * @return + */ + virtual common::Status GetPreallocatedBuffer(int ort_value_index, const std::string& name, std::optional& buf_out, AllocatorPtr& alloc_out) = 0; diff --git a/onnxruntime/core/framework/tensor_allocator_with_mem_pattern.h b/onnxruntime/core/framework/tensor_allocator_with_mem_pattern.h index 1bc2a1469d..5764a02d9e 100644 --- a/onnxruntime/core/framework/tensor_allocator_with_mem_pattern.h +++ b/onnxruntime/core/framework/tensor_allocator_with_mem_pattern.h @@ -74,7 +74,7 @@ class TensorAllocatorWithMemPattern : public ITensorAllocator { return Status::OK(); } - common::Status GetPreallocatedBuffer(int ort_value_index, const char* name, + common::Status GetPreallocatedBuffer(int ort_value_index, const std::string& name, std::optional& buf_out, AllocatorPtr& alloc_out) override { if (!is_sealed_) { return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Internal error.");