Fix cuda graph capture (#15005)

Fix two issues related to cuda graph capture:
https://github.com/microsoft/onnxruntime/issues/14942 and
https://github.com/microsoft/onnxruntime/issues/15002

Issue 1: Previously, graph capture starts at the second run. However,
memory pattern optimization will allocate memory from the second run,
and cudamalloc is not allowed during graph capture. In this PR, the
graph capture will start graph capture after 2 runs to avoid the issue.

Issue 2: https://github.com/microsoft/onnxruntime/pull/13495 introduced
multiple stream support. But stream cleanup will call
cudaStreamSyncronize which is not allowed in cuda graph capture. In this
PR, we move stream cleanup after cuda graph capture.

Update the squeeze net test model with dynamic axis so that we can test
with larger batch size. Add a test that could reproduce the bug (when
changing min runs from 2 back to 1).
This commit is contained in:
Tianlei Wu 2023-06-14 18:10:20 -07:00 committed by GitHub
parent 8a3de16d14
commit 9be133231f
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
8 changed files with 195 additions and 76 deletions

View file

@ -139,5 +139,16 @@ Stream* DeviceStreamCollection::GetRootStream() const {
return impl_->GetRootStream();
}
DeviceStreamCollectionHolder::DeviceStreamCollectionHolder(const SessionState* session_state)
: session_state_(session_state),
p_(session_state->AcquireDeviceStreamCollection()) {
}
DeviceStreamCollectionHolder::~DeviceStreamCollectionHolder() {
if (p_) {
session_state_->RecycleDeviceStreamCollection(std::move(p_));
}
}
} // namespace onnxruntime
#endif

View file

@ -45,5 +45,17 @@ class DeviceStreamCollection {
private:
std::unique_ptr<DeviceStreamCollectionImpl> impl_;
};
struct DeviceStreamCollectionHolder {
DeviceStreamCollectionHolder(const SessionState* session_state);
DeviceStreamCollectionHolder() = delete;
DeviceStreamCollectionHolder(const DeviceStreamCollectionHolder&) = delete;
~DeviceStreamCollectionHolder();
const SessionState* session_state_;
std::unique_ptr<DeviceStreamCollection> p_;
};
} // namespace onnxruntime
#endif

View file

