mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-07-02 03:55:34 +00:00
Partial graph execution perf improvements. (#7438)
* Partial graph execution perf improvements. * PR feedback. * Decrement reference count of tensors in ORTModule. * PR feedback. * PR feedback. * PR feedback.
This commit is contained in:
parent
0702a14ee7
commit
82108b18e3
5 changed files with 19 additions and 16 deletions
|
|
@ -147,17 +147,9 @@ Status PartialExecutor::Execute(const SessionState& session_state, const std::ve
|
|||
tp = session_state.Profiler().Now();
|
||||
}
|
||||
|
||||
if (state_.GetExecutionFrame() == nullptr) {
|
||||
auto frame = onnxruntime::make_unique<ExecutionFrame>(feed_mlvalue_idxs, feeds, fetch_mlvalue_idxs,
|
||||
fetches, fetch_allocators, session_state);
|
||||
ExecutionFrame& frame = state_.GetExecutionFrame(feed_mlvalue_idxs, feeds, fetch_mlvalue_idxs, fetches,
|
||||
fetch_allocators, session_state);
|
||||
|
||||
state_.SetExecutionFrame(std::move(frame));
|
||||
} else {
|
||||
state_.GetExecutionFrame()->UpdateFeeds(feed_mlvalue_idxs, feeds);
|
||||
state_.GetExecutionFrame()->UpdateFetches(fetch_mlvalue_idxs, fetches, session_state.GetInitializedTensors());
|
||||
}
|
||||
|
||||
ExecutionFrame& frame = *(state_.GetExecutionFrame());
|
||||
LOGS(logger, INFO) << "Begin execution";
|
||||
const SequentialExecutionPlan& seq_exec_plan = *session_state.GetExecutionPlan();
|
||||
const auto& exec_plan_vec = seq_exec_plan.execution_plan;
|
||||
|
|
|
|||
|
|
@ -22,11 +22,20 @@ struct PartialGraphExecutionState {
|
|||
size_t GetProgramCounterStart() { return program_counter_start_; }
|
||||
size_t GetProgramCounterEnd() { return program_counter_end_; }
|
||||
|
||||
void SetExecutionFrame(std::unique_ptr<ExecutionFrame> frame) {
|
||||
execution_frame_ = std::move(frame);
|
||||
}
|
||||
ExecutionFrame& GetExecutionFrame(const std::vector<int>& feed_mlvalue_idxs, const std::vector<OrtValue>& feeds,
|
||||
const std::vector<int>& fetch_mlvalue_idxs, const std::vector<OrtValue>& fetches,
|
||||
const std::unordered_map<size_t, IExecutor::CustomAllocator>& fetch_allocators,
|
||||
const SessionState& session_state) {
|
||||
if (execution_frame_ == nullptr) {
|
||||
execution_frame_ = onnxruntime::make_unique<ExecutionFrame>(feed_mlvalue_idxs, feeds, fetch_mlvalue_idxs, fetches,
|
||||
fetch_allocators, session_state);
|
||||
} else {
|
||||
execution_frame_->UpdateFeeds(feed_mlvalue_idxs, feeds);
|
||||
execution_frame_->UpdateFetches(fetch_mlvalue_idxs, fetches, session_state.GetInitializedTensors());
|
||||
}
|
||||
|
||||
const std::unique_ptr<ExecutionFrame>& GetExecutionFrame() const { return execution_frame_; }
|
||||
return *execution_frame_;
|
||||
}
|
||||
|
||||
private:
|
||||
std::unique_ptr<ExecutionFrame> execution_frame_;
|
||||
|
|
|
|||
|
|
@ -38,7 +38,7 @@ TrainingAgent::TrainingAgent(InferenceSession& session,
|
|||
bw_program_counter_end_ = exec_plan_vec.size();
|
||||
}
|
||||
|
||||
TrainingAgent::~TrainingAgent(){};
|
||||
TrainingAgent::~TrainingAgent() = default;
|
||||
|
||||
common::Status TrainingAgent::RunForward(const std::vector<OrtValue>& feeds, std::vector<OrtValue>& fetches,
|
||||
PartialGraphExecutionState& state) {
|
||||
|
|
|
|||
|
|
@ -14,7 +14,6 @@
|
|||
|
||||
namespace onnxruntime {
|
||||
namespace training {
|
||||
class IOBinding;
|
||||
|
||||
class TrainingAgent {
|
||||
public:
|
||||
|
|
|
|||
|
|
@ -151,6 +151,9 @@ class TrainingManager(GraphExecutionManager):
|
|||
|
||||
backward_outputs = C.OrtValueVector()
|
||||
self._execution_agent.run_backward(backward_inputs, backward_outputs, ctx.run_info.state)
|
||||
# Destroy the state immediately (as opposed to be at the mercy of garbage collector) so it does not
|
||||
# affect peak memory usage in a subsequent graph run.
|
||||
del ctx.run_info.state
|
||||
# Return input and initializer gradients
|
||||
num_user_input_grads = len(self._input_info.require_grad_names)
|
||||
results = []
|
||||
|
|
|
|||
Loading…
Reference in a new issue