[CUDA] support user_compute_stream in python API (#19229)

### Description
It is an important feature to pass user cuda stream to avoid
synchronization in python API. Here we allow user to pass cuda stream
for CUDA provider. Note that TRT or ROCm provider need similar change,
which are not included in this pull request.

Note that we will set `has_user_compute_stream` automatically based on
whether there is cuda stream passed, so setting
`has_user_compute_stream` through python API has no effect.

### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->

https://github.com/microsoft/onnxruntime/issues/19094
This commit is contained in:
Tianlei Wu 2024-01-26 10:34:43 -08:00 committed by GitHub
parent 7d4dc66846
commit d7ff81dfb7
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 35 additions and 0 deletions

View file

@ -16,6 +16,7 @@ namespace cuda {
namespace provider_option_names {
constexpr const char* kDeviceId = "device_id";
constexpr const char* kHasUserComputeStream = "has_user_compute_stream";
constexpr const char* kUserComputeStream = "user_compute_stream";
constexpr const char* kMemLimit = "gpu_mem_limit";
constexpr const char* kArenaExtendStrategy = "arena_extend_strategy";
constexpr const char* kCudnnConvAlgoSearch = "cudnn_conv_algo_search";
@ -51,6 +52,7 @@ CUDAExecutionProviderInfo CUDAExecutionProviderInfo::FromProviderOptions(const P
void* alloc = nullptr;
void* free = nullptr;
void* empty_cache = nullptr;
void* user_compute_stream = nullptr;
ORT_THROW_IF_ERROR(
ProviderOptionsParser{}
.AddValueParser(
@ -66,6 +68,14 @@ CUDAExecutionProviderInfo CUDAExecutionProviderInfo::FromProviderOptions(const P
return Status::OK();
})
.AddAssignmentToReference(cuda::provider_option_names::kHasUserComputeStream, info.has_user_compute_stream)
.AddValueParser(
cuda::provider_option_names::kUserComputeStream,
[&user_compute_stream](const std::string& value_str) -> Status {
size_t address;
ORT_RETURN_IF_ERROR(ParseStringWithClassicLocale(value_str, address));
user_compute_stream = reinterpret_cast<void*>(address);
return Status::OK();
})
.AddValueParser(
cuda::provider_option_names::kGpuExternalAlloc,
[&alloc](const std::string& value_str) -> Status {
@ -126,6 +136,10 @@ CUDAExecutionProviderInfo CUDAExecutionProviderInfo::FromProviderOptions(const P
CUDAExecutionProviderExternalAllocatorInfo alloc_info{alloc, free, empty_cache};
info.external_allocator_info = alloc_info;
info.user_compute_stream = user_compute_stream;
info.has_user_compute_stream = (user_compute_stream != nullptr);
return info;
}
@ -133,6 +147,7 @@ ProviderOptions CUDAExecutionProviderInfo::ToProviderOptions(const CUDAExecution
const ProviderOptions options{
{cuda::provider_option_names::kDeviceId, MakeStringWithClassicLocale(info.device_id)},
{cuda::provider_option_names::kHasUserComputeStream, MakeStringWithClassicLocale(info.has_user_compute_stream)},
{cuda::provider_option_names::kUserComputeStream, MakeStringWithClassicLocale(reinterpret_cast<size_t>(info.user_compute_stream))},
{cuda::provider_option_names::kMemLimit, MakeStringWithClassicLocale(info.gpu_mem_limit)},
{cuda::provider_option_names::kGpuExternalAlloc, MakeStringWithClassicLocale(reinterpret_cast<size_t>(info.external_allocator_info.alloc))},
{cuda::provider_option_names::kGpuExternalFree, MakeStringWithClassicLocale(reinterpret_cast<size_t>(info.external_allocator_info.free))},
@ -160,6 +175,7 @@ ProviderOptions CUDAExecutionProviderInfo::ToProviderOptions(const OrtCUDAProvid
const ProviderOptions options{
{cuda::provider_option_names::kDeviceId, MakeStringWithClassicLocale(info.device_id)},
{cuda::provider_option_names::kHasUserComputeStream, MakeStringWithClassicLocale(info.has_user_compute_stream)},
{cuda::provider_option_names::kUserComputeStream, MakeStringWithClassicLocale(reinterpret_cast<size_t>(info.user_compute_stream))},
{cuda::provider_option_names::kMemLimit, MakeStringWithClassicLocale(info.gpu_mem_limit)},
{cuda::provider_option_names::kArenaExtendStrategy, EnumToName(arena_extend_strategy_mapping, info.arena_extend_strategy)},
{cuda::provider_option_names::kCudnnConvAlgoSearch, EnumToName(ort_cudnn_conv_algo_search_mapping, info.cudnn_conv_algo_search)},

View file

@ -434,6 +434,25 @@ class TestInferenceSession(unittest.TestCase):
self.assertEqual(options["CUDAExecutionProvider"]["gpu_external_alloc"], "0")
self.assertEqual(options["CUDAExecutionProvider"]["gpu_external_free"], "0")
self.assertEqual(options["CUDAExecutionProvider"]["gpu_external_empty_cache"], "0")
option["user_compute_stream"] = "0"
sess.set_providers(["CUDAExecutionProvider"], [option])
options = sess.get_provider_options()
self.assertEqual(options["CUDAExecutionProvider"]["user_compute_stream"], "0")
try:
import torch
if torch.cuda.is_available():
s = torch.cuda.Stream()
option["user_compute_stream"] = str(s.cuda_stream)
sess.set_providers(["CUDAExecutionProvider"], [option])
options = sess.get_provider_options()
self.assertEqual(options["CUDAExecutionProvider"]["user_compute_stream"], str(s.cuda_stream))
self.assertEqual(options["CUDAExecutionProvider"]["has_user_compute_stream"], "1")
except ImportError:
print("torch is not installed, skip testing setting user_compute_stream from torch cuda stream")
#
# Note: Tests that throw an exception leave an empty session due to how set_providers currently works,
# so run them last. Each set_providers call will attempt to re-create a session, so it's