diff --git a/onnxruntime/core/framework/allocation_planner.cc b/onnxruntime/core/framework/allocation_planner.cc index d414f55641..455434628c 100644 --- a/onnxruntime/core/framework/allocation_planner.cc +++ b/onnxruntime/core/framework/allocation_planner.cc @@ -389,6 +389,9 @@ class PlannerImpl { // Find if freelist contains a buffer of the same size as output_arg bool FindReusableTensor(const onnxruntime::NodeArg& output_arg, OrtValueIndex* reusable_tensor) { + if(!context_.GetEnableMemoryReuse()) { + return false; + } auto p_required_buffer_shape = context_.GetShape(output_arg); if (nullptr == p_required_buffer_shape || p_required_buffer_shape->dim_size() == 0) return false; auto& required_memory_info = AllocPlan(output_arg.Name()).location; diff --git a/onnxruntime/core/framework/allocation_planner.h b/onnxruntime/core/framework/allocation_planner.h index 3cfe12b930..3bafc05a9d 100644 --- a/onnxruntime/core/framework/allocation_planner.h +++ b/onnxruntime/core/framework/allocation_planner.h @@ -30,13 +30,16 @@ class ISequentialPlannerContext { virtual bool IsParallelExecutionEnabled() const { return false; } virtual ExecutionOrder GetExecutionOrder() const { return ExecutionOrder::DEFAULT; } + + virtual bool GetEnableMemoryReuse() const { return true; } }; class SequentialPlannerContext : public ISequentialPlannerContext { public: - SequentialPlannerContext(ExecutionMode execution_mode, ExecutionOrder execution_order) + SequentialPlannerContext(ExecutionMode execution_mode, ExecutionOrder execution_order, bool enable_memory_reuse) : execution_mode_(execution_mode), - exection_order_(execution_order) { + exection_order_(execution_order), + enable_memory_reuse_(enable_memory_reuse) { } const ONNX_NAMESPACE::TensorShapeProto* GetShape(const onnxruntime::NodeArg& arg) const override { @@ -47,9 +50,12 @@ class SequentialPlannerContext : public ISequentialPlannerContext { ExecutionOrder GetExecutionOrder() const override { return exection_order_; } + bool GetEnableMemoryReuse() const override { return enable_memory_reuse_; } + private: ExecutionMode execution_mode_ = ExecutionMode::ORT_SEQUENTIAL; ExecutionOrder exection_order_ = ExecutionOrder::DEFAULT; + bool enable_memory_reuse_ = true; }; class SequentialPlanner { diff --git a/onnxruntime/core/framework/session_options.h b/onnxruntime/core/framework/session_options.h index 055f2f6906..9e7592e844 100644 --- a/onnxruntime/core/framework/session_options.h +++ b/onnxruntime/core/framework/session_options.h @@ -64,6 +64,11 @@ struct SessionOptions { // See class 'OrtValuePatternPlanner'. bool enable_mem_pattern = true; + // Enable memory resue in memory planning. Allows to reuse tensor buffer between tensors if they are of + // the same size. The issue with this is it can lead to memory being held for longer than needed and + // can impact peak memory consumption. + bool enable_mem_reuse = true; + // enable the memory arena on CPU // Arena may pre-allocate memory for future usage. // set this option to false if you don't want it. diff --git a/onnxruntime/core/framework/session_state.cc b/onnxruntime/core/framework/session_state.cc index 69cb0377b4..6022e2968a 100644 --- a/onnxruntime/core/framework/session_state.cc +++ b/onnxruntime/core/framework/session_state.cc @@ -580,6 +580,8 @@ Status SessionState::UpdateMemoryPatternGroupCache(const std::vector node_index_info_; std::multimap> cached_feeds_fetches_managers_; diff --git a/onnxruntime/core/session/inference_session.cc b/onnxruntime/core/session/inference_session.cc index a3a2f93f1b..683e028b13 100644 --- a/onnxruntime/core/session/inference_session.cc +++ b/onnxruntime/core/session/inference_session.cc @@ -1186,7 +1186,8 @@ common::Status InferenceSession::Initialize() { data_transfer_mgr_, *session_logger_, session_profiler_, - session_options_.use_deterministic_compute); + session_options_.use_deterministic_compute, + session_options_.enable_mem_reuse); onnxruntime::Graph& graph = model_->MainGraph(); diff --git a/onnxruntime/python/onnxruntime_pybind_state.cc b/onnxruntime/python/onnxruntime_pybind_state.cc index 64071a89bc..ebd296747a 100644 --- a/onnxruntime/python/onnxruntime_pybind_state.cc +++ b/onnxruntime/python/onnxruntime_pybind_state.cc @@ -1613,6 +1613,8 @@ Serialized model format will default to ONNX unless: )pbdoc") .def_readwrite("enable_mem_pattern", &PySessionOptions::enable_mem_pattern, R"pbdoc(Enable the memory pattern optimization. Default is true.)pbdoc") + .def_readwrite("enable_mem_reuse", &PySessionOptions::enable_mem_reuse, + R"pbdoc(Enable the memory reuse optimization. Default is true.)pbdoc") .def_readwrite("logid", &PySessionOptions::session_logid, R"pbdoc(Logger id to use for session output.)pbdoc") .def_readwrite("log_severity_level", &PySessionOptions::session_log_severity_level, diff --git a/orttraining/orttraining/models/bert/main.cc b/orttraining/orttraining/models/bert/main.cc index dc4d3044c5..f639a43bb2 100644 --- a/orttraining/orttraining/models/bert/main.cc +++ b/orttraining/orttraining/models/bert/main.cc @@ -39,6 +39,7 @@ static SessionOptions session_options = { false, //enable_profiling ORT_TSTR(""), //optimized_model_filepath true, //enable_mem_pattern + true, //enable_mem_reuse true, //enable_cpu_mem_arena ORT_TSTR("onnxruntime_profile_"), //profile_file_prefix "", //session_logid diff --git a/orttraining/orttraining/models/pipeline_poc/main.cc b/orttraining/orttraining/models/pipeline_poc/main.cc index 04a459f171..6d44578671 100644 --- a/orttraining/orttraining/models/pipeline_poc/main.cc +++ b/orttraining/orttraining/models/pipeline_poc/main.cc @@ -90,6 +90,7 @@ int main(int argc, char* argv[]) { false, //enable_profiling ORT_TSTR(""), //optimized_model_filepath true, //enable_mem_pattern + true, //enable_mem_reuse true, //enable_cpu_mem_arena ORT_TSTR("onnxruntime_profile_"), //profile_file_prefix "", //session_logid diff --git a/orttraining/orttraining/models/runner/training_runner.cc b/orttraining/orttraining/models/runner/training_runner.cc index c908490066..090153b69e 100644 --- a/orttraining/orttraining/models/runner/training_runner.cc +++ b/orttraining/orttraining/models/runner/training_runner.cc @@ -39,6 +39,7 @@ static SessionOptions SESSION_OPTION = { false, //enable_profiling ORT_TSTR(""), //optimized_model_filepath true, //enable_mem_pattern + true, //enable_mem_reuse true, //enable_cpu_mem_arena ORT_TSTR("onnxruntime_profile_"), //profile_file_prefix "", //session_logid diff --git a/orttraining/orttraining/python/training/_ortmodule_graph_execution_manager.py b/orttraining/orttraining/python/training/_ortmodule_graph_execution_manager.py index 2b71d39491..0623c092ad 100644 --- a/orttraining/orttraining/python/training/_ortmodule_graph_execution_manager.py +++ b/orttraining/orttraining/python/training/_ortmodule_graph_execution_manager.py @@ -168,6 +168,7 @@ class GraphExecutionManager(ABC): session_options = onnxruntime.SessionOptions() session_options.enable_mem_pattern = False + session_options.enable_mem_reuse = False session_options.use_deterministic_compute = False # default to PRIORITY_BASED execution order session_options.execution_order = onnxruntime.ExecutionOrder.PRIORITY_BASED