mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-14 20:48:00 +00:00
Fix cuda graph capture (#15005)
Fix two issues related to cuda graph capture: https://github.com/microsoft/onnxruntime/issues/14942 and https://github.com/microsoft/onnxruntime/issues/15002 Issue 1: Previously, graph capture starts at the second run. However, memory pattern optimization will allocate memory from the second run, and cudamalloc is not allowed during graph capture. In this PR, the graph capture will start graph capture after 2 runs to avoid the issue. Issue 2: https://github.com/microsoft/onnxruntime/pull/13495 introduced multiple stream support. But stream cleanup will call cudaStreamSyncronize which is not allowed in cuda graph capture. In this PR, we move stream cleanup after cuda graph capture. Update the squeeze net test model with dynamic axis so that we can test with larger batch size. Add a test that could reproduce the bug (when changing min runs from 2 back to 1).
This commit is contained in:
parent
8a3de16d14
commit
9be133231f
8 changed files with 195 additions and 76 deletions
|
|
@ -139,5 +139,16 @@ Stream* DeviceStreamCollection::GetRootStream() const {
|
|||
return impl_->GetRootStream();
|
||||
}
|
||||
|
||||
DeviceStreamCollectionHolder::DeviceStreamCollectionHolder(const SessionState* session_state)
|
||||
: session_state_(session_state),
|
||||
p_(session_state->AcquireDeviceStreamCollection()) {
|
||||
}
|
||||
|
||||
DeviceStreamCollectionHolder::~DeviceStreamCollectionHolder() {
|
||||
if (p_) {
|
||||
session_state_->RecycleDeviceStreamCollection(std::move(p_));
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace onnxruntime
|
||||
#endif
|
||||
|
|
|
|||
|
|
@ -45,5 +45,17 @@ class DeviceStreamCollection {
|
|||
private:
|
||||
std::unique_ptr<DeviceStreamCollectionImpl> impl_;
|
||||
};
|
||||
|
||||
struct DeviceStreamCollectionHolder {
|
||||
DeviceStreamCollectionHolder(const SessionState* session_state);
|
||||
DeviceStreamCollectionHolder() = delete;
|
||||
DeviceStreamCollectionHolder(const DeviceStreamCollectionHolder&) = delete;
|
||||
|
||||
~DeviceStreamCollectionHolder();
|
||||
|
||||
const SessionState* session_state_;
|
||||
std::unique_ptr<DeviceStreamCollection> p_;
|
||||
};
|
||||
|
||||
} // namespace onnxruntime
|
||||
#endif
|
||||
|
|
|
|||
|
|
@ -489,22 +489,6 @@ static common::Status CopyInputsAcrossDevices(const SessionState& session_state,
|
|||
}
|
||||
|
||||
#ifdef ORT_ENABLE_STREAM
|
||||
struct DeviceStreamCollectionHolder {
|
||||
DeviceStreamCollectionHolder(
|
||||
const SessionState& session_state) : session_state_(session_state),
|
||||
p_(session_state.AcquireDeviceStreamCollection()) {
|
||||
}
|
||||
|
||||
~DeviceStreamCollectionHolder() {
|
||||
if (p_) {
|
||||
session_state_.RecycleDeviceStreamCollection(std::move(p_));
|
||||
}
|
||||
}
|
||||
|
||||
const SessionState& session_state_;
|
||||
std::unique_ptr<DeviceStreamCollection> p_;
|
||||
};
|
||||
|
||||
static void UpdateWithParentStream(DeviceStreamCollection& device_stream_collection,
|
||||
Stream* parent_stream) {
|
||||
if (parent_stream) {
|
||||
|
|
@ -551,7 +535,7 @@ common::Status CopyOneInputAcrossDevices(const SessionState& session_state, cons
|
|||
|
||||
Stream* device_stream = nullptr;
|
||||
#ifdef ORT_ENABLE_STREAM
|
||||
DeviceStreamCollectionHolder device_stream_collection_holder(session_state);
|
||||
DeviceStreamCollectionHolder device_stream_collection_holder(&session_state);
|
||||
if (device_stream_collection_holder.p_ != nullptr) {
|
||||
DeviceStreamCollection* device_stream_collection = device_stream_collection_holder.p_.get();
|
||||
size_t num_streams = device_stream_collection->NumStreams();
|
||||
|
|
@ -750,7 +734,10 @@ common::Status ExecuteGraph(const SessionState& session_state,
|
|||
FeedsFetchesManager& feeds_fetches_manager,
|
||||
gsl::span<const OrtValue> feeds, std::vector<OrtValue>& fetches,
|
||||
ExecutionMode execution_mode, const bool& terminate_flag,
|
||||
const logging::Logger& logger, bool sync_execution_provider,
|
||||
const logging::Logger& logger,
|
||||
#ifdef ORT_ENABLE_STREAM
|
||||
DeviceStreamCollectionHolder& device_stream_collection_holder,
|
||||
#endif
|
||||
bool only_execute_path_to_fetches,
|
||||
Stream* parent_stream) {
|
||||
ORT_RETURN_IF_ERROR(utils::InitializeFeedFetchCopyInfo(session_state, feeds_fetches_manager));
|
||||
|
|
@ -758,18 +745,14 @@ common::Status ExecuteGraph(const SessionState& session_state,
|
|||
// finalize the copy info using the provided feeds and fetches. will update device_copy_checks in the background
|
||||
FinalizeFeedFetchCopyInfo(feeds_fetches_manager, feeds, fetches);
|
||||
#ifdef ORT_ENABLE_STREAM
|
||||
DeviceStreamCollectionHolder device_stream_collection_holder(session_state);
|
||||
DeviceStreamCollection* device_stream_collection = device_stream_collection_holder.p_.get();
|
||||
auto retval = ExecuteGraphImpl(session_state, feeds_fetches_manager, feeds, fetches, {},
|
||||
execution_mode, terminate_flag, logger,
|
||||
device_stream_collection,
|
||||
only_execute_path_to_fetches,
|
||||
parent_stream);
|
||||
if (device_stream_collection)
|
||||
ORT_CHECK_AND_SET_RETVAL(device_stream_collection->CleanUp(sync_execution_provider));
|
||||
return retval;
|
||||
#else
|
||||
ORT_UNUSED_PARAMETER(sync_execution_provider);
|
||||
return ExecuteGraphImpl(session_state, feeds_fetches_manager, feeds, fetches, {},
|
||||
execution_mode, terminate_flag, logger,
|
||||
only_execute_path_to_fetches,
|
||||
|
|
@ -781,6 +764,9 @@ common::Status ExecuteGraph(const SessionState& session_state,
|
|||
FeedsFetchesManager& feeds_fetches_manager,
|
||||
gsl::span<const OrtValue> feeds, std::vector<OrtValue>& fetches,
|
||||
ExecutionMode execution_mode, const RunOptions& run_options,
|
||||
#ifdef ORT_ENABLE_STREAM
|
||||
DeviceStreamCollectionHolder& device_stream_collection_holder,
|
||||
#endif
|
||||
const logging::Logger& logger) {
|
||||
#ifdef USE_AZURE
|
||||
const auto iter = run_options.config_options.configurations.find("use_azure");
|
||||
|
|
@ -793,14 +779,15 @@ common::Status ExecuteGraph(const SessionState& session_state,
|
|||
logger);
|
||||
}
|
||||
#endif
|
||||
bool synchronize_execution_providers = run_options.config_options.GetConfigOrDefault(kOrtRunOptionsConfigDisableSynchronizeExecutionProviders, "0") == "0";
|
||||
return ExecuteGraph(session_state,
|
||||
feeds_fetches_manager,
|
||||
feeds, fetches,
|
||||
execution_mode,
|
||||
run_options.terminate,
|
||||
logger,
|
||||
synchronize_execution_providers,
|
||||
#ifdef ORT_ENABLE_STREAM
|
||||
device_stream_collection_holder,
|
||||
#endif
|
||||
run_options.only_execute_path_to_fetches);
|
||||
}
|
||||
|
||||
|
|
@ -946,7 +933,7 @@ common::Status ExecuteSubgraph(const SessionState& session_state, const FeedsFet
|
|||
Stream* parent_stream,
|
||||
bool sync_subgraph_fetches) {
|
||||
#ifdef ORT_ENABLE_STREAM
|
||||
DeviceStreamCollectionHolder device_stream_collection_holder(session_state);
|
||||
DeviceStreamCollectionHolder device_stream_collection_holder(&session_state);
|
||||
DeviceStreamCollection* device_stream_collection = device_stream_collection_holder.p_.get();
|
||||
|
||||
auto retval = ExecuteGraphImpl(session_state, feeds_fetches_manager, feeds, fetches, fetch_allocators,
|
||||
|
|
|
|||
|
|
@ -84,13 +84,19 @@ void FinalizeFeedFetchCopyInfo(FeedsFetchesManager& feeds_fetches_manager,
|
|||
common::Status ExecuteGraph(const SessionState& session_state, FeedsFetchesManager& feeds_fetches_manager,
|
||||
gsl::span<const OrtValue> feeds, std::vector<OrtValue>& fetches,
|
||||
ExecutionMode execution_mode, const bool& terminate_flag, const logging::Logger& logger,
|
||||
bool sync_execution_provider,
|
||||
#ifdef ORT_ENABLE_STREAM
|
||||
DeviceStreamCollectionHolder& device_stream_collection_holder,
|
||||
#endif
|
||||
bool only_execute_path_to_fetches = false,
|
||||
Stream* parent_stream = nullptr);
|
||||
|
||||
common::Status ExecuteGraph(const SessionState& session_state, FeedsFetchesManager& feeds_fetches_manager,
|
||||
gsl::span<const OrtValue> feeds, std::vector<OrtValue>& fetches,
|
||||
ExecutionMode execution_mode, const RunOptions& run_options, const logging::Logger& logger);
|
||||
ExecutionMode execution_mode, const RunOptions& run_options,
|
||||
#ifdef ORT_ENABLE_STREAM
|
||||
DeviceStreamCollectionHolder& device_stream_collection_holder,
|
||||
#endif
|
||||
const logging::Logger& logger);
|
||||
|
||||
#ifdef ENABLE_TRAINING
|
||||
common::Status ExecutePartialGraph(const SessionState& session_state, FeedsFetchesManager& feeds_fetches_manager,
|
||||
|
|
|
|||
|
|
@ -221,7 +221,12 @@ class CUDAExecutionProvider : public IExecutionProvider {
|
|||
CUDAGraph cuda_graph_;
|
||||
bool is_graph_captured_ = false;
|
||||
int regular_run_count_before_graph_capture_ = 0;
|
||||
const int min_num_runs_before_cuda_graph_capture_ = 1; // required min regular runs before graph capture for the necessary memory allocations.
|
||||
|
||||
// There is chance that the second regular run allocates GPU memory for causes like:
|
||||
// (1) memory pattern is enabled. (2) arena allocation for stream.
|
||||
// Since no GPU memory allocation is allowed during graph capturing, we need at least two regular runs
|
||||
// to allocate enough memory in Arena before graph capturing.
|
||||
const int min_num_runs_before_cuda_graph_capture_ = 2; // required min regular runs before graph capture for the necessary memory allocations.
|
||||
};
|
||||
|
||||
using PerThreadContextMap = std::unordered_map<const CUDAExecutionProvider*, std::weak_ptr<PerThreadContext>>;
|
||||
|
|
|
|||
|
|
@ -1533,32 +1533,30 @@ common::Status InferenceSession::Initialize() {
|
|||
// Then the CUDA EP is cached for triggering a ReplayGraph() in Run().
|
||||
auto* cuda_ep = execution_providers_.Get(onnxruntime::kCudaExecutionProvider);
|
||||
if (cuda_ep && cuda_ep->IsGraphCaptureEnabled()) {
|
||||
if (cuda_ep->IsGraphCaptureEnabled()) {
|
||||
if (HasControlflowNodes(graph)) {
|
||||
LOGS(*session_logger_, ERROR) << "This session cannot use the CUDA Graph feature as requested by the user "
|
||||
<< " as the model has control flow nodes which can't be supported by CUDA Graphs.";
|
||||
if (HasControlflowNodes(graph)) {
|
||||
LOGS(*session_logger_, ERROR) << "This session cannot use the CUDA Graph feature as requested by the user "
|
||||
<< " as the model has control flow nodes which can't be supported by CUDA Graphs.";
|
||||
|
||||
// Return error status as we don't want the session initialization to complete successfully
|
||||
// if the user has requested usage of CUDA Graph feature and we cannot honor that.
|
||||
ORT_RETURN_IF_ERROR_SESSIONID_(
|
||||
ORT_MAKE_STATUS(ONNXRUNTIME, FAIL,
|
||||
"This session cannot use the CUDA Graph feature as requested by the user "
|
||||
" as the model has control flow nodes which can't be supported by CUDA Graphs."));
|
||||
} else if (!AreAllNodesInMainGraphAssignedToOneEp(graph, onnxruntime::kCudaExecutionProvider)) {
|
||||
LOGS(*session_logger_, ERROR) << "This session cannot use the CUDA Graph feature as requested by the user "
|
||||
<< " as all the graph nodes have not been partitioned to the CUDA EP.";
|
||||
// Return error status as we don't want the session initialization to complete successfully
|
||||
// if the user has requested usage of CUDA Graph feature and we cannot honor that.
|
||||
ORT_RETURN_IF_ERROR_SESSIONID_(
|
||||
ORT_MAKE_STATUS(ONNXRUNTIME, FAIL,
|
||||
"This session cannot use the CUDA Graph feature as requested by the user "
|
||||
" as the model has control flow nodes which can't be supported by CUDA Graphs."));
|
||||
} else if (!AreAllNodesInMainGraphAssignedToOneEp(graph, onnxruntime::kCudaExecutionProvider)) {
|
||||
LOGS(*session_logger_, ERROR) << "This session cannot use the CUDA Graph feature as requested by the user "
|
||||
<< " as all the graph nodes have not been partitioned to the CUDA EP.";
|
||||
|
||||
// Return error status as we don't want the session initialization to complete successfully
|
||||
// if the user has requested usage of CUDA Graph feature and we cannot honor that.
|
||||
ORT_RETURN_IF_ERROR_SESSIONID_(
|
||||
ORT_MAKE_STATUS(ONNXRUNTIME, FAIL,
|
||||
"This session cannot use the CUDA Graph feature as requested by the user "
|
||||
" as all the graph nodes have not been partitioned to the CUDA EP."));
|
||||
// Return error status as we don't want the session initialization to complete successfully
|
||||
// if the user has requested usage of CUDA Graph feature and we cannot honor that.
|
||||
ORT_RETURN_IF_ERROR_SESSIONID_(
|
||||
ORT_MAKE_STATUS(ONNXRUNTIME, FAIL,
|
||||
"This session cannot use the CUDA Graph feature as requested by the user "
|
||||
" as all the graph nodes have not been partitioned to the CUDA EP."));
|
||||
|
||||
} else {
|
||||
LOGS(*session_logger_, INFO) << "This session will use the CUDA Graph feature as requested by the user.";
|
||||
cached_execution_provider_for_graph_replay_.SetExecutionProvider(cuda_ep);
|
||||
}
|
||||
} else {
|
||||
LOGS(*session_logger_, INFO) << "This session will use the CUDA Graph feature as requested by the user.";
|
||||
cached_execution_provider_for_graph_replay_.SetExecutionProvider(cuda_ep);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -2141,9 +2139,37 @@ Status InferenceSession::Run(const RunOptions& run_options,
|
|||
session_state_->IncrementGraphExecutionCounter();
|
||||
#endif
|
||||
|
||||
ORT_CHECK_AND_SET_RETVAL(utils::ExecuteGraph(*session_state_, feeds_fetches_manager, feeds, *p_fetches,
|
||||
session_options_.execution_mode,
|
||||
run_options, run_logger));
|
||||
#ifdef ORT_ENABLE_STREAM
|
||||
DeviceStreamCollectionHolder device_stream_collection_holder(session_state_.get());
|
||||
#endif
|
||||
|
||||
if (retval.IsOK()) {
|
||||
retval = utils::ExecuteGraph(*session_state_, feeds_fetches_manager, feeds, *p_fetches,
|
||||
session_options_.execution_mode,
|
||||
run_options,
|
||||
#ifdef ORT_ENABLE_STREAM
|
||||
device_stream_collection_holder,
|
||||
#endif
|
||||
run_logger);
|
||||
}
|
||||
|
||||
// info all execution providers InferenceSession:Run ended
|
||||
for (auto* xp : exec_providers_to_stop) {
|
||||
bool synchronize_execution_providers = run_options.config_options.GetConfigOrDefault(kOrtRunOptionsConfigDisableSynchronizeExecutionProviders, "0") == "0";
|
||||
auto status = xp->OnRunEnd(synchronize_execution_providers);
|
||||
ORT_CHECK_AND_SET_RETVAL(status);
|
||||
}
|
||||
|
||||
// Move stream cleanup from ExecuteGraph to here for cuda graph capture.
|
||||
// Cleanup will call cudaStreamSyncronize, which is not allowed for graph capture.
|
||||
// Note that graph capture ends when we call xp->OnRunEnd() in the above code so it is safe here.
|
||||
#ifdef ORT_ENABLE_STREAM
|
||||
DeviceStreamCollection* device_stream_collection = device_stream_collection_holder.p_.get();
|
||||
if (device_stream_collection) {
|
||||
bool sync_execution_provider = run_options.config_options.GetConfigOrDefault(kOrtRunOptionsConfigDisableSynchronizeExecutionProviders, "0") == "0";
|
||||
ORT_CHECK_AND_SET_RETVAL(device_stream_collection->CleanUp(sync_execution_provider));
|
||||
}
|
||||
#endif
|
||||
}
|
||||
ORT_CATCH(const std::exception& e) {
|
||||
ORT_HANDLE_EXCEPTION([&]() {
|
||||
|
|
@ -2154,13 +2180,6 @@ Status InferenceSession::Run(const RunOptions& run_options,
|
|||
retval = Status(common::ONNXRUNTIME, common::RUNTIME_EXCEPTION, "Encountered unknown exception in Run()");
|
||||
}
|
||||
|
||||
// info all execution providers InferenceSession:Run ended
|
||||
for (auto* xp : exec_providers_to_stop) {
|
||||
bool synchronize_execution_providers = run_options.config_options.GetConfigOrDefault(kOrtRunOptionsConfigDisableSynchronizeExecutionProviders, "0") == "0";
|
||||
auto status = xp->OnRunEnd(synchronize_execution_providers);
|
||||
ORT_CHECK_AND_SET_RETVAL(status);
|
||||
}
|
||||
|
||||
if (!arenas_to_shrink.empty()) {
|
||||
ShrinkMemoryArenas(arenas_to_shrink);
|
||||
}
|
||||
|
|
@ -2192,15 +2211,13 @@ Status InferenceSession::Run(const RunOptions& run_options,
|
|||
TraceLoggingWriteStop(ortrun_activity, "OrtRun");
|
||||
#endif
|
||||
|
||||
// As two inference runs (one for memory allocation and one for graph capturing)
|
||||
// are needed before replaying the captured graph, here run the inference again
|
||||
// to capture the graph, so that users just need one session run to capture
|
||||
// the graph.
|
||||
// As N+1 inference runs (N for memory allocation and 1 for graph capturing)
|
||||
// are needed before replaying the captured graph, here run N inference runs recursively until graph captured,
|
||||
// so that users just need one session run to capture the graph.
|
||||
// N is defined in min_num_runs_before_cuda_graph_capture_ for CUDA EP, and the value could be different for other EP.
|
||||
if (retval.IsOK() && cached_execution_provider_for_graph_replay_.IsGraphCaptureEnabled() &&
|
||||
!cached_execution_provider_for_graph_replay_.IsGraphCaptured()) {
|
||||
LOGS(*session_logger_, INFO) << "Start the second Run() to capture the graph. "
|
||||
"The first one is for necessary memory allocation;"
|
||||
"The second one is for capturing the graph.";
|
||||
LOGS(*session_logger_, INFO) << "Start another run for necessary memory allocation or graph capture.";
|
||||
ORT_RETURN_IF_ERROR(Run(run_options, feed_names, feeds, output_names, p_fetches, p_fetches_device_info));
|
||||
}
|
||||
return retval;
|
||||
|
|
|
|||
|
|
@ -1,20 +1,63 @@
|
|||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import gc # noqa: F401
|
||||
import os # noqa: F401
|
||||
import sys # noqa: F401
|
||||
import threading # noqa: F401
|
||||
import time # noqa: F401
|
||||
|
||||
# -*- coding: UTF-8 -*-
|
||||
import unittest
|
||||
from typing import Dict, List
|
||||
|
||||
import numpy as np
|
||||
from helper import get_name
|
||||
|
||||
import onnxruntime as onnxrt
|
||||
from onnxruntime.capi.onnxruntime_pybind11_state import Fail # noqa: F401
|
||||
|
||||
|
||||
class CudaGraphHelper:
|
||||
def __init__(
|
||||
self,
|
||||
ort_session: onnxrt.InferenceSession,
|
||||
input_and_output_shape: Dict[str, List[int]],
|
||||
device_id: int = 0,
|
||||
):
|
||||
self.input_names = [input.name for input in ort_session.get_inputs()]
|
||||
self.output_names = [output.name for output in ort_session.get_outputs()]
|
||||
|
||||
self.input_and_output_shape = input_and_output_shape
|
||||
self.io_numpy_type = self.get_io_numpy_type_map(ort_session)
|
||||
self.io_binding = ort_session.io_binding()
|
||||
self.io_ort_value = {}
|
||||
|
||||
for name in self.input_names + self.output_names:
|
||||
ort_value = onnxrt.OrtValue.ortvalue_from_shape_and_type(
|
||||
input_and_output_shape[name], self.io_numpy_type[name], "cuda", device_id
|
||||
)
|
||||
self.io_ort_value[name] = ort_value
|
||||
if name in self.input_names:
|
||||
self.io_binding.bind_ortvalue_input(name, ort_value)
|
||||
else:
|
||||
self.io_binding.bind_ortvalue_output(name, ort_value)
|
||||
|
||||
def get_io_numpy_type_map(self, ort_session: onnxrt.InferenceSession):
|
||||
ort_type_to_numpy_type = {
|
||||
"tensor(int64)": np.longlong,
|
||||
"tensor(int32)": np.intc,
|
||||
"tensor(float)": np.float32,
|
||||
"tensor(float16)": np.float16,
|
||||
}
|
||||
|
||||
name_to_numpy_type = {}
|
||||
for _input in ort_session.get_inputs():
|
||||
name_to_numpy_type[_input.name] = ort_type_to_numpy_type[_input.type]
|
||||
|
||||
for output in ort_session.get_outputs():
|
||||
name_to_numpy_type[output.name] = ort_type_to_numpy_type[output.type]
|
||||
|
||||
return name_to_numpy_type
|
||||
|
||||
def update_inputs(self, inputs: Dict[str, np.ndarray]):
|
||||
for input_name in self.input_names:
|
||||
self.io_ort_value[input_name].update_inplace(inputs[input_name])
|
||||
|
||||
def get_output(self, output_name: str):
|
||||
return self.io_ort_value[output_name].numpy()
|
||||
|
||||
|
||||
class TestInferenceSessionWithCudaGraph(unittest.TestCase):
|
||||
|
|
@ -74,6 +117,44 @@ class TestInferenceSessionWithCudaGraph(unittest.TestCase):
|
|||
atol=1e-05,
|
||||
)
|
||||
|
||||
def testArenaWithCudaGraph(self): # noqa: N802
|
||||
if "CUDAExecutionProvider" in onnxrt.get_available_providers():
|
||||
# To test cuda graph catpure, we set Arena extend strategy to be SameAsRequested so as to detect any
|
||||
# potential memory allocation after the first run.
|
||||
providers = [
|
||||
("CUDAExecutionProvider", {"enable_cuda_graph": True, "arena_extend_strategy": "kSameAsRequested"})
|
||||
]
|
||||
test_model_path = get_name("squeezenet/model.onnx")
|
||||
|
||||
input_and_output_shape = {
|
||||
"data_0": [16, 3, 224, 224],
|
||||
"softmaxout_1": [16, 1000, 1, 1],
|
||||
}
|
||||
|
||||
session_options = onnxrt.SessionOptions()
|
||||
# It is optional to disable memory pattern since min_num_runs_before_cuda_graph_capture_ = 2.
|
||||
session_options.enable_mem_pattern = False
|
||||
session = onnxrt.InferenceSession(test_model_path, session_options, providers=providers)
|
||||
|
||||
cuda_graph_helper = CudaGraphHelper(session, input_and_output_shape)
|
||||
io_binding = cuda_graph_helper.io_binding
|
||||
|
||||
# Create a random input for testing.
|
||||
np.random.seed(0)
|
||||
inputs = {"data_0": np.random.randint(0, 256, size=input_and_output_shape["data_0"]).astype(np.float32)}
|
||||
|
||||
# One regular run for the necessary memory allocation and cuda graph capturing
|
||||
cuda_graph_helper.update_inputs(inputs)
|
||||
session.run_with_iobinding(io_binding)
|
||||
expected_output = cuda_graph_helper.get_output("softmaxout_1")
|
||||
|
||||
# After capturing, CUDA graph replay happens from this Run onwards
|
||||
cuda_graph_helper.update_inputs(inputs)
|
||||
session.run_with_iobinding(io_binding)
|
||||
output = cuda_graph_helper.get_output("softmaxout_1")
|
||||
|
||||
np.testing.assert_allclose(expected_output, output, rtol=1e-02, atol=1e-02)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
|
|
|||
BIN
onnxruntime/test/testdata/squeezenet/model.onnx
vendored
BIN
onnxruntime/test/testdata/squeezenet/model.onnx
vendored
Binary file not shown.
Loading…
Reference in a new issue