mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-22 22:01:08 +00:00
Flag for tensor memory re-use in allocation planner. (#7359)
This commit is contained in:
parent
96cdc65d57
commit
6dda1e0681
11 changed files with 38 additions and 7 deletions
|
|
@ -389,6 +389,9 @@ class PlannerImpl {
|
|||
|
||||
// Find if freelist contains a buffer of the same size as output_arg
|
||||
bool FindReusableTensor(const onnxruntime::NodeArg& output_arg, OrtValueIndex* reusable_tensor) {
|
||||
if(!context_.GetEnableMemoryReuse()) {
|
||||
return false;
|
||||
}
|
||||
auto p_required_buffer_shape = context_.GetShape(output_arg);
|
||||
if (nullptr == p_required_buffer_shape || p_required_buffer_shape->dim_size() == 0) return false;
|
||||
auto& required_memory_info = AllocPlan(output_arg.Name()).location;
|
||||
|
|
|
|||
|
|
@ -30,13 +30,16 @@ class ISequentialPlannerContext {
|
|||
virtual bool IsParallelExecutionEnabled() const { return false; }
|
||||
|
||||
virtual ExecutionOrder GetExecutionOrder() const { return ExecutionOrder::DEFAULT; }
|
||||
|
||||
virtual bool GetEnableMemoryReuse() const { return true; }
|
||||
};
|
||||
|
||||
class SequentialPlannerContext : public ISequentialPlannerContext {
|
||||
public:
|
||||
SequentialPlannerContext(ExecutionMode execution_mode, ExecutionOrder execution_order)
|
||||
SequentialPlannerContext(ExecutionMode execution_mode, ExecutionOrder execution_order, bool enable_memory_reuse)
|
||||
: execution_mode_(execution_mode),
|
||||
exection_order_(execution_order) {
|
||||
exection_order_(execution_order),
|
||||
enable_memory_reuse_(enable_memory_reuse) {
|
||||
}
|
||||
|
||||
const ONNX_NAMESPACE::TensorShapeProto* GetShape(const onnxruntime::NodeArg& arg) const override {
|
||||
|
|
@ -47,9 +50,12 @@ class SequentialPlannerContext : public ISequentialPlannerContext {
|
|||
|
||||
ExecutionOrder GetExecutionOrder() const override { return exection_order_; }
|
||||
|
||||
bool GetEnableMemoryReuse() const override { return enable_memory_reuse_; }
|
||||
|
||||
private:
|
||||
ExecutionMode execution_mode_ = ExecutionMode::ORT_SEQUENTIAL;
|
||||
ExecutionOrder exection_order_ = ExecutionOrder::DEFAULT;
|
||||
bool enable_memory_reuse_ = true;
|
||||
};
|
||||
|
||||
class SequentialPlanner {
|
||||
|
|
|
|||
|
|
@ -64,6 +64,11 @@ struct SessionOptions {
|
|||
// See class 'OrtValuePatternPlanner'.
|
||||
bool enable_mem_pattern = true;
|
||||
|
||||
// Enable memory resue in memory planning. Allows to reuse tensor buffer between tensors if they are of
|
||||
// the same size. The issue with this is it can lead to memory being held for longer than needed and
|
||||
// can impact peak memory consumption.
|
||||
bool enable_mem_reuse = true;
|
||||
|
||||
// enable the memory arena on CPU
|
||||
// Arena may pre-allocate memory for future usage.
|
||||
// set this option to false if you don't want it.
|
||||
|
|
|
|||
|
|
@ -580,6 +580,8 @@ Status SessionState::UpdateMemoryPatternGroupCache(const std::vector<std::refere
|
|||
|
||||
bool SessionState::GetEnableMemoryPattern() const { return enable_mem_pattern_; }
|
||||
|
||||
bool SessionState::GetEnableMemoryReuse() const { return enable_mem_reuse_; }
|
||||
|
||||
common::Status SessionState::AddInputNameToNodeInfoMapping(const std::string& input_name, const NodeInfo& node_info) {
|
||||
// Graph partitioning should ensure an input is only consumed from one device. Copy nodes should have been inserted
|
||||
// to handle a scenario where an input is required on different devices by different nodes. Validate that.
|
||||
|
|
@ -998,7 +1000,7 @@ Status SessionState::FinalizeSessionStateImpl(const std::basic_string<PATH_CHAR_
|
|||
});
|
||||
}
|
||||
|
||||
SequentialPlannerContext context(session_options.execution_mode, session_options.execution_order);
|
||||
SequentialPlannerContext context(session_options.execution_mode, session_options.execution_order, session_options.enable_mem_reuse);
|
||||
ORT_RETURN_IF_ERROR(SequentialPlanner::CreatePlan(parent_node, *graph_viewer_, valid_outer_scope_node_args,
|
||||
execution_providers_, kernel_create_info_map_,
|
||||
ort_value_name_idx_map_, context, p_seq_exec_plan_));
|
||||
|
|
|
|||
|
|
@ -86,7 +86,8 @@ class SessionState {
|
|||
const DataTransferManager& data_transfer_mgr,
|
||||
const logging::Logger& logger,
|
||||
profiling::Profiler& profiler,
|
||||
bool use_deterministic_compute = false)
|
||||
bool use_deterministic_compute = false,
|
||||
bool enable_mem_reuse = true)
|
||||
: graph_(graph),
|
||||
execution_providers_(execution_providers),
|
||||
logger_(logger),
|
||||
|
|
@ -95,7 +96,8 @@ class SessionState {
|
|||
thread_pool_(thread_pool),
|
||||
inter_op_thread_pool_(inter_op_thread_pool),
|
||||
data_transfer_mgr_(data_transfer_mgr),
|
||||
use_deterministic_compute_(use_deterministic_compute) {
|
||||
use_deterministic_compute_(use_deterministic_compute),
|
||||
enable_mem_reuse_(enable_mem_reuse) {
|
||||
SetupAllocators();
|
||||
}
|
||||
|
||||
|
|
@ -212,6 +214,12 @@ class SessionState {
|
|||
*/
|
||||
bool GetEnableMemoryPattern() const;
|
||||
|
||||
/**
|
||||
Get enable memory re-use flag.
|
||||
*/
|
||||
|
||||
bool GetEnableMemoryReuse() const;
|
||||
|
||||
/**
|
||||
Update enable_mem_pattern_ flag according to the presence of graph inputs' shape
|
||||
If any one of the graph input is shapeless, enable_mem_pattern_ will be set to false
|
||||
|
|
@ -438,7 +446,7 @@ class SessionState {
|
|||
const DataTransferManager& data_transfer_mgr_;
|
||||
|
||||
bool use_deterministic_compute_;
|
||||
|
||||
bool enable_mem_reuse_;
|
||||
std::unique_ptr<NodeIndexInfo> node_index_info_;
|
||||
std::multimap<int, std::unique_ptr<FeedsFetchesManager>> cached_feeds_fetches_managers_;
|
||||
|
||||
|
|
|
|||
|
|
@ -1186,7 +1186,8 @@ common::Status InferenceSession::Initialize() {
|
|||
data_transfer_mgr_,
|
||||
*session_logger_,
|
||||
session_profiler_,
|
||||
session_options_.use_deterministic_compute);
|
||||
session_options_.use_deterministic_compute,
|
||||
session_options_.enable_mem_reuse);
|
||||
|
||||
onnxruntime::Graph& graph = model_->MainGraph();
|
||||
|
||||
|
|
|
|||
|
|
@ -1613,6 +1613,8 @@ Serialized model format will default to ONNX unless:
|
|||
)pbdoc")
|
||||
.def_readwrite("enable_mem_pattern", &PySessionOptions::enable_mem_pattern,
|
||||
R"pbdoc(Enable the memory pattern optimization. Default is true.)pbdoc")
|
||||
.def_readwrite("enable_mem_reuse", &PySessionOptions::enable_mem_reuse,
|
||||
R"pbdoc(Enable the memory reuse optimization. Default is true.)pbdoc")
|
||||
.def_readwrite("logid", &PySessionOptions::session_logid,
|
||||
R"pbdoc(Logger id to use for session output.)pbdoc")
|
||||
.def_readwrite("log_severity_level", &PySessionOptions::session_log_severity_level,
|
||||
|
|
|
|||
|
|
@ -39,6 +39,7 @@ static SessionOptions session_options = {
|
|||
false, //enable_profiling
|
||||
ORT_TSTR(""), //optimized_model_filepath
|
||||
true, //enable_mem_pattern
|
||||
true, //enable_mem_reuse
|
||||
true, //enable_cpu_mem_arena
|
||||
ORT_TSTR("onnxruntime_profile_"), //profile_file_prefix
|
||||
"", //session_logid
|
||||
|
|
|
|||
|
|
@ -90,6 +90,7 @@ int main(int argc, char* argv[]) {
|
|||
false, //enable_profiling
|
||||
ORT_TSTR(""), //optimized_model_filepath
|
||||
true, //enable_mem_pattern
|
||||
true, //enable_mem_reuse
|
||||
true, //enable_cpu_mem_arena
|
||||
ORT_TSTR("onnxruntime_profile_"), //profile_file_prefix
|
||||
"", //session_logid
|
||||
|
|
|
|||
|
|
@ -39,6 +39,7 @@ static SessionOptions SESSION_OPTION = {
|
|||
false, //enable_profiling
|
||||
ORT_TSTR(""), //optimized_model_filepath
|
||||
true, //enable_mem_pattern
|
||||
true, //enable_mem_reuse
|
||||
true, //enable_cpu_mem_arena
|
||||
ORT_TSTR("onnxruntime_profile_"), //profile_file_prefix
|
||||
"", //session_logid
|
||||
|
|
|
|||
|
|
@ -168,6 +168,7 @@ class GraphExecutionManager(ABC):
|
|||
|
||||
session_options = onnxruntime.SessionOptions()
|
||||
session_options.enable_mem_pattern = False
|
||||
session_options.enable_mem_reuse = False
|
||||
session_options.use_deterministic_compute = False
|
||||
# default to PRIORITY_BASED execution order
|
||||
session_options.execution_order = onnxruntime.ExecutionOrder.PRIORITY_BASED
|
||||
|
|
|
|||
Loading…
Reference in a new issue