Partial graph execution perf improvements. (#7438)

* Partial graph execution perf improvements.

* PR feedback.

* Decrement reference count of tensors in ORTModule.

* PR feedback.

* PR feedback.

* PR feedback.
This commit is contained in:
M. Zeeshan Siddiqui 2021-04-26 17:13:55 -07:00 committed by GitHub
parent 0702a14ee7
commit 82108b18e3
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 19 additions and 16 deletions

View file

@ -147,17 +147,9 @@ Status PartialExecutor::Execute(const SessionState& session_state, const std::ve
tp = session_state.Profiler().Now();
}
if (state_.GetExecutionFrame() == nullptr) {
auto frame = onnxruntime::make_unique<ExecutionFrame>(feed_mlvalue_idxs, feeds, fetch_mlvalue_idxs,
fetches, fetch_allocators, session_state);
ExecutionFrame& frame = state_.GetExecutionFrame(feed_mlvalue_idxs, feeds, fetch_mlvalue_idxs, fetches,
fetch_allocators, session_state);
state_.SetExecutionFrame(std::move(frame));
} else {
state_.GetExecutionFrame()->UpdateFeeds(feed_mlvalue_idxs, feeds);
state_.GetExecutionFrame()->UpdateFetches(fetch_mlvalue_idxs, fetches, session_state.GetInitializedTensors());
}
ExecutionFrame& frame = *(state_.GetExecutionFrame());
LOGS(logger, INFO) << "Begin execution";
const SequentialExecutionPlan& seq_exec_plan = *session_state.GetExecutionPlan();
const auto& exec_plan_vec = seq_exec_plan.execution_plan;

View file

@ -22,11 +22,20 @@ struct PartialGraphExecutionState {
size_t GetProgramCounterStart() { return program_counter_start_; }
size_t GetProgramCounterEnd() { return program_counter_end_; }
void SetExecutionFrame(std::unique_ptr<ExecutionFrame> frame) {
execution_frame_ = std::move(frame);
}
ExecutionFrame& GetExecutionFrame(const std::vector<int>& feed_mlvalue_idxs, const std::vector<OrtValue>& feeds,
const std::vector<int>& fetch_mlvalue_idxs, const std::vector<OrtValue>& fetches,
const std::unordered_map<size_t, IExecutor::CustomAllocator>& fetch_allocators,
const SessionState& session_state) {
if (execution_frame_ == nullptr) {
execution_frame_ = onnxruntime::make_unique<ExecutionFrame>(feed_mlvalue_idxs, feeds, fetch_mlvalue_idxs, fetches,
fetch_allocators, session_state);
} else {
execution_frame_->UpdateFeeds(feed_mlvalue_idxs, feeds);
execution_frame_->UpdateFetches(fetch_mlvalue_idxs, fetches, session_state.GetInitializedTensors());
}
const std::unique_ptr<ExecutionFrame>& GetExecutionFrame() const { return execution_frame_; }
return *execution_frame_;
}
private:
std::unique_ptr<ExecutionFrame> execution_frame_;

View file

@ -38,7 +38,7 @@ TrainingAgent::TrainingAgent(InferenceSession& session,
bw_program_counter_end_ = exec_plan_vec.size();
}
TrainingAgent::~TrainingAgent(){};
TrainingAgent::~TrainingAgent() = default;
common::Status TrainingAgent::RunForward(const std::vector<OrtValue>& feeds, std::vector<OrtValue>& fetches,
PartialGraphExecutionState& state) {

View file

@ -14,7 +14,6 @@
namespace onnxruntime {
namespace training {
class IOBinding;
class TrainingAgent {
public:

View file

@ -151,6 +151,9 @@ class TrainingManager(GraphExecutionManager):
backward_outputs = C.OrtValueVector()
self._execution_agent.run_backward(backward_inputs, backward_outputs, ctx.run_info.state)
# Destroy the state immediately (as opposed to be at the mercy of garbage collector) so it does not
# affect peak memory usage in a subsequent graph run.
del ctx.run_info.state
# Return input and initializer gradients
num_user_input_grads = len(self._input_info.require_grad_names)
results = []