From 9be133231fe7585ed6e3337389058bd5ca8db89e Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Wed, 14 Jun 2023 18:10:20 -0700 Subject: [PATCH] Fix cuda graph capture (#15005) Fix two issues related to cuda graph capture: https://github.com/microsoft/onnxruntime/issues/14942 and https://github.com/microsoft/onnxruntime/issues/15002 Issue 1: Previously, graph capture starts at the second run. However, memory pattern optimization will allocate memory from the second run, and cudamalloc is not allowed during graph capture. In this PR, the graph capture will start graph capture after 2 runs to avoid the issue. Issue 2: https://github.com/microsoft/onnxruntime/pull/13495 introduced multiple stream support. But stream cleanup will call cudaStreamSyncronize which is not allowed in cuda graph capture. In this PR, we move stream cleanup after cuda graph capture. Update the squeeze net test model with dynamic axis so that we can test with larger batch size. Add a test that could reproduce the bug (when changing min runs from 2 back to 1). --- .../framework/device_stream_collection.cc | 11 ++ .../core/framework/device_stream_collection.h | 12 +++ onnxruntime/core/framework/utils.cc | 37 +++---- onnxruntime/core/framework/utils.h | 10 +- .../providers/cuda/cuda_execution_provider.h | 7 +- onnxruntime/core/session/inference_session.cc | 97 ++++++++++-------- .../onnxruntime_test_python_cudagraph.py | 97 ++++++++++++++++-- .../test/testdata/squeezenet/model.onnx | Bin 4952222 -> 4952232 bytes 8 files changed, 195 insertions(+), 76 deletions(-) diff --git a/onnxruntime/core/framework/device_stream_collection.cc b/onnxruntime/core/framework/device_stream_collection.cc index 669fb0bc79..3c102f2679 100644 --- a/onnxruntime/core/framework/device_stream_collection.cc +++ b/onnxruntime/core/framework/device_stream_collection.cc @@ -139,5 +139,16 @@ Stream* DeviceStreamCollection::GetRootStream() const { return impl_->GetRootStream(); } +DeviceStreamCollectionHolder::DeviceStreamCollectionHolder(const SessionState* session_state) + : session_state_(session_state), + p_(session_state->AcquireDeviceStreamCollection()) { +} + +DeviceStreamCollectionHolder::~DeviceStreamCollectionHolder() { + if (p_) { + session_state_->RecycleDeviceStreamCollection(std::move(p_)); + } +} + } // namespace onnxruntime #endif diff --git a/onnxruntime/core/framework/device_stream_collection.h b/onnxruntime/core/framework/device_stream_collection.h index 8a1ed8a41e..8a5f784845 100644 --- a/onnxruntime/core/framework/device_stream_collection.h +++ b/onnxruntime/core/framework/device_stream_collection.h @@ -45,5 +45,17 @@ class DeviceStreamCollection { private: std::unique_ptr impl_; }; + +struct DeviceStreamCollectionHolder { + DeviceStreamCollectionHolder(const SessionState* session_state); + DeviceStreamCollectionHolder() = delete; + DeviceStreamCollectionHolder(const DeviceStreamCollectionHolder&) = delete; + + ~DeviceStreamCollectionHolder(); + + const SessionState* session_state_; + std::unique_ptr p_; +}; + } // namespace onnxruntime #endif diff --git a/onnxruntime/core/framework/utils.cc b/onnxruntime/core/framework/utils.cc index 74c1f19580..fcb73825f1 100644 --- a/onnxruntime/core/framework/utils.cc +++ b/onnxruntime/core/framework/utils.cc @@ -489,22 +489,6 @@ static common::Status CopyInputsAcrossDevices(const SessionState& session_state, } #ifdef ORT_ENABLE_STREAM -struct DeviceStreamCollectionHolder { - DeviceStreamCollectionHolder( - const SessionState& session_state) : session_state_(session_state), - p_(session_state.AcquireDeviceStreamCollection()) { - } - - ~DeviceStreamCollectionHolder() { - if (p_) { - session_state_.RecycleDeviceStreamCollection(std::move(p_)); - } - } - - const SessionState& session_state_; - std::unique_ptr p_; -}; - static void UpdateWithParentStream(DeviceStreamCollection& device_stream_collection, Stream* parent_stream) { if (parent_stream) { @@ -551,7 +535,7 @@ common::Status CopyOneInputAcrossDevices(const SessionState& session_state, cons Stream* device_stream = nullptr; #ifdef ORT_ENABLE_STREAM - DeviceStreamCollectionHolder device_stream_collection_holder(session_state); + DeviceStreamCollectionHolder device_stream_collection_holder(&session_state); if (device_stream_collection_holder.p_ != nullptr) { DeviceStreamCollection* device_stream_collection = device_stream_collection_holder.p_.get(); size_t num_streams = device_stream_collection->NumStreams(); @@ -750,7 +734,10 @@ common::Status ExecuteGraph(const SessionState& session_state, FeedsFetchesManager& feeds_fetches_manager, gsl::span feeds, std::vector& fetches, ExecutionMode execution_mode, const bool& terminate_flag, - const logging::Logger& logger, bool sync_execution_provider, + const logging::Logger& logger, +#ifdef ORT_ENABLE_STREAM + DeviceStreamCollectionHolder& device_stream_collection_holder, +#endif bool only_execute_path_to_fetches, Stream* parent_stream) { ORT_RETURN_IF_ERROR(utils::InitializeFeedFetchCopyInfo(session_state, feeds_fetches_manager)); @@ -758,18 +745,14 @@ common::Status ExecuteGraph(const SessionState& session_state, // finalize the copy info using the provided feeds and fetches. will update device_copy_checks in the background FinalizeFeedFetchCopyInfo(feeds_fetches_manager, feeds, fetches); #ifdef ORT_ENABLE_STREAM - DeviceStreamCollectionHolder device_stream_collection_holder(session_state); DeviceStreamCollection* device_stream_collection = device_stream_collection_holder.p_.get(); auto retval = ExecuteGraphImpl(session_state, feeds_fetches_manager, feeds, fetches, {}, execution_mode, terminate_flag, logger, device_stream_collection, only_execute_path_to_fetches, parent_stream); - if (device_stream_collection) - ORT_CHECK_AND_SET_RETVAL(device_stream_collection->CleanUp(sync_execution_provider)); return retval; #else - ORT_UNUSED_PARAMETER(sync_execution_provider); return ExecuteGraphImpl(session_state, feeds_fetches_manager, feeds, fetches, {}, execution_mode, terminate_flag, logger, only_execute_path_to_fetches, @@ -781,6 +764,9 @@ common::Status ExecuteGraph(const SessionState& session_state, FeedsFetchesManager& feeds_fetches_manager, gsl::span feeds, std::vector& fetches, ExecutionMode execution_mode, const RunOptions& run_options, +#ifdef ORT_ENABLE_STREAM + DeviceStreamCollectionHolder& device_stream_collection_holder, +#endif const logging::Logger& logger) { #ifdef USE_AZURE const auto iter = run_options.config_options.configurations.find("use_azure"); @@ -793,14 +779,15 @@ common::Status ExecuteGraph(const SessionState& session_state, logger); } #endif - bool synchronize_execution_providers = run_options.config_options.GetConfigOrDefault(kOrtRunOptionsConfigDisableSynchronizeExecutionProviders, "0") == "0"; return ExecuteGraph(session_state, feeds_fetches_manager, feeds, fetches, execution_mode, run_options.terminate, logger, - synchronize_execution_providers, +#ifdef ORT_ENABLE_STREAM + device_stream_collection_holder, +#endif run_options.only_execute_path_to_fetches); } @@ -946,7 +933,7 @@ common::Status ExecuteSubgraph(const SessionState& session_state, const FeedsFet Stream* parent_stream, bool sync_subgraph_fetches) { #ifdef ORT_ENABLE_STREAM - DeviceStreamCollectionHolder device_stream_collection_holder(session_state); + DeviceStreamCollectionHolder device_stream_collection_holder(&session_state); DeviceStreamCollection* device_stream_collection = device_stream_collection_holder.p_.get(); auto retval = ExecuteGraphImpl(session_state, feeds_fetches_manager, feeds, fetches, fetch_allocators, diff --git a/onnxruntime/core/framework/utils.h b/onnxruntime/core/framework/utils.h index 3ca9fef62a..56f41154b7 100644 --- a/onnxruntime/core/framework/utils.h +++ b/onnxruntime/core/framework/utils.h @@ -84,13 +84,19 @@ void FinalizeFeedFetchCopyInfo(FeedsFetchesManager& feeds_fetches_manager, common::Status ExecuteGraph(const SessionState& session_state, FeedsFetchesManager& feeds_fetches_manager, gsl::span feeds, std::vector& fetches, ExecutionMode execution_mode, const bool& terminate_flag, const logging::Logger& logger, - bool sync_execution_provider, +#ifdef ORT_ENABLE_STREAM + DeviceStreamCollectionHolder& device_stream_collection_holder, +#endif bool only_execute_path_to_fetches = false, Stream* parent_stream = nullptr); common::Status ExecuteGraph(const SessionState& session_state, FeedsFetchesManager& feeds_fetches_manager, gsl::span feeds, std::vector& fetches, - ExecutionMode execution_mode, const RunOptions& run_options, const logging::Logger& logger); + ExecutionMode execution_mode, const RunOptions& run_options, +#ifdef ORT_ENABLE_STREAM + DeviceStreamCollectionHolder& device_stream_collection_holder, +#endif + const logging::Logger& logger); #ifdef ENABLE_TRAINING common::Status ExecutePartialGraph(const SessionState& session_state, FeedsFetchesManager& feeds_fetches_manager, diff --git a/onnxruntime/core/providers/cuda/cuda_execution_provider.h b/onnxruntime/core/providers/cuda/cuda_execution_provider.h index cb12fc1563..89a5fb83be 100644 --- a/onnxruntime/core/providers/cuda/cuda_execution_provider.h +++ b/onnxruntime/core/providers/cuda/cuda_execution_provider.h @@ -221,7 +221,12 @@ class CUDAExecutionProvider : public IExecutionProvider { CUDAGraph cuda_graph_; bool is_graph_captured_ = false; int regular_run_count_before_graph_capture_ = 0; - const int min_num_runs_before_cuda_graph_capture_ = 1; // required min regular runs before graph capture for the necessary memory allocations. + + // There is chance that the second regular run allocates GPU memory for causes like: + // (1) memory pattern is enabled. (2) arena allocation for stream. + // Since no GPU memory allocation is allowed during graph capturing, we need at least two regular runs + // to allocate enough memory in Arena before graph capturing. + const int min_num_runs_before_cuda_graph_capture_ = 2; // required min regular runs before graph capture for the necessary memory allocations. }; using PerThreadContextMap = std::unordered_map>; diff --git a/onnxruntime/core/session/inference_session.cc b/onnxruntime/core/session/inference_session.cc index 50ffb73087..4d0f0ccde7 100644 --- a/onnxruntime/core/session/inference_session.cc +++ b/onnxruntime/core/session/inference_session.cc @@ -1533,32 +1533,30 @@ common::Status InferenceSession::Initialize() { // Then the CUDA EP is cached for triggering a ReplayGraph() in Run(). auto* cuda_ep = execution_providers_.Get(onnxruntime::kCudaExecutionProvider); if (cuda_ep && cuda_ep->IsGraphCaptureEnabled()) { - if (cuda_ep->IsGraphCaptureEnabled()) { - if (HasControlflowNodes(graph)) { - LOGS(*session_logger_, ERROR) << "This session cannot use the CUDA Graph feature as requested by the user " - << " as the model has control flow nodes which can't be supported by CUDA Graphs."; + if (HasControlflowNodes(graph)) { + LOGS(*session_logger_, ERROR) << "This session cannot use the CUDA Graph feature as requested by the user " + << " as the model has control flow nodes which can't be supported by CUDA Graphs."; - // Return error status as we don't want the session initialization to complete successfully - // if the user has requested usage of CUDA Graph feature and we cannot honor that. - ORT_RETURN_IF_ERROR_SESSIONID_( - ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, - "This session cannot use the CUDA Graph feature as requested by the user " - " as the model has control flow nodes which can't be supported by CUDA Graphs.")); - } else if (!AreAllNodesInMainGraphAssignedToOneEp(graph, onnxruntime::kCudaExecutionProvider)) { - LOGS(*session_logger_, ERROR) << "This session cannot use the CUDA Graph feature as requested by the user " - << " as all the graph nodes have not been partitioned to the CUDA EP."; + // Return error status as we don't want the session initialization to complete successfully + // if the user has requested usage of CUDA Graph feature and we cannot honor that. + ORT_RETURN_IF_ERROR_SESSIONID_( + ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, + "This session cannot use the CUDA Graph feature as requested by the user " + " as the model has control flow nodes which can't be supported by CUDA Graphs.")); + } else if (!AreAllNodesInMainGraphAssignedToOneEp(graph, onnxruntime::kCudaExecutionProvider)) { + LOGS(*session_logger_, ERROR) << "This session cannot use the CUDA Graph feature as requested by the user " + << " as all the graph nodes have not been partitioned to the CUDA EP."; - // Return error status as we don't want the session initialization to complete successfully - // if the user has requested usage of CUDA Graph feature and we cannot honor that. - ORT_RETURN_IF_ERROR_SESSIONID_( - ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, - "This session cannot use the CUDA Graph feature as requested by the user " - " as all the graph nodes have not been partitioned to the CUDA EP.")); + // Return error status as we don't want the session initialization to complete successfully + // if the user has requested usage of CUDA Graph feature and we cannot honor that. + ORT_RETURN_IF_ERROR_SESSIONID_( + ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, + "This session cannot use the CUDA Graph feature as requested by the user " + " as all the graph nodes have not been partitioned to the CUDA EP.")); - } else { - LOGS(*session_logger_, INFO) << "This session will use the CUDA Graph feature as requested by the user."; - cached_execution_provider_for_graph_replay_.SetExecutionProvider(cuda_ep); - } + } else { + LOGS(*session_logger_, INFO) << "This session will use the CUDA Graph feature as requested by the user."; + cached_execution_provider_for_graph_replay_.SetExecutionProvider(cuda_ep); } } @@ -2141,9 +2139,37 @@ Status InferenceSession::Run(const RunOptions& run_options, session_state_->IncrementGraphExecutionCounter(); #endif - ORT_CHECK_AND_SET_RETVAL(utils::ExecuteGraph(*session_state_, feeds_fetches_manager, feeds, *p_fetches, - session_options_.execution_mode, - run_options, run_logger)); +#ifdef ORT_ENABLE_STREAM + DeviceStreamCollectionHolder device_stream_collection_holder(session_state_.get()); +#endif + + if (retval.IsOK()) { + retval = utils::ExecuteGraph(*session_state_, feeds_fetches_manager, feeds, *p_fetches, + session_options_.execution_mode, + run_options, +#ifdef ORT_ENABLE_STREAM + device_stream_collection_holder, +#endif + run_logger); + } + + // info all execution providers InferenceSession:Run ended + for (auto* xp : exec_providers_to_stop) { + bool synchronize_execution_providers = run_options.config_options.GetConfigOrDefault(kOrtRunOptionsConfigDisableSynchronizeExecutionProviders, "0") == "0"; + auto status = xp->OnRunEnd(synchronize_execution_providers); + ORT_CHECK_AND_SET_RETVAL(status); + } + + // Move stream cleanup from ExecuteGraph to here for cuda graph capture. + // Cleanup will call cudaStreamSyncronize, which is not allowed for graph capture. + // Note that graph capture ends when we call xp->OnRunEnd() in the above code so it is safe here. +#ifdef ORT_ENABLE_STREAM + DeviceStreamCollection* device_stream_collection = device_stream_collection_holder.p_.get(); + if (device_stream_collection) { + bool sync_execution_provider = run_options.config_options.GetConfigOrDefault(kOrtRunOptionsConfigDisableSynchronizeExecutionProviders, "0") == "0"; + ORT_CHECK_AND_SET_RETVAL(device_stream_collection->CleanUp(sync_execution_provider)); + } +#endif } ORT_CATCH(const std::exception& e) { ORT_HANDLE_EXCEPTION([&]() { @@ -2154,13 +2180,6 @@ Status InferenceSession::Run(const RunOptions& run_options, retval = Status(common::ONNXRUNTIME, common::RUNTIME_EXCEPTION, "Encountered unknown exception in Run()"); } - // info all execution providers InferenceSession:Run ended - for (auto* xp : exec_providers_to_stop) { - bool synchronize_execution_providers = run_options.config_options.GetConfigOrDefault(kOrtRunOptionsConfigDisableSynchronizeExecutionProviders, "0") == "0"; - auto status = xp->OnRunEnd(synchronize_execution_providers); - ORT_CHECK_AND_SET_RETVAL(status); - } - if (!arenas_to_shrink.empty()) { ShrinkMemoryArenas(arenas_to_shrink); } @@ -2192,15 +2211,13 @@ Status InferenceSession::Run(const RunOptions& run_options, TraceLoggingWriteStop(ortrun_activity, "OrtRun"); #endif - // As two inference runs (one for memory allocation and one for graph capturing) - // are needed before replaying the captured graph, here run the inference again - // to capture the graph, so that users just need one session run to capture - // the graph. + // As N+1 inference runs (N for memory allocation and 1 for graph capturing) + // are needed before replaying the captured graph, here run N inference runs recursively until graph captured, + // so that users just need one session run to capture the graph. + // N is defined in min_num_runs_before_cuda_graph_capture_ for CUDA EP, and the value could be different for other EP. if (retval.IsOK() && cached_execution_provider_for_graph_replay_.IsGraphCaptureEnabled() && !cached_execution_provider_for_graph_replay_.IsGraphCaptured()) { - LOGS(*session_logger_, INFO) << "Start the second Run() to capture the graph. " - "The first one is for necessary memory allocation;" - "The second one is for capturing the graph."; + LOGS(*session_logger_, INFO) << "Start another run for necessary memory allocation or graph capture."; ORT_RETURN_IF_ERROR(Run(run_options, feed_names, feeds, output_names, p_fetches, p_fetches_device_info)); } return retval; diff --git a/onnxruntime/test/python/onnxruntime_test_python_cudagraph.py b/onnxruntime/test/python/onnxruntime_test_python_cudagraph.py index 5dd927a566..30e299863f 100644 --- a/onnxruntime/test/python/onnxruntime_test_python_cudagraph.py +++ b/onnxruntime/test/python/onnxruntime_test_python_cudagraph.py @@ -1,20 +1,63 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. -import gc # noqa: F401 -import os # noqa: F401 -import sys # noqa: F401 -import threading # noqa: F401 -import time # noqa: F401 - -# -*- coding: UTF-8 -*- import unittest +from typing import Dict, List import numpy as np from helper import get_name import onnxruntime as onnxrt -from onnxruntime.capi.onnxruntime_pybind11_state import Fail # noqa: F401 + + +class CudaGraphHelper: + def __init__( + self, + ort_session: onnxrt.InferenceSession, + input_and_output_shape: Dict[str, List[int]], + device_id: int = 0, + ): + self.input_names = [input.name for input in ort_session.get_inputs()] + self.output_names = [output.name for output in ort_session.get_outputs()] + + self.input_and_output_shape = input_and_output_shape + self.io_numpy_type = self.get_io_numpy_type_map(ort_session) + self.io_binding = ort_session.io_binding() + self.io_ort_value = {} + + for name in self.input_names + self.output_names: + ort_value = onnxrt.OrtValue.ortvalue_from_shape_and_type( + input_and_output_shape[name], self.io_numpy_type[name], "cuda", device_id + ) + self.io_ort_value[name] = ort_value + if name in self.input_names: + self.io_binding.bind_ortvalue_input(name, ort_value) + else: + self.io_binding.bind_ortvalue_output(name, ort_value) + + def get_io_numpy_type_map(self, ort_session: onnxrt.InferenceSession): + ort_type_to_numpy_type = { + "tensor(int64)": np.longlong, + "tensor(int32)": np.intc, + "tensor(float)": np.float32, + "tensor(float16)": np.float16, + } + + name_to_numpy_type = {} + for _input in ort_session.get_inputs(): + name_to_numpy_type[_input.name] = ort_type_to_numpy_type[_input.type] + + for output in ort_session.get_outputs(): + name_to_numpy_type[output.name] = ort_type_to_numpy_type[output.type] + + return name_to_numpy_type + + def update_inputs(self, inputs: Dict[str, np.ndarray]): + for input_name in self.input_names: + self.io_ort_value[input_name].update_inplace(inputs[input_name]) + + def get_output(self, output_name: str): + return self.io_ort_value[output_name].numpy() class TestInferenceSessionWithCudaGraph(unittest.TestCase): @@ -74,6 +117,44 @@ class TestInferenceSessionWithCudaGraph(unittest.TestCase): atol=1e-05, ) + def testArenaWithCudaGraph(self): # noqa: N802 + if "CUDAExecutionProvider" in onnxrt.get_available_providers(): + # To test cuda graph catpure, we set Arena extend strategy to be SameAsRequested so as to detect any + # potential memory allocation after the first run. + providers = [ + ("CUDAExecutionProvider", {"enable_cuda_graph": True, "arena_extend_strategy": "kSameAsRequested"}) + ] + test_model_path = get_name("squeezenet/model.onnx") + + input_and_output_shape = { + "data_0": [16, 3, 224, 224], + "softmaxout_1": [16, 1000, 1, 1], + } + + session_options = onnxrt.SessionOptions() + # It is optional to disable memory pattern since min_num_runs_before_cuda_graph_capture_ = 2. + session_options.enable_mem_pattern = False + session = onnxrt.InferenceSession(test_model_path, session_options, providers=providers) + + cuda_graph_helper = CudaGraphHelper(session, input_and_output_shape) + io_binding = cuda_graph_helper.io_binding + + # Create a random input for testing. + np.random.seed(0) + inputs = {"data_0": np.random.randint(0, 256, size=input_and_output_shape["data_0"]).astype(np.float32)} + + # One regular run for the necessary memory allocation and cuda graph capturing + cuda_graph_helper.update_inputs(inputs) + session.run_with_iobinding(io_binding) + expected_output = cuda_graph_helper.get_output("softmaxout_1") + + # After capturing, CUDA graph replay happens from this Run onwards + cuda_graph_helper.update_inputs(inputs) + session.run_with_iobinding(io_binding) + output = cuda_graph_helper.get_output("softmaxout_1") + + np.testing.assert_allclose(expected_output, output, rtol=1e-02, atol=1e-02) + if __name__ == "__main__": unittest.main() diff --git a/onnxruntime/test/testdata/squeezenet/model.onnx b/onnxruntime/test/testdata/squeezenet/model.onnx index b8e1dfce26d9947cb4062583f31ff7c6f1eef879..24de98fd4bc1e38babb38137ca83c1b324879cdb 100644 GIT binary patch delta 348 zcmYL=yHZn8002pDLav8b_NqaWfII|55(HFG9zJ*}@(avlE@+0vq#(|urP!9?3!I^& zqVs<&{R|?F#cs9x>-si((}Orl2HLGww|<&7o8MPAE-#RbWF?ZEYxJd{WBSV}S~Ph?EWQjw~R%Y;0YNtu#qnUQBQD|0e0&t*Yg$fDGwE=%%K zmSsg&<(0gaH?k({@>br-hHT1~Y|DH3ARpzE?8vU{$-W%OXZa$B^7S`In^(v4q5oUj zNgK<_c&NmgWHc0#{7Kq5Jqy`53Q_zgbFY(H7(8z`J3rEH`)8-Il2k*vmzVzMQT(^i Ndousn$%VdR;SQHvQsDpq delta 338 zcmXYoH%aBY~tOlya$%N=Zv3RT4|J)JUz=Nxd{kX-T9}nxt7; zq*dCaT{@&wx};lrq*wZ+Uj}4QhGbYqBqO6TCgU<8lQJdKG9$AxC-br(i?SrkvLdUp zChM{xo3bU_vLm~)C;M_BhjJvxaw4a4Cg)#VY`k6$#Fe+%V)kk_X^UH;AW7m>6qE?# qF!~Pube)N79tyeQefFvFSiG7`TI1$Y$zK^pKk5G^*e{RE8q&Y#Do%R<