From 8fadc6c913bc30edff2e89756da515b9bd75d256 Mon Sep 17 00:00:00 2001 From: zhijiang <43435212+zhijxu-MS@users.noreply.github.com> Date: Wed, 21 Feb 2024 10:41:42 +0800 Subject: [PATCH] Zhijxu/cleanup cached tensors when oom (#19306) in pytorch, when oom happens at bp, user could decrease the batch size and rerun it without restarting the process. while in ORT, the intermediate tensors are kept even OOM, so decrease batch size still fail. this is torch run, we can see after oom failure, torch will release tensor before next step ![image](https://github.com/microsoft/onnxruntime/assets/43435212/92b8a2e3-454b-448a-a223-17cb91d463c2) this is from ort, we can see ort not release its tensors after OOM failure. ![image](https://github.com/microsoft/onnxruntime/assets/43435212/bb6a3882-8e14-4f37-8079-e7f70fc2546b) ort with the PR, we can see memory is released, **the 4GB memory is not own by ort, and will be released by torch at the end**. ![image](https://github.com/microsoft/onnxruntime/assets/43435212/7f39d711-4e36-47d5-aecf-3805433a6d01) --- onnxruntime/core/framework/execution_frame.cc | 21 ++++++++++++++++++ onnxruntime/core/framework/execution_frame.h | 2 ++ .../training/ortmodule/_training_manager.py | 22 ++++++++++--------- 3 files changed, 35 insertions(+), 10 deletions(-) diff --git a/onnxruntime/core/framework/execution_frame.cc b/onnxruntime/core/framework/execution_frame.cc index 8c08152986..32a5f749af 100644 --- a/onnxruntime/core/framework/execution_frame.cc +++ b/onnxruntime/core/framework/execution_frame.cc @@ -204,6 +204,14 @@ AllocatorPtr IExecutionFrame::GetAllocator(const OrtDevice& info) const { Status IExecutionFrame::ReleaseMLValue(int ort_value_idx) { return ReleaseMLValueImpl(ort_value_idx); } +#ifdef ENABLE_TRAINING +void IExecutionFrame::ReleaseAllMLValues() { + for (size_t ort_value_idx = 0; ort_value_idx < all_values_.size(); ort_value_idx++) { + all_values_[ort_value_idx] = OrtValue(); + } +} +#endif + Status IExecutionFrame::ReleaseMLValueImpl(int ort_value_idx) { if (ort_value_idx == NodeIndexInfo::kInvalidEntry || static_cast(ort_value_idx) >= all_values_size_) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "invalid index ", ort_value_idx); @@ -831,7 +839,20 @@ AllocatorPtr ExecutionFrame::GetAllocatorImpl(const OrtDevice& info) const { // This method is not thread safe! // Return S_OK and nullptr if index map to a value that is an unused optional input/output Status ExecutionFrame::CreateNodeOutputMLValueImpl(OrtValue& ort_value, int ort_value_idx, const TensorShape* shape) { +#ifdef ENABLE_TRAINING + try { + auto status = AllocateAsPerAllocationPlan(ort_value, ort_value_idx, shape); + return status; + } catch (const std::exception& e) { + LOGS(session_state_.Logger(), WARNING) + << "Exception caught when allocating memory for ort_value with index: " << ort_value_idx + << "so clean up all OrtValues"; + ReleaseAllMLValues(); + return Status(ONNXRUNTIME, FAIL, e.what()); + } +#else return AllocateAsPerAllocationPlan(ort_value, ort_value_idx, shape); +#endif } void ExecutionFrame::VerifyOutputSizes(int output_index, const Node& node, const TensorShape& output_shape) { diff --git a/onnxruntime/core/framework/execution_frame.h b/onnxruntime/core/framework/execution_frame.h index 1576c16684..18d210ffd4 100644 --- a/onnxruntime/core/framework/execution_frame.h +++ b/onnxruntime/core/framework/execution_frame.h @@ -67,6 +67,8 @@ class IExecutionFrame { const std::unordered_map& initializers); Status GetOutputs(gsl::span fetch_mlvalue_idxs, std::vector& fetches); + // if OOM happens, then release all values, so session can run next batch. + void ReleaseAllMLValues(); #endif // TO DO: make it thread safe diff --git a/orttraining/orttraining/python/training/ortmodule/_training_manager.py b/orttraining/orttraining/python/training/ortmodule/_training_manager.py index cc533e549d..73c32a2f51 100644 --- a/orttraining/orttraining/python/training/ortmodule/_training_manager.py +++ b/orttraining/orttraining/python/training/ortmodule/_training_manager.py @@ -196,18 +196,20 @@ class TrainingManager(GraphExecutionManager): # Run and get results backward_outputs = C.OrtValueVector() - self._execution_agent.run_backward(backward_inputs, backward_outputs, ctx.run_info.state) - # Destroy the state immediately (as opposed to be at the mercy of garbage collector) so it does not - # affect peak memory usage in a subsequent graph run. - del ctx.run_info.state + try: + self._execution_agent.run_backward(backward_inputs, backward_outputs, ctx.run_info.state) + # Destroy the state immediately (as opposed to be at the mercy of garbage collector) so it does not + # affect peak memory usage in a subsequent graph run. - # Fast version: all backward_outputs are converted first. - # This version only works if backward_outputs is an OrtValueVector. - transferred_backward_outputs = _utils._ortvalues_to_torch_tensor(backward_outputs, self._device) + # Fast version: all backward_outputs are converted first. + # This version only works if backward_outputs is an OrtValueVector. + transferred_backward_outputs = _utils._ortvalues_to_torch_tensor(backward_outputs, self._device) - self._runtime_inspector.memory_ob.inspect_memory(Phase.POST_BACKWARD) - - return tuple(transferred_backward_outputs[idx] if idx != -1 else None for idx in self._gradient_map) + self._runtime_inspector.memory_ob.inspect_memory(Phase.POST_BACKWARD) + res = tuple(transferred_backward_outputs[idx] if idx != -1 else None for idx in self._gradient_map) + return res + finally: + del ctx.run_info.state return _ORTModuleFunction