Flag for tensor memory re-use in allocation planner. (#7359)

This commit is contained in:
M. Zeeshan Siddiqui 2021-04-16 17:53:25 -07:00 committed by GitHub
parent 96cdc65d57
commit 6dda1e0681
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
11 changed files with 38 additions and 7 deletions

View file

@ -389,6 +389,9 @@ class PlannerImpl {
// Find if freelist contains a buffer of the same size as output_arg
bool FindReusableTensor(const onnxruntime::NodeArg& output_arg, OrtValueIndex* reusable_tensor) {
if(!context_.GetEnableMemoryReuse()) {
return false;
}
auto p_required_buffer_shape = context_.GetShape(output_arg);
if (nullptr == p_required_buffer_shape || p_required_buffer_shape->dim_size() == 0) return false;
auto& required_memory_info = AllocPlan(output_arg.Name()).location;

View file

@ -30,13 +30,16 @@ class ISequentialPlannerContext {
virtual bool IsParallelExecutionEnabled() const { return false; }
virtual ExecutionOrder GetExecutionOrder() const { return ExecutionOrder::DEFAULT; }
virtual bool GetEnableMemoryReuse() const { return true; }
};
class SequentialPlannerContext : public ISequentialPlannerContext {
public:
SequentialPlannerContext(ExecutionMode execution_mode, ExecutionOrder execution_order)
SequentialPlannerContext(ExecutionMode execution_mode, ExecutionOrder execution_order, bool enable_memory_reuse)
: execution_mode_(execution_mode),
exection_order_(execution_order) {
exection_order_(execution_order),
enable_memory_reuse_(enable_memory_reuse) {
}
const ONNX_NAMESPACE::TensorShapeProto* GetShape(const onnxruntime::NodeArg& arg) const override {
@ -47,9 +50,12 @@ class SequentialPlannerContext : public ISequentialPlannerContext {
ExecutionOrder GetExecutionOrder() const override { return exection_order_; }
bool GetEnableMemoryReuse() const override { return enable_memory_reuse_; }
private:
ExecutionMode execution_mode_ = ExecutionMode::ORT_SEQUENTIAL;
ExecutionOrder exection_order_ = ExecutionOrder::DEFAULT;
bool enable_memory_reuse_ = true;
};
class SequentialPlanner {

View file

@ -64,6 +64,11 @@ struct SessionOptions {
// See class 'OrtValuePatternPlanner'.
bool enable_mem_pattern = true;
// Enable memory resue in memory planning. Allows to reuse tensor buffer between tensors if they are of
// the same size. The issue with this is it can lead to memory being held for longer than needed and
// can impact peak memory consumption.
bool enable_mem_reuse = true;
// enable the memory arena on CPU
// Arena may pre-allocate memory for future usage.
// set this option to false if you don't want it.

View file

@ -580,6 +580,8 @@ Status SessionState::UpdateMemoryPatternGroupCache(const std::vector<std::refere
bool SessionState::GetEnableMemoryPattern() const { return enable_mem_pattern_; }
bool SessionState::GetEnableMemoryReuse() const { return enable_mem_reuse_; }
common::Status SessionState::AddInputNameToNodeInfoMapping(const std::string& input_name, const NodeInfo& node_info) {
// Graph partitioning should ensure an input is only consumed from one device. Copy nodes should have been inserted
// to handle a scenario where an input is required on different devices by different nodes. Validate that.
@ -998,7 +1000,7 @@ Status SessionState::FinalizeSessionStateImpl(const std::basic_string<PATH_CHAR_
});
}
SequentialPlannerContext context(session_options.execution_mode, session_options.execution_order);
SequentialPlannerContext context(session_options.execution_mode, session_options.execution_order, session_options.enable_mem_reuse);
ORT_RETURN_IF_ERROR(SequentialPlanner::CreatePlan(parent_node, *graph_viewer_, valid_outer_scope_node_args,
execution_providers_, kernel_create_info_map_,
ort_value_name_idx_map_, context, p_seq_exec_plan_));

View file

@ -86,7 +86,8 @@ class SessionState {
const DataTransferManager& data_transfer_mgr,
const logging::Logger& logger,
profiling::Profiler& profiler,
bool use_deterministic_compute = false)
bool use_deterministic_compute = false,
bool enable_mem_reuse = true)
: graph_(graph),
execution_providers_(execution_providers),
logger_(logger),
@ -95,7 +96,8 @@ class SessionState {
thread_pool_(thread_pool),
inter_op_thread_pool_(inter_op_thread_pool),
data_transfer_mgr_(data_transfer_mgr),
use_deterministic_compute_(use_deterministic_compute) {
use_deterministic_compute_(use_deterministic_compute),
enable_mem_reuse_(enable_mem_reuse) {
SetupAllocators();
}
@ -212,6 +214,12 @@ class SessionState {
*/
bool GetEnableMemoryPattern() const;
/**
Get enable memory re-use flag.
*/
bool GetEnableMemoryReuse() const;
/**
Update enable_mem_pattern_ flag according to the presence of graph inputs' shape
If any one of the graph input is shapeless, enable_mem_pattern_ will be set to false
@ -438,7 +446,7 @@ class SessionState {
const DataTransferManager& data_transfer_mgr_;
bool use_deterministic_compute_;
bool enable_mem_reuse_;
std::unique_ptr<NodeIndexInfo> node_index_info_;
std::multimap<int, std::unique_ptr<FeedsFetchesManager>> cached_feeds_fetches_managers_;

View file

@ -1186,7 +1186,8 @@ common::Status InferenceSession::Initialize() {
data_transfer_mgr_,
*session_logger_,
session_profiler_,
session_options_.use_deterministic_compute);
session_options_.use_deterministic_compute,
session_options_.enable_mem_reuse);
onnxruntime::Graph& graph = model_->MainGraph();

View file

@ -1613,6 +1613,8 @@ Serialized model format will default to ONNX unless:
)pbdoc")
.def_readwrite("enable_mem_pattern", &PySessionOptions::enable_mem_pattern,
R"pbdoc(Enable the memory pattern optimization. Default is true.)pbdoc")
.def_readwrite("enable_mem_reuse", &PySessionOptions::enable_mem_reuse,
R"pbdoc(Enable the memory reuse optimization. Default is true.)pbdoc")
.def_readwrite("logid", &PySessionOptions::session_logid,
R"pbdoc(Logger id to use for session output.)pbdoc")
.def_readwrite("log_severity_level", &PySessionOptions::session_log_severity_level,

View file

@ -39,6 +39,7 @@ static SessionOptions session_options = {
false, //enable_profiling
ORT_TSTR(""), //optimized_model_filepath
true, //enable_mem_pattern
true, //enable_mem_reuse
true, //enable_cpu_mem_arena
ORT_TSTR("onnxruntime_profile_"), //profile_file_prefix
"", //session_logid

View file

@ -90,6 +90,7 @@ int main(int argc, char* argv[]) {
false, //enable_profiling
ORT_TSTR(""), //optimized_model_filepath
true, //enable_mem_pattern
true, //enable_mem_reuse
true, //enable_cpu_mem_arena
ORT_TSTR("onnxruntime_profile_"), //profile_file_prefix
"", //session_logid

View file

@ -39,6 +39,7 @@ static SessionOptions SESSION_OPTION = {
false, //enable_profiling
ORT_TSTR(""), //optimized_model_filepath
true, //enable_mem_pattern
true, //enable_mem_reuse
true, //enable_cpu_mem_arena
ORT_TSTR("onnxruntime_profile_"), //profile_file_prefix
"", //session_logid

View file

@ -168,6 +168,7 @@ class GraphExecutionManager(ABC):
session_options = onnxruntime.SessionOptions()
session_options.enable_mem_pattern = False
session_options.enable_mem_reuse = False
session_options.use_deterministic_compute = False
# default to PRIORITY_BASED execution order
session_options.execution_order = onnxruntime.ExecutionOrder.PRIORITY_BASED