diff --git a/onnxruntime/core/framework/orttraining_partial_executor.cc b/onnxruntime/core/framework/orttraining_partial_executor.cc index c38aaf9210..0966f3281d 100644 --- a/onnxruntime/core/framework/orttraining_partial_executor.cc +++ b/onnxruntime/core/framework/orttraining_partial_executor.cc @@ -147,17 +147,9 @@ Status PartialExecutor::Execute(const SessionState& session_state, const std::ve tp = session_state.Profiler().Now(); } - if (state_.GetExecutionFrame() == nullptr) { - auto frame = onnxruntime::make_unique(feed_mlvalue_idxs, feeds, fetch_mlvalue_idxs, - fetches, fetch_allocators, session_state); + ExecutionFrame& frame = state_.GetExecutionFrame(feed_mlvalue_idxs, feeds, fetch_mlvalue_idxs, fetches, + fetch_allocators, session_state); - state_.SetExecutionFrame(std::move(frame)); - } else { - state_.GetExecutionFrame()->UpdateFeeds(feed_mlvalue_idxs, feeds); - state_.GetExecutionFrame()->UpdateFetches(fetch_mlvalue_idxs, fetches, session_state.GetInitializedTensors()); - } - - ExecutionFrame& frame = *(state_.GetExecutionFrame()); LOGS(logger, INFO) << "Begin execution"; const SequentialExecutionPlan& seq_exec_plan = *session_state.GetExecutionPlan(); const auto& exec_plan_vec = seq_exec_plan.execution_plan; diff --git a/onnxruntime/core/framework/partial_graph_execution_state.h b/onnxruntime/core/framework/partial_graph_execution_state.h index 9c4d781a09..ac5568e104 100644 --- a/onnxruntime/core/framework/partial_graph_execution_state.h +++ b/onnxruntime/core/framework/partial_graph_execution_state.h @@ -22,11 +22,20 @@ struct PartialGraphExecutionState { size_t GetProgramCounterStart() { return program_counter_start_; } size_t GetProgramCounterEnd() { return program_counter_end_; } - void SetExecutionFrame(std::unique_ptr frame) { - execution_frame_ = std::move(frame); - } + ExecutionFrame& GetExecutionFrame(const std::vector& feed_mlvalue_idxs, const std::vector& feeds, + const std::vector& fetch_mlvalue_idxs, const std::vector& fetches, + const std::unordered_map& fetch_allocators, + const SessionState& session_state) { + if (execution_frame_ == nullptr) { + execution_frame_ = onnxruntime::make_unique(feed_mlvalue_idxs, feeds, fetch_mlvalue_idxs, fetches, + fetch_allocators, session_state); + } else { + execution_frame_->UpdateFeeds(feed_mlvalue_idxs, feeds); + execution_frame_->UpdateFetches(fetch_mlvalue_idxs, fetches, session_state.GetInitializedTensors()); + } - const std::unique_ptr& GetExecutionFrame() const { return execution_frame_; } + return *execution_frame_; + } private: std::unique_ptr execution_frame_; diff --git a/orttraining/orttraining/core/agent/training_agent.cc b/orttraining/orttraining/core/agent/training_agent.cc index 1ff7b42d5c..ce00baeee3 100644 --- a/orttraining/orttraining/core/agent/training_agent.cc +++ b/orttraining/orttraining/core/agent/training_agent.cc @@ -38,7 +38,7 @@ TrainingAgent::TrainingAgent(InferenceSession& session, bw_program_counter_end_ = exec_plan_vec.size(); } -TrainingAgent::~TrainingAgent(){}; +TrainingAgent::~TrainingAgent() = default; common::Status TrainingAgent::RunForward(const std::vector& feeds, std::vector& fetches, PartialGraphExecutionState& state) { diff --git a/orttraining/orttraining/core/agent/training_agent.h b/orttraining/orttraining/core/agent/training_agent.h index ab32690623..ec391d1b6f 100644 --- a/orttraining/orttraining/core/agent/training_agent.h +++ b/orttraining/orttraining/core/agent/training_agent.h @@ -14,7 +14,6 @@ namespace onnxruntime { namespace training { -class IOBinding; class TrainingAgent { public: diff --git a/orttraining/orttraining/python/training/ortmodule/_training_manager.py b/orttraining/orttraining/python/training/ortmodule/_training_manager.py index f0cedd6aa9..9034bab74e 100644 --- a/orttraining/orttraining/python/training/ortmodule/_training_manager.py +++ b/orttraining/orttraining/python/training/ortmodule/_training_manager.py @@ -151,6 +151,9 @@ class TrainingManager(GraphExecutionManager): backward_outputs = C.OrtValueVector() self._execution_agent.run_backward(backward_inputs, backward_outputs, ctx.run_info.state) + # Destroy the state immediately (as opposed to be at the mercy of garbage collector) so it does not + # affect peak memory usage in a subsequent graph run. + del ctx.run_info.state # Return input and initializer gradients num_user_input_grads = len(self._input_info.require_grad_names) results = []