Zhijxu/cleanup cached tensors when oom (#19306)

in pytorch, when oom happens at bp, user could decrease the batch size
and rerun it without restarting the process.

while in ORT, the intermediate tensors are kept even OOM, so decrease
batch size still fail.


this is torch run, we can see after oom failure, torch will release
tensor before next step

![image](https://github.com/microsoft/onnxruntime/assets/43435212/92b8a2e3-454b-448a-a223-17cb91d463c2)

this is from ort, we can see ort not release its tensors after OOM
failure.

![image](https://github.com/microsoft/onnxruntime/assets/43435212/bb6a3882-8e14-4f37-8079-e7f70fc2546b)

ort with the PR, we can see memory is released, **the 4GB memory is not
own by ort, and will be released by torch at the end**.

![image](https://github.com/microsoft/onnxruntime/assets/43435212/7f39d711-4e36-47d5-aecf-3805433a6d01)
This commit is contained in:
zhijiang 2024-02-21 10:41:42 +08:00 committed by GitHub
parent 0c4421cb78
commit 8fadc6c913
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 35 additions and 10 deletions

View file

@ -204,6 +204,14 @@ AllocatorPtr IExecutionFrame::GetAllocator(const OrtDevice& info) const {
Status IExecutionFrame::ReleaseMLValue(int ort_value_idx) { return ReleaseMLValueImpl(ort_value_idx); }
#ifdef ENABLE_TRAINING
void IExecutionFrame::ReleaseAllMLValues() {
for (size_t ort_value_idx = 0; ort_value_idx < all_values_.size(); ort_value_idx++) {
all_values_[ort_value_idx] = OrtValue();
}
}
#endif
Status IExecutionFrame::ReleaseMLValueImpl(int ort_value_idx) {
if (ort_value_idx == NodeIndexInfo::kInvalidEntry || static_cast<size_t>(ort_value_idx) >= all_values_size_) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "invalid index ", ort_value_idx);
@ -831,7 +839,20 @@ AllocatorPtr ExecutionFrame::GetAllocatorImpl(const OrtDevice& info) const {
// This method is not thread safe!
// Return S_OK and nullptr if index map to a value that is an unused optional input/output
Status ExecutionFrame::CreateNodeOutputMLValueImpl(OrtValue& ort_value, int ort_value_idx, const TensorShape* shape) {
#ifdef ENABLE_TRAINING
try {
auto status = AllocateAsPerAllocationPlan(ort_value, ort_value_idx, shape);
return status;
} catch (const std::exception& e) {
LOGS(session_state_.Logger(), WARNING)
<< "Exception caught when allocating memory for ort_value with index: " << ort_value_idx
<< "so clean up all OrtValues";
ReleaseAllMLValues();
return Status(ONNXRUNTIME, FAIL, e.what());
}
#else
return AllocateAsPerAllocationPlan(ort_value, ort_value_idx, shape);
#endif
}
void ExecutionFrame::VerifyOutputSizes(int output_index, const Node& node, const TensorShape& output_shape) {

View file

@ -67,6 +67,8 @@ class IExecutionFrame {
const std::unordered_map<int, OrtValue>& initializers);
Status GetOutputs(gsl::span<const int> fetch_mlvalue_idxs, std::vector<OrtValue>& fetches);
// if OOM happens, then release all values, so session can run next batch.
void ReleaseAllMLValues();
#endif
// TO DO: make it thread safe

View file

@ -196,18 +196,20 @@ class TrainingManager(GraphExecutionManager):
# Run and get results
backward_outputs = C.OrtValueVector()
self._execution_agent.run_backward(backward_inputs, backward_outputs, ctx.run_info.state)
# Destroy the state immediately (as opposed to be at the mercy of garbage collector) so it does not
# affect peak memory usage in a subsequent graph run.
del ctx.run_info.state
try:
self._execution_agent.run_backward(backward_inputs, backward_outputs, ctx.run_info.state)
# Destroy the state immediately (as opposed to be at the mercy of garbage collector) so it does not
# affect peak memory usage in a subsequent graph run.
# Fast version: all backward_outputs are converted first.
# This version only works if backward_outputs is an OrtValueVector.
transferred_backward_outputs = _utils._ortvalues_to_torch_tensor(backward_outputs, self._device)
# Fast version: all backward_outputs are converted first.
# This version only works if backward_outputs is an OrtValueVector.
transferred_backward_outputs = _utils._ortvalues_to_torch_tensor(backward_outputs, self._device)
self._runtime_inspector.memory_ob.inspect_memory(Phase.POST_BACKWARD)
return tuple(transferred_backward_outputs[idx] if idx != -1 else None for idx in self._gradient_map)
self._runtime_inspector.memory_ob.inspect_memory(Phase.POST_BACKWARD)
res = tuple(transferred_backward_outputs[idx] if idx != -1 else None for idx in self._gradient_map)
return res
finally:
del ctx.run_info.state
return _ORTModuleFunction