mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-04 23:59:56 +00:00
[CUDA] support user_compute_stream in python API (#19229)
### Description It is an important feature to pass user cuda stream to avoid synchronization in python API. Here we allow user to pass cuda stream for CUDA provider. Note that TRT or ROCm provider need similar change, which are not included in this pull request. Note that we will set `has_user_compute_stream` automatically based on whether there is cuda stream passed, so setting `has_user_compute_stream` through python API has no effect. ### Motivation and Context <!-- - Why is this change required? What problem does it solve? - If it fixes an open issue, please link to the issue here. --> https://github.com/microsoft/onnxruntime/issues/19094
This commit is contained in:
parent
7d4dc66846
commit
d7ff81dfb7
2 changed files with 35 additions and 0 deletions
|
|
@ -16,6 +16,7 @@ namespace cuda {
|
|||
namespace provider_option_names {
|
||||
constexpr const char* kDeviceId = "device_id";
|
||||
constexpr const char* kHasUserComputeStream = "has_user_compute_stream";
|
||||
constexpr const char* kUserComputeStream = "user_compute_stream";
|
||||
constexpr const char* kMemLimit = "gpu_mem_limit";
|
||||
constexpr const char* kArenaExtendStrategy = "arena_extend_strategy";
|
||||
constexpr const char* kCudnnConvAlgoSearch = "cudnn_conv_algo_search";
|
||||
|
|
@ -51,6 +52,7 @@ CUDAExecutionProviderInfo CUDAExecutionProviderInfo::FromProviderOptions(const P
|
|||
void* alloc = nullptr;
|
||||
void* free = nullptr;
|
||||
void* empty_cache = nullptr;
|
||||
void* user_compute_stream = nullptr;
|
||||
ORT_THROW_IF_ERROR(
|
||||
ProviderOptionsParser{}
|
||||
.AddValueParser(
|
||||
|
|
@ -66,6 +68,14 @@ CUDAExecutionProviderInfo CUDAExecutionProviderInfo::FromProviderOptions(const P
|
|||
return Status::OK();
|
||||
})
|
||||
.AddAssignmentToReference(cuda::provider_option_names::kHasUserComputeStream, info.has_user_compute_stream)
|
||||
.AddValueParser(
|
||||
cuda::provider_option_names::kUserComputeStream,
|
||||
[&user_compute_stream](const std::string& value_str) -> Status {
|
||||
size_t address;
|
||||
ORT_RETURN_IF_ERROR(ParseStringWithClassicLocale(value_str, address));
|
||||
user_compute_stream = reinterpret_cast<void*>(address);
|
||||
return Status::OK();
|
||||
})
|
||||
.AddValueParser(
|
||||
cuda::provider_option_names::kGpuExternalAlloc,
|
||||
[&alloc](const std::string& value_str) -> Status {
|
||||
|
|
@ -126,6 +136,10 @@ CUDAExecutionProviderInfo CUDAExecutionProviderInfo::FromProviderOptions(const P
|
|||
|
||||
CUDAExecutionProviderExternalAllocatorInfo alloc_info{alloc, free, empty_cache};
|
||||
info.external_allocator_info = alloc_info;
|
||||
|
||||
info.user_compute_stream = user_compute_stream;
|
||||
info.has_user_compute_stream = (user_compute_stream != nullptr);
|
||||
|
||||
return info;
|
||||
}
|
||||
|
||||
|
|
@ -133,6 +147,7 @@ ProviderOptions CUDAExecutionProviderInfo::ToProviderOptions(const CUDAExecution
|
|||
const ProviderOptions options{
|
||||
{cuda::provider_option_names::kDeviceId, MakeStringWithClassicLocale(info.device_id)},
|
||||
{cuda::provider_option_names::kHasUserComputeStream, MakeStringWithClassicLocale(info.has_user_compute_stream)},
|
||||
{cuda::provider_option_names::kUserComputeStream, MakeStringWithClassicLocale(reinterpret_cast<size_t>(info.user_compute_stream))},
|
||||
{cuda::provider_option_names::kMemLimit, MakeStringWithClassicLocale(info.gpu_mem_limit)},
|
||||
{cuda::provider_option_names::kGpuExternalAlloc, MakeStringWithClassicLocale(reinterpret_cast<size_t>(info.external_allocator_info.alloc))},
|
||||
{cuda::provider_option_names::kGpuExternalFree, MakeStringWithClassicLocale(reinterpret_cast<size_t>(info.external_allocator_info.free))},
|
||||
|
|
@ -160,6 +175,7 @@ ProviderOptions CUDAExecutionProviderInfo::ToProviderOptions(const OrtCUDAProvid
|
|||
const ProviderOptions options{
|
||||
{cuda::provider_option_names::kDeviceId, MakeStringWithClassicLocale(info.device_id)},
|
||||
{cuda::provider_option_names::kHasUserComputeStream, MakeStringWithClassicLocale(info.has_user_compute_stream)},
|
||||
{cuda::provider_option_names::kUserComputeStream, MakeStringWithClassicLocale(reinterpret_cast<size_t>(info.user_compute_stream))},
|
||||
{cuda::provider_option_names::kMemLimit, MakeStringWithClassicLocale(info.gpu_mem_limit)},
|
||||
{cuda::provider_option_names::kArenaExtendStrategy, EnumToName(arena_extend_strategy_mapping, info.arena_extend_strategy)},
|
||||
{cuda::provider_option_names::kCudnnConvAlgoSearch, EnumToName(ort_cudnn_conv_algo_search_mapping, info.cudnn_conv_algo_search)},
|
||||
|
|
|
|||
|
|
@ -434,6 +434,25 @@ class TestInferenceSession(unittest.TestCase):
|
|||
self.assertEqual(options["CUDAExecutionProvider"]["gpu_external_alloc"], "0")
|
||||
self.assertEqual(options["CUDAExecutionProvider"]["gpu_external_free"], "0")
|
||||
self.assertEqual(options["CUDAExecutionProvider"]["gpu_external_empty_cache"], "0")
|
||||
|
||||
option["user_compute_stream"] = "0"
|
||||
sess.set_providers(["CUDAExecutionProvider"], [option])
|
||||
options = sess.get_provider_options()
|
||||
self.assertEqual(options["CUDAExecutionProvider"]["user_compute_stream"], "0")
|
||||
|
||||
try:
|
||||
import torch
|
||||
|
||||
if torch.cuda.is_available():
|
||||
s = torch.cuda.Stream()
|
||||
option["user_compute_stream"] = str(s.cuda_stream)
|
||||
sess.set_providers(["CUDAExecutionProvider"], [option])
|
||||
options = sess.get_provider_options()
|
||||
self.assertEqual(options["CUDAExecutionProvider"]["user_compute_stream"], str(s.cuda_stream))
|
||||
self.assertEqual(options["CUDAExecutionProvider"]["has_user_compute_stream"], "1")
|
||||
except ImportError:
|
||||
print("torch is not installed, skip testing setting user_compute_stream from torch cuda stream")
|
||||
|
||||
#
|
||||
# Note: Tests that throw an exception leave an empty session due to how set_providers currently works,
|
||||
# so run them last. Each set_providers call will attempt to re-create a session, so it's
|
||||
|
|
|
|||
Loading…
Reference in a new issue