@ -489,22 +489,6 @@ static common::Status CopyInputsAcrossDevices(const SessionState& session_state,
}
#ifdef ORT_ENABLE_STREAM
struct DeviceStreamCollectionHolder {
DeviceStreamCollectionHolder(
const SessionState& session_state) : session_state_(session_state),
p_(session_state.AcquireDeviceStreamCollection()) {
}
~DeviceStreamCollectionHolder() {
if (p_) {
session_state_.RecycleDeviceStreamCollection(std::move(p_));
}
}
const SessionState& session_state_;
std::unique_ptr<DeviceStreamCollection> p_;
};
static void UpdateWithParentStream(DeviceStreamCollection& device_stream_collection,
Stream* parent_stream) {
if (parent_stream) {
@ -551,7 +535,7 @@ common::Status CopyOneInputAcrossDevices(const SessionState& session_state, cons
Stream* device_stream = nullptr;
#ifdef ORT_ENABLE_STREAM
DeviceStreamCollectionHolder device_stream_collection_holder(session_state);
DeviceStreamCollectionHolder device_stream_collection_holder(&session_state);
if (device_stream_collection_holder.p_ != nullptr) {
DeviceStreamCollection* device_stream_collection = device_stream_collection_holder.p_.get();
size_t num_streams = device_stream_collection->NumStreams();
@ -750,7 +734,10 @@ common::Status ExecuteGraph(const SessionState& session_state,
FeedsFetchesManager& feeds_fetches_manager,
gsl::span<const OrtValue> feeds, std::vector<OrtValue>& fetches,
ExecutionMode execution_mode, const bool& terminate_flag,
const logging::Logger& logger, bool sync_execution_provider,
const logging::Logger& logger,
#ifdef ORT_ENABLE_STREAM
DeviceStreamCollectionHolder& device_stream_collection_holder,
#endif
bool only_execute_path_to_fetches,
Stream* parent_stream) {
ORT_RETURN_IF_ERROR(utils::InitializeFeedFetchCopyInfo(session_state, feeds_fetches_manager));
@ -758,18 +745,14 @@ common::Status ExecuteGraph(const SessionState& session_state,
// finalize the copy info using the provided feeds and fetches. will update device_copy_checks in the background
FinalizeFeedFetchCopyInfo(feeds_fetches_manager, feeds, fetches);
#ifdef ORT_ENABLE_STREAM
DeviceStreamCollectionHolder device_stream_collection_holder(session_state);
DeviceStreamCollection* device_stream_collection = device_stream_collection_holder.p_.get();
auto retval = ExecuteGraphImpl(session_state, feeds_fetches_manager, feeds, fetches, {},
execution_mode, terminate_flag, logger,
device_stream_collection,
only_execute_path_to_fetches,
parent_stream);
if (device_stream_collection)
ORT_CHECK_AND_SET_RETVAL(device_stream_collection->CleanUp(sync_execution_provider));
return retval;
#else
ORT_UNUSED_PARAMETER(sync_execution_provider);
return ExecuteGraphImpl(session_state, feeds_fetches_manager, feeds, fetches, {},
execution_mode, terminate_flag, logger,
only_execute_path_to_fetches,
@ -781,6 +764,9 @@ common::Status ExecuteGraph(const SessionState& session_state,
FeedsFetchesManager& feeds_fetches_manager,
gsl::span<const OrtValue> feeds, std::vector<OrtValue>& fetches,
ExecutionMode execution_mode, const RunOptions& run_options,
#ifdef ORT_ENABLE_STREAM
DeviceStreamCollectionHolder& device_stream_collection_holder,
#endif
const logging::Logger& logger) {
#ifdef USE_AZURE
const auto iter = run_options.config_options.configurations.find("use_azure");
@ -793,14 +779,15 @@ common::Status ExecuteGraph(const SessionState& session_state,
logger);
}
#endif
bool synchronize_execution_providers = run_options.config_options.GetConfigOrDefault(kOrtRunOptionsConfigDisableSynchronizeExecutionProviders, "0") == "0";
return ExecuteGraph(session_state,
feeds_fetches_manager,
feeds, fetches,
execution_mode,
run_options.terminate,
logger,
synchronize_execution_providers,
#ifdef ORT_ENABLE_STREAM
device_stream_collection_holder,
#endif
run_options.only_execute_path_to_fetches);
}
@ -946,7 +933,7 @@ common::Status ExecuteSubgraph(const SessionState& session_state, const FeedsFet
Stream* parent_stream,
bool sync_subgraph_fetches) {
#ifdef ORT_ENABLE_STREAM
DeviceStreamCollectionHolder device_stream_collection_holder(session_state);
DeviceStreamCollectionHolder device_stream_collection_holder(&session_state);
DeviceStreamCollection* device_stream_collection = device_stream_collection_holder.p_.get();
auto retval = ExecuteGraphImpl(session_state, feeds_fetches_manager, feeds, fetches, fetch_allocators,

View file

@ -84,13 +84,19 @@ void FinalizeFeedFetchCopyInfo(FeedsFetchesManager& feeds_fetches_manager,
common::Status ExecuteGraph(const SessionState& session_state, FeedsFetchesManager& feeds_fetches_manager,
gsl::span<const OrtValue> feeds, std::vector<OrtValue>& fetches,
ExecutionMode execution_mode, const bool& terminate_flag, const logging::Logger& logger,
bool sync_execution_provider,
#ifdef ORT_ENABLE_STREAM
DeviceStreamCollectionHolder& device_stream_collection_holder,
#endif
bool only_execute_path_to_fetches = false,
Stream* parent_stream = nullptr);
common::Status ExecuteGraph(const SessionState& session_state, FeedsFetchesManager& feeds_fetches_manager,
gsl::span<const OrtValue> feeds, std::vector<OrtValue>& fetches,
ExecutionMode execution_mode, const RunOptions& run_options, const logging::Logger& logger);
ExecutionMode execution_mode, const RunOptions& run_options,
#ifdef ORT_ENABLE_STREAM
DeviceStreamCollectionHolder& device_stream_collection_holder,
#endif
const logging::Logger& logger);
#ifdef ENABLE_TRAINING
common::Status ExecutePartialGraph(const SessionState& session_state, FeedsFetchesManager& feeds_fetches_manager,

View file

@ -221,7 +221,12 @@ class CUDAExecutionProvider : public IExecutionProvider {
CUDAGraph cuda_graph_;
bool is_graph_captured_ = false;
int regular_run_count_before_graph_capture_ = 0;
const int min_num_runs_before_cuda_graph_capture_ = 1; // required min regular runs before graph capture for the necessary memory allocations.
// There is chance that the second regular run allocates GPU memory for causes like:
// (1) memory pattern is enabled. (2) arena allocation for stream.
// Since no GPU memory allocation is allowed during graph capturing, we need at least two regular runs
// to allocate enough memory in Arena before graph capturing.
const int min_num_runs_before_cuda_graph_capture_ = 2; // required min regular runs before graph capture for the necessary memory allocations.
};
using PerThreadContextMap = std::unordered_map<const CUDAExecutionProvider*, std::weak_ptr<PerThreadContext>>;

View file

@ -1533,32 +1533,30 @@ common::Status InferenceSession::Initialize() {
// Then the CUDA EP is cached for triggering a ReplayGraph() in Run().
auto* cuda_ep = execution_providers_.Get(onnxruntime::kCudaExecutionProvider);
if (cuda_ep && cuda_ep->IsGraphCaptureEnabled()) {
if (cuda_ep->IsGraphCaptureEnabled()) {
if (HasControlflowNodes(graph)) {
LOGS(*session_logger_, ERROR) << "This session cannot use the CUDA Graph feature as requested by the user "
<< " as the model has control flow nodes which can't be supported by CUDA Graphs.";
if (HasControlflowNodes(graph)) {
LOGS(*session_logger_, ERROR) << "This session cannot use the CUDA Graph feature as requested by the user "
<< " as the model has control flow nodes which can't be supported by CUDA Graphs.";
// Return error status as we don't want the session initialization to complete successfully
// if the user has requested usage of CUDA Graph feature and we cannot honor that.
ORT_RETURN_IF_ERROR_SESSIONID_(
ORT_MAKE_STATUS(ONNXRUNTIME, FAIL,
"This session cannot use the CUDA Graph feature as requested by the user "
" as the model has control flow nodes which can't be supported by CUDA Graphs."));
} else if (!AreAllNodesInMainGraphAssignedToOneEp(graph, onnxruntime::kCudaExecutionProvider)) {
LOGS(*session_logger_, ERROR) << "This session cannot use the CUDA Graph feature as requested by the user "
<< " as all the graph nodes have not been partitioned to the CUDA EP.";
// Return error status as we don't want the session initialization to complete successfully
// if the user has requested usage of CUDA Graph feature and we cannot honor that.
ORT_RETURN_IF_ERROR_SESSIONID_(
ORT_MAKE_STATUS(ONNXRUNTIME, FAIL,
"This session cannot use the CUDA Graph feature as requested by the user "
" as the model has control flow nodes which can't be supported by CUDA Graphs."));
} else if (!AreAllNodesInMainGraphAssignedToOneEp(graph, onnxruntime::kCudaExecutionProvider)) {
LOGS(*session_logger_, ERROR) << "This session cannot use the CUDA Graph feature as requested by the user "
<< " as all the graph nodes have not been partitioned to the CUDA EP.";
// Return error status as we don't want the session initialization to complete successfully
// if the user has requested usage of CUDA Graph feature and we cannot honor that.
ORT_RETURN_IF_ERROR_SESSIONID_(
ORT_MAKE_STATUS(ONNXRUNTIME, FAIL,
"This session cannot use the CUDA Graph feature as requested by the user "
" as all the graph nodes have not been partitioned to the CUDA EP."));
// Return error status as we don't want the session initialization to complete successfully
// if the user has requested usage of CUDA Graph feature and we cannot honor that.
ORT_RETURN_IF_ERROR_SESSIONID_(
ORT_MAKE_STATUS(ONNXRUNTIME, FAIL,
"This session cannot use the CUDA Graph feature as requested by the user "
" as all the graph nodes have not been partitioned to the CUDA EP."));
} else {
LOGS(*session_logger_, INFO) << "This session will use the CUDA Graph feature as requested by the user.";
cached_execution_provider_for_graph_replay_.SetExecutionProvider(cuda_ep);
}
} else {
LOGS(*session_logger_, INFO) << "This session will use the CUDA Graph feature as requested by the user.";
cached_execution_provider_for_graph_replay_.SetExecutionProvider(cuda_ep);
}
}
@ -2141,9 +2139,37 @@ Status InferenceSession::Run(const RunOptions& run_options,
session_state_->IncrementGraphExecutionCounter();
#endif
ORT_CHECK_AND_SET_RETVAL(utils::ExecuteGraph(*session_state_, feeds_fetches_manager, feeds, *p_fetches,
session_options_.execution_mode,
run_options, run_logger));
#ifdef ORT_ENABLE_STREAM
DeviceStreamCollectionHolder device_stream_collection_holder(session_state_.get());
#endif
if (retval.IsOK()) {
retval = utils::ExecuteGraph(*session_state_, feeds_fetches_manager, feeds, *p_fetches,
session_options_.execution_mode,
run_options,
#ifdef ORT_ENABLE_STREAM
device_stream_collection_holder,
#endif
run_logger);
}
// info all execution providers InferenceSession:Run ended
for (auto* xp : exec_providers_to_stop) {
bool synchronize_execution_providers = run_options.config_options.GetConfigOrDefault(kOrtRunOptionsConfigDisableSynchronizeExecutionProviders, "0") == "0";
auto status = xp->OnRunEnd(synchronize_execution_providers);
ORT_CHECK_AND_SET_RETVAL(status);
}
// Move stream cleanup from ExecuteGraph to here for cuda graph capture.
// Cleanup will call cudaStreamSyncronize, which is not allowed for graph capture.
// Note that graph capture ends when we call xp->OnRunEnd() in the above code so it is safe here.
#ifdef ORT_ENABLE_STREAM
DeviceStreamCollection* device_stream_collection = device_stream_collection_holder.p_.get();
if (device_stream_collection) {
bool sync_execution_provider = run_options.config_options.GetConfigOrDefault(kOrtRunOptionsConfigDisableSynchronizeExecutionProviders, "0") == "0";
ORT_CHECK_AND_SET_RETVAL(device_stream_collection->CleanUp(sync_execution_provider));
}
#endif
}
ORT_CATCH(const std::exception& e) {
ORT_HANDLE_EXCEPTION([&]() {
@ -2154,13 +2180,6 @@ Status InferenceSession::Run(const RunOptions& run_options,
retval = Status(common::ONNXRUNTIME, common::RUNTIME_EXCEPTION, "Encountered unknown exception in Run()");
}
// info all execution providers InferenceSession:Run ended
for (auto* xp : exec_providers_to_stop) {
bool synchronize_execution_providers = run_options.config_options.GetConfigOrDefault(kOrtRunOptionsConfigDisableSynchronizeExecutionProviders, "0") == "0";
auto status = xp->OnRunEnd(synchronize_execution_providers);
ORT_CHECK_AND_SET_RETVAL(status);
}
if (!arenas_to_shrink.empty()) {
ShrinkMemoryArenas(arenas_to_shrink);
}
@ -2192,15 +2211,13 @@ Status InferenceSession::Run(const RunOptions& run_options,
TraceLoggingWriteStop(ortrun_activity, "OrtRun");
#endif
// As two inference runs (one for memory allocation and one for graph capturing)
// are needed before replaying the captured graph, here run the inference again
// to capture the graph, so that users just need one session run to capture
// the graph.
// As N+1 inference runs (N for memory allocation and 1 for graph capturing)
// are needed before replaying the captured graph, here run N inference runs recursively until graph captured,
// so that users just need one session run to capture the graph.
// N is defined in min_num_runs_before_cuda_graph_capture_ for CUDA EP, and the value could be different for other EP.
if (retval.IsOK() && cached_execution_provider_for_graph_replay_.IsGraphCaptureEnabled() &&
!cached_execution_provider_for_graph_replay_.IsGraphCaptured()) {
LOGS(*session_logger_, INFO) << "Start the second Run() to capture the graph. "
"The first one is for necessary memory allocation;"
"The second one is for capturing the graph.";
LOGS(*session_logger_, INFO) << "Start another run for necessary memory allocation or graph capture.";
ORT_RETURN_IF_ERROR(Run(run_options, feed_names, feeds, output_names, p_fetches, p_fetches_device_info));
}
return retval;

View file

@ -1,20 +1,63 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
import gc # noqa: F401
import os # noqa: F401
import sys # noqa: F401
import threading # noqa: F401
import time # noqa: F401
# -*- coding: UTF-8 -*-
import unittest
from typing import Dict, List
import numpy as np
from helper import get_name
import onnxruntime as onnxrt
from onnxruntime.capi.onnxruntime_pybind11_state import Fail # noqa: F401
class CudaGraphHelper:
def __init__(
self,
ort_session: onnxrt.InferenceSession,
input_and_output_shape: Dict[str, List[int]],
device_id: int = 0,
):
self.input_names = [input.name for input in ort_session.get_inputs()]
self.output_names = [output.name for output in ort_session.get_outputs()]
self.input_and_output_shape = input_and_output_shape
self.io_numpy_type = self.get_io_numpy_type_map(ort_session)
self.io_binding = ort_session.io_binding()
self.io_ort_value = {}
for name in self.input_names + self.output_names:
ort_value = onnxrt.OrtValue.ortvalue_from_shape_and_type(
input_and_output_shape[name], self.io_numpy_type[name], "cuda", device_id
)
self.io_ort_value[name] = ort_value
if name in self.input_names:
self.io_binding.bind_ortvalue_input(name, ort_value)
else:
self.io_binding.bind_ortvalue_output(name, ort_value)
def get_io_numpy_type_map(self, ort_session: onnxrt.InferenceSession):
ort_type_to_numpy_type = {
"tensor(int64)": np.longlong,
"tensor(int32)": np.intc,
"tensor(float)": np.float32,
"tensor(float16)": np.float16,
}
name_to_numpy_type = {}
for _input in ort_session.get_inputs():
name_to_numpy_type[_input.name] = ort_type_to_numpy_type[_input.type]
for output in ort_session.get_outputs():
name_to_numpy_type[output.name] = ort_type_to_numpy_type[output.type]
return name_to_numpy_type
def update_inputs(self, inputs: Dict[str, np.ndarray]):
for input_name in self.input_names:
self.io_ort_value[input_name].update_inplace(inputs[input_name])
def get_output(self, output_name: str):
return self.io_ort_value[output_name].numpy()
class TestInferenceSessionWithCudaGraph(unittest.TestCase):
@ -74,6 +117,44 @@ class TestInferenceSessionWithCudaGraph(unittest.TestCase):
atol=1e-05,
)
def testArenaWithCudaGraph(self): # noqa: N802
if "CUDAExecutionProvider" in onnxrt.get_available_providers():
# To test cuda graph catpure, we set Arena extend strategy to be SameAsRequested so as to detect any
# potential memory allocation after the first run.
providers = [
("CUDAExecutionProvider", {"enable_cuda_graph": True, "arena_extend_strategy": "kSameAsRequested"})
]
test_model_path = get_name("squeezenet/model.onnx")
input_and_output_shape = {
"data_0": [16, 3, 224, 224],
"softmaxout_1": [16, 1000, 1, 1],
}
session_options = onnxrt.SessionOptions()
# It is optional to disable memory pattern since min_num_runs_before_cuda_graph_capture_ = 2.
session_options.enable_mem_pattern = False
session = onnxrt.InferenceSession(test_model_path, session_options, providers=providers)
cuda_graph_helper = CudaGraphHelper(session, input_and_output_shape)
io_binding = cuda_graph_helper.io_binding
# Create a random input for testing.
np.random.seed(0)
inputs = {"data_0": np.random.randint(0, 256, size=input_and_output_shape["data_0"]).astype(np.float32)}
# One regular run for the necessary memory allocation and cuda graph capturing
cuda_graph_helper.update_inputs(inputs)
session.run_with_iobinding(io_binding)
expected_output = cuda_graph_helper.get_output("softmaxout_1")
# After capturing, CUDA graph replay happens from this Run onwards
cuda_graph_helper.update_inputs(inputs)
session.run_with_iobinding(io_binding)
output = cuda_graph_helper.get_output("softmaxout_1")
np.testing.assert_allclose(expected_output, output, rtol=1e-02, atol=1e-02)
if __name__ == "__main__":
unittest.main()

Binary file not shown.