mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-14 20:48:00 +00:00
Zhijxu/cleanup cached tensors when oom (#19306)
in pytorch, when oom happens at bp, user could decrease the batch size and rerun it without restarting the process. while in ORT, the intermediate tensors are kept even OOM, so decrease batch size still fail. this is torch run, we can see after oom failure, torch will release tensor before next step  this is from ort, we can see ort not release its tensors after OOM failure.  ort with the PR, we can see memory is released, **the 4GB memory is not own by ort, and will be released by torch at the end**. 
This commit is contained in:
parent
0c4421cb78
commit
8fadc6c913
3 changed files with 35 additions and 10 deletions
|
|
@ -204,6 +204,14 @@ AllocatorPtr IExecutionFrame::GetAllocator(const OrtDevice& info) const {
|
|||
|
||||
Status IExecutionFrame::ReleaseMLValue(int ort_value_idx) { return ReleaseMLValueImpl(ort_value_idx); }
|
||||
|
||||
#ifdef ENABLE_TRAINING
|
||||
void IExecutionFrame::ReleaseAllMLValues() {
|
||||
for (size_t ort_value_idx = 0; ort_value_idx < all_values_.size(); ort_value_idx++) {
|
||||
all_values_[ort_value_idx] = OrtValue();
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
Status IExecutionFrame::ReleaseMLValueImpl(int ort_value_idx) {
|
||||
if (ort_value_idx == NodeIndexInfo::kInvalidEntry || static_cast<size_t>(ort_value_idx) >= all_values_size_) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "invalid index ", ort_value_idx);
|
||||
|
|
@ -831,7 +839,20 @@ AllocatorPtr ExecutionFrame::GetAllocatorImpl(const OrtDevice& info) const {
|
|||
// This method is not thread safe!
|
||||
// Return S_OK and nullptr if index map to a value that is an unused optional input/output
|
||||
Status ExecutionFrame::CreateNodeOutputMLValueImpl(OrtValue& ort_value, int ort_value_idx, const TensorShape* shape) {
|
||||
#ifdef ENABLE_TRAINING
|
||||
try {
|
||||
auto status = AllocateAsPerAllocationPlan(ort_value, ort_value_idx, shape);
|
||||
return status;
|
||||
} catch (const std::exception& e) {
|
||||
LOGS(session_state_.Logger(), WARNING)
|
||||
<< "Exception caught when allocating memory for ort_value with index: " << ort_value_idx
|
||||
<< "so clean up all OrtValues";
|
||||
ReleaseAllMLValues();
|
||||
return Status(ONNXRUNTIME, FAIL, e.what());
|
||||
}
|
||||
#else
|
||||
return AllocateAsPerAllocationPlan(ort_value, ort_value_idx, shape);
|
||||
#endif
|
||||
}
|
||||
|
||||
void ExecutionFrame::VerifyOutputSizes(int output_index, const Node& node, const TensorShape& output_shape) {
|
||||
|
|
|
|||
|
|
@ -67,6 +67,8 @@ class IExecutionFrame {
|
|||
|
||||
const std::unordered_map<int, OrtValue>& initializers);
|
||||
Status GetOutputs(gsl::span<const int> fetch_mlvalue_idxs, std::vector<OrtValue>& fetches);
|
||||
// if OOM happens, then release all values, so session can run next batch.
|
||||
void ReleaseAllMLValues();
|
||||
#endif
|
||||
|
||||
// TO DO: make it thread safe
|
||||
|
|
|
|||
|
|
@ -196,18 +196,20 @@ class TrainingManager(GraphExecutionManager):
|
|||
|
||||
# Run and get results
|
||||
backward_outputs = C.OrtValueVector()
|
||||
self._execution_agent.run_backward(backward_inputs, backward_outputs, ctx.run_info.state)
|
||||
# Destroy the state immediately (as opposed to be at the mercy of garbage collector) so it does not
|
||||
# affect peak memory usage in a subsequent graph run.
|
||||
del ctx.run_info.state
|
||||
try:
|
||||
self._execution_agent.run_backward(backward_inputs, backward_outputs, ctx.run_info.state)
|
||||
# Destroy the state immediately (as opposed to be at the mercy of garbage collector) so it does not
|
||||
# affect peak memory usage in a subsequent graph run.
|
||||
|
||||
# Fast version: all backward_outputs are converted first.
|
||||
# This version only works if backward_outputs is an OrtValueVector.
|
||||
transferred_backward_outputs = _utils._ortvalues_to_torch_tensor(backward_outputs, self._device)
|
||||
# Fast version: all backward_outputs are converted first.
|
||||
# This version only works if backward_outputs is an OrtValueVector.
|
||||
transferred_backward_outputs = _utils._ortvalues_to_torch_tensor(backward_outputs, self._device)
|
||||
|
||||
self._runtime_inspector.memory_ob.inspect_memory(Phase.POST_BACKWARD)
|
||||
|
||||
return tuple(transferred_backward_outputs[idx] if idx != -1 else None for idx in self._gradient_map)
|
||||
self._runtime_inspector.memory_ob.inspect_memory(Phase.POST_BACKWARD)
|
||||
res = tuple(transferred_backward_outputs[idx] if idx != -1 else None for idx in self._gradient_map)
|
||||
return res
|
||||
finally:
|
||||
del ctx.run_info.state
|
||||
|
||||
return _ORTModuleFunction
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue