From d7ff81dfb77989a8ce975db29457e5cdfc00f9e3 Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Fri, 26 Jan 2024 10:34:43 -0800 Subject: [PATCH] [CUDA] support user_compute_stream in python API (#19229) ### Description It is an important feature to pass user cuda stream to avoid synchronization in python API. Here we allow user to pass cuda stream for CUDA provider. Note that TRT or ROCm provider need similar change, which are not included in this pull request. Note that we will set `has_user_compute_stream` automatically based on whether there is cuda stream passed, so setting `has_user_compute_stream` through python API has no effect. ### Motivation and Context https://github.com/microsoft/onnxruntime/issues/19094 --- .../cuda/cuda_execution_provider_info.cc | 16 ++++++++++++++++ .../test/python/onnxruntime_test_python.py | 19 +++++++++++++++++++ 2 files changed, 35 insertions(+) diff --git a/onnxruntime/core/providers/cuda/cuda_execution_provider_info.cc b/onnxruntime/core/providers/cuda/cuda_execution_provider_info.cc index daa3b5ff3d..7b507296d5 100644 --- a/onnxruntime/core/providers/cuda/cuda_execution_provider_info.cc +++ b/onnxruntime/core/providers/cuda/cuda_execution_provider_info.cc @@ -16,6 +16,7 @@ namespace cuda { namespace provider_option_names { constexpr const char* kDeviceId = "device_id"; constexpr const char* kHasUserComputeStream = "has_user_compute_stream"; +constexpr const char* kUserComputeStream = "user_compute_stream"; constexpr const char* kMemLimit = "gpu_mem_limit"; constexpr const char* kArenaExtendStrategy = "arena_extend_strategy"; constexpr const char* kCudnnConvAlgoSearch = "cudnn_conv_algo_search"; @@ -51,6 +52,7 @@ CUDAExecutionProviderInfo CUDAExecutionProviderInfo::FromProviderOptions(const P void* alloc = nullptr; void* free = nullptr; void* empty_cache = nullptr; + void* user_compute_stream = nullptr; ORT_THROW_IF_ERROR( ProviderOptionsParser{} .AddValueParser( @@ -66,6 +68,14 @@ CUDAExecutionProviderInfo CUDAExecutionProviderInfo::FromProviderOptions(const P return Status::OK(); }) .AddAssignmentToReference(cuda::provider_option_names::kHasUserComputeStream, info.has_user_compute_stream) + .AddValueParser( + cuda::provider_option_names::kUserComputeStream, + [&user_compute_stream](const std::string& value_str) -> Status { + size_t address; + ORT_RETURN_IF_ERROR(ParseStringWithClassicLocale(value_str, address)); + user_compute_stream = reinterpret_cast(address); + return Status::OK(); + }) .AddValueParser( cuda::provider_option_names::kGpuExternalAlloc, [&alloc](const std::string& value_str) -> Status { @@ -126,6 +136,10 @@ CUDAExecutionProviderInfo CUDAExecutionProviderInfo::FromProviderOptions(const P CUDAExecutionProviderExternalAllocatorInfo alloc_info{alloc, free, empty_cache}; info.external_allocator_info = alloc_info; + + info.user_compute_stream = user_compute_stream; + info.has_user_compute_stream = (user_compute_stream != nullptr); + return info; } @@ -133,6 +147,7 @@ ProviderOptions CUDAExecutionProviderInfo::ToProviderOptions(const CUDAExecution const ProviderOptions options{ {cuda::provider_option_names::kDeviceId, MakeStringWithClassicLocale(info.device_id)}, {cuda::provider_option_names::kHasUserComputeStream, MakeStringWithClassicLocale(info.has_user_compute_stream)}, + {cuda::provider_option_names::kUserComputeStream, MakeStringWithClassicLocale(reinterpret_cast(info.user_compute_stream))}, {cuda::provider_option_names::kMemLimit, MakeStringWithClassicLocale(info.gpu_mem_limit)}, {cuda::provider_option_names::kGpuExternalAlloc, MakeStringWithClassicLocale(reinterpret_cast(info.external_allocator_info.alloc))}, {cuda::provider_option_names::kGpuExternalFree, MakeStringWithClassicLocale(reinterpret_cast(info.external_allocator_info.free))}, @@ -160,6 +175,7 @@ ProviderOptions CUDAExecutionProviderInfo::ToProviderOptions(const OrtCUDAProvid const ProviderOptions options{ {cuda::provider_option_names::kDeviceId, MakeStringWithClassicLocale(info.device_id)}, {cuda::provider_option_names::kHasUserComputeStream, MakeStringWithClassicLocale(info.has_user_compute_stream)}, + {cuda::provider_option_names::kUserComputeStream, MakeStringWithClassicLocale(reinterpret_cast(info.user_compute_stream))}, {cuda::provider_option_names::kMemLimit, MakeStringWithClassicLocale(info.gpu_mem_limit)}, {cuda::provider_option_names::kArenaExtendStrategy, EnumToName(arena_extend_strategy_mapping, info.arena_extend_strategy)}, {cuda::provider_option_names::kCudnnConvAlgoSearch, EnumToName(ort_cudnn_conv_algo_search_mapping, info.cudnn_conv_algo_search)}, diff --git a/onnxruntime/test/python/onnxruntime_test_python.py b/onnxruntime/test/python/onnxruntime_test_python.py index 8c23286e45..e210917e7a 100644 --- a/onnxruntime/test/python/onnxruntime_test_python.py +++ b/onnxruntime/test/python/onnxruntime_test_python.py @@ -434,6 +434,25 @@ class TestInferenceSession(unittest.TestCase): self.assertEqual(options["CUDAExecutionProvider"]["gpu_external_alloc"], "0") self.assertEqual(options["CUDAExecutionProvider"]["gpu_external_free"], "0") self.assertEqual(options["CUDAExecutionProvider"]["gpu_external_empty_cache"], "0") + + option["user_compute_stream"] = "0" + sess.set_providers(["CUDAExecutionProvider"], [option]) + options = sess.get_provider_options() + self.assertEqual(options["CUDAExecutionProvider"]["user_compute_stream"], "0") + + try: + import torch + + if torch.cuda.is_available(): + s = torch.cuda.Stream() + option["user_compute_stream"] = str(s.cuda_stream) + sess.set_providers(["CUDAExecutionProvider"], [option]) + options = sess.get_provider_options() + self.assertEqual(options["CUDAExecutionProvider"]["user_compute_stream"], str(s.cuda_stream)) + self.assertEqual(options["CUDAExecutionProvider"]["has_user_compute_stream"], "1") + except ImportError: + print("torch is not installed, skip testing setting user_compute_stream from torch cuda stream") + # # Note: Tests that throw an exception leave an empty session due to how set_providers currently works, # so run them last. Each set_providers call will attempt to re-create a session, so it's