From 7361c283c7b32583cffd2fa1bb2e0e6038f2f354 Mon Sep 17 00:00:00 2001 From: Chi Lo <54722500+chilo-ms@users.noreply.github.com> Date: Wed, 9 Aug 2023 09:24:19 -0700 Subject: [PATCH] Add API for updating CUDA EP provider option user compute stream (#17037) Add a generic `UpdateCUDAProviderOptionsWithValue()` C API to update CUDA EP provider options where its data type is pointer that can't be represented by string. Note: Please see some comments for the similar [PR ](https://github.com/microsoft/onnxruntime/pull/16965)for TRT EP. --- .../core/session/onnxruntime_c_api.h | 26 ++++++++++++ .../cuda/cuda_execution_provider_info.cc | 4 ++ .../providers/cuda/cuda_provider_factory.cc | 13 +++++- onnxruntime/core/session/onnxruntime_c_api.cc | 2 + onnxruntime/core/session/ort_apis.h | 2 + .../core/session/provider_bridge_ort.cc | 41 +++++++++++++++++++ .../core/session/provider_registration.cc | 20 +++++++++ onnxruntime/test/shared_lib/test_inference.cc | 15 +++++-- 8 files changed, 119 insertions(+), 4 deletions(-) diff --git a/include/onnxruntime/core/session/onnxruntime_c_api.h b/include/onnxruntime/core/session/onnxruntime_c_api.h index cf50681153..c39f0b8b1e 100644 --- a/include/onnxruntime/core/session/onnxruntime_c_api.h +++ b/include/onnxruntime/core/session/onnxruntime_c_api.h @@ -4372,6 +4372,32 @@ struct OrtApi { * \since Version 1.16. */ ORT_API2_STATUS(GetTensorRTProviderOptionsByName, _In_ const OrtTensorRTProviderOptionsV2* tensorrt_options, _In_ const char* key, _Outptr_ void** ptr); + + /** + * Update CUDA EP provider option where its data type is pointer, for example 'user_compute_stream'. + * If the data type of the provider option can be represented by string please use UpdateCUDAProviderOptions. + * + * Note: It's caller's responsibility to properly manage the lifetime of the instance pointed by this pointer. + * + * \param cuda_options - OrtCUDAProviderOptionsV2 instance + * \param key - Name of the provider option + * \param value - A pointer to the instance that will be assigned to this provider option + * + * \since Version 1.16. + */ + ORT_API2_STATUS(UpdateCUDAProviderOptionsWithValue, _Inout_ OrtCUDAProviderOptionsV2* cuda_options, _In_ const char* key, _In_ void* value); + + /** + * Get CUDA EP provider option where its data type is pointer. + * If the data type of the provider option can be represented by string please use GetCUDAProviderOptionsAsString. + * + * \param cuda_options - OrtCUDAProviderOptionsV2 instance + * \param key - Name of the provider option + * \param ptr - A pointer to the instance that is kept by the provider option + * + * \since Version 1.16. + */ + ORT_API2_STATUS(GetCUDAProviderOptionsByName, _In_ const OrtCUDAProviderOptionsV2* cuda_options, _In_ const char* key, _Outptr_ void** ptr); }; /* diff --git a/onnxruntime/core/providers/cuda/cuda_execution_provider_info.cc b/onnxruntime/core/providers/cuda/cuda_execution_provider_info.cc index e713dfc132..ca88b3474b 100644 --- a/onnxruntime/core/providers/cuda/cuda_execution_provider_info.cc +++ b/onnxruntime/core/providers/cuda/cuda_execution_provider_info.cc @@ -14,6 +14,7 @@ namespace onnxruntime { namespace cuda { namespace provider_option_names { constexpr const char* kDeviceId = "device_id"; +constexpr const char* kHasUserComputeStream = "has_user_compute_stream"; constexpr const char* kMemLimit = "gpu_mem_limit"; constexpr const char* kArenaExtendStrategy = "arena_extend_strategy"; constexpr const char* kCudnnConvAlgoSearch = "cudnn_conv_algo_search"; @@ -61,6 +62,7 @@ CUDAExecutionProviderInfo CUDAExecutionProviderInfo::FromProviderOptions(const P ", must be between 0 (inclusive) and ", num_devices, " (exclusive)."); return Status::OK(); }) + .AddAssignmentToReference(cuda::provider_option_names::kHasUserComputeStream, info.has_user_compute_stream) .AddValueParser( cuda::provider_option_names::kGpuExternalAlloc, [&alloc](const std::string& value_str) -> Status { @@ -125,6 +127,7 @@ CUDAExecutionProviderInfo CUDAExecutionProviderInfo::FromProviderOptions(const P ProviderOptions CUDAExecutionProviderInfo::ToProviderOptions(const CUDAExecutionProviderInfo& info) { const ProviderOptions options{ {cuda::provider_option_names::kDeviceId, MakeStringWithClassicLocale(info.device_id)}, + {cuda::provider_option_names::kHasUserComputeStream, MakeStringWithClassicLocale(info.has_user_compute_stream)}, {cuda::provider_option_names::kMemLimit, MakeStringWithClassicLocale(info.gpu_mem_limit)}, {cuda::provider_option_names::kGpuExternalAlloc, MakeStringWithClassicLocale(reinterpret_cast(info.external_allocator_info.alloc))}, {cuda::provider_option_names::kGpuExternalFree, MakeStringWithClassicLocale(reinterpret_cast(info.external_allocator_info.free))}, @@ -149,6 +152,7 @@ ProviderOptions CUDAExecutionProviderInfo::ToProviderOptions(const CUDAExecution ProviderOptions CUDAExecutionProviderInfo::ToProviderOptions(const OrtCUDAProviderOptionsV2& info) { const ProviderOptions options{ {cuda::provider_option_names::kDeviceId, MakeStringWithClassicLocale(info.device_id)}, + {cuda::provider_option_names::kHasUserComputeStream, MakeStringWithClassicLocale(info.has_user_compute_stream)}, {cuda::provider_option_names::kMemLimit, MakeStringWithClassicLocale(info.gpu_mem_limit)}, {cuda::provider_option_names::kArenaExtendStrategy, EnumToName(arena_extend_strategy_mapping, info.arena_extend_strategy)}, {cuda::provider_option_names::kCudnnConvAlgoSearch, EnumToName(ort_cudnn_conv_algo_search_mapping, info.cudnn_conv_algo_search)}, diff --git a/onnxruntime/core/providers/cuda/cuda_provider_factory.cc b/onnxruntime/core/providers/cuda/cuda_provider_factory.cc index 1c90f04a69..5a11f2529f 100644 --- a/onnxruntime/core/providers/cuda/cuda_provider_factory.cc +++ b/onnxruntime/core/providers/cuda/cuda_provider_factory.cc @@ -226,6 +226,13 @@ struct CUDA_Provider : Provider { return std::make_shared(info); } + /** + * This function will be called by the C API UpdateCUDAProviderOptions(). + * + * What this function does is equivalent to resetting the OrtCUDAProviderOptionsV2 instance with + * default CUDAExecutionProviderInf instance first and then set up the provided provider options. + * See CUDAExecutionProviderInfo::FromProviderOptions() for more details. + */ void UpdateProviderOptions(void* provider_options, const ProviderOptions& options) override { auto internal_options = onnxruntime::CUDAExecutionProviderInfo::FromProviderOptions(options); auto& cuda_options = *reinterpret_cast(provider_options); @@ -236,7 +243,11 @@ struct CUDA_Provider : Provider { cuda_options.arena_extend_strategy = internal_options.arena_extend_strategy; cuda_options.do_copy_in_default_stream = internal_options.do_copy_in_default_stream; cuda_options.has_user_compute_stream = internal_options.has_user_compute_stream; - cuda_options.user_compute_stream = internal_options.user_compute_stream; + // The 'has_user_compute_stream' of the OrtCUDAProviderOptionsV2 instance can be set byC API UpdateCUDAProviderOptionsWithValue() as well. + // We only set the 'has_user_compute_stream' of the OrtCUDAProviderOptionsV2 instance if it is provided in options + if (options.find("has_user_compute_stream") != options.end()) { + cuda_options.user_compute_stream = internal_options.user_compute_stream; + } cuda_options.default_memory_arena_cfg = internal_options.default_memory_arena_cfg; cuda_options.cudnn_conv_use_max_workspace = internal_options.cudnn_conv_use_max_workspace; cuda_options.enable_cuda_graph = internal_options.enable_cuda_graph; diff --git a/onnxruntime/core/session/onnxruntime_c_api.cc b/onnxruntime/core/session/onnxruntime_c_api.cc index 808e7c3c3c..8795c2db50 100644 --- a/onnxruntime/core/session/onnxruntime_c_api.cc +++ b/onnxruntime/core/session/onnxruntime_c_api.cc @@ -2713,6 +2713,8 @@ static constexpr OrtApi ort_api_1_to_16 = { &OrtApis::RunAsync, &OrtApis::UpdateTensorRTProviderOptionsWithValue, &OrtApis::GetTensorRTProviderOptionsByName, + &OrtApis::UpdateCUDAProviderOptionsWithValue, + &OrtApis::GetCUDAProviderOptionsByName, }; // OrtApiBase can never change as there is no way to know what version of OrtApiBase is returned by OrtGetApiBase. diff --git a/onnxruntime/core/session/ort_apis.h b/onnxruntime/core/session/ort_apis.h index 4cefe3fb96..5a8b2e0069 100644 --- a/onnxruntime/core/session/ort_apis.h +++ b/onnxruntime/core/session/ort_apis.h @@ -487,4 +487,6 @@ ORT_API_STATUS_IMPL(RunAsync, _Inout_ OrtSession* sess, _In_opt_ const OrtRunOpt _In_ RunAsyncCallbackFn run_async_callback, _In_opt_ void* user_data); ORT_API_STATUS_IMPL(UpdateTensorRTProviderOptionsWithValue, _Inout_ OrtTensorRTProviderOptionsV2* tensorrt_options, _In_ const char* key, _In_ void* value); ORT_API_STATUS_IMPL(GetTensorRTProviderOptionsByName, _In_ const OrtTensorRTProviderOptionsV2* tensorrt_options, _In_ const char* key, _Outptr_ void** ptr); +ORT_API_STATUS_IMPL(UpdateCUDAProviderOptionsWithValue, _Inout_ OrtCUDAProviderOptionsV2* cuda_options, _In_ const char* key, _In_ void* value); +ORT_API_STATUS_IMPL(GetCUDAProviderOptionsByName, _In_ const OrtCUDAProviderOptionsV2* cuda_options, _In_ const char* key, _Outptr_ void** ptr); } // namespace OrtApis diff --git a/onnxruntime/core/session/provider_bridge_ort.cc b/onnxruntime/core/session/provider_bridge_ort.cc index 7357feccbf..255c7e36b3 100644 --- a/onnxruntime/core/session/provider_bridge_ort.cc +++ b/onnxruntime/core/session/provider_bridge_ort.cc @@ -1977,6 +1977,47 @@ ORT_API_STATUS_IMPL(OrtApis::GetCUDAProviderOptionsAsString, _In_ const OrtCUDAP API_IMPL_END } +ORT_API_STATUS_IMPL(OrtApis::UpdateCUDAProviderOptionsWithValue, + _Inout_ OrtCUDAProviderOptionsV2* cuda_options, + _In_ const char* key, + _In_ void* value) { + API_IMPL_BEGIN +#ifdef USE_CUDA + if (strcmp(key, "user_compute_stream") == 0) { + cuda_options->has_user_compute_stream = 1; + cuda_options->user_compute_stream = value; + } + return nullptr; +#else + ORT_UNUSED_PARAMETER(cuda_options); + ORT_UNUSED_PARAMETER(key); + ORT_UNUSED_PARAMETER(value); + return CreateStatus(ORT_FAIL, "CUDA execution provider is not enabled in this build."); +#endif + API_IMPL_END +} + +ORT_API_STATUS_IMPL(OrtApis::GetCUDAProviderOptionsByName, + _In_ const OrtCUDAProviderOptionsV2* cuda_options, + _In_ const char* key, + _Outptr_ void** ptr) { + API_IMPL_BEGIN +#ifdef USE_CUDA + if (strcmp(key, "user_compute_stream") == 0) { + *ptr = cuda_options->user_compute_stream; + } else { + *ptr = nullptr; + } + return nullptr; +#else + ORT_UNUSED_PARAMETER(cuda_options); + ORT_UNUSED_PARAMETER(key); + ORT_UNUSED_PARAMETER(ptr); + return CreateStatus(ORT_FAIL, "CUDA execution provider is not enabled in this build."); +#endif + API_IMPL_END +} + ORT_API(void, OrtApis::ReleaseCUDAProviderOptions, _Frees_ptr_opt_ OrtCUDAProviderOptionsV2* ptr) { #ifdef USE_CUDA std::unique_ptr p(ptr); diff --git a/onnxruntime/core/session/provider_registration.cc b/onnxruntime/core/session/provider_registration.cc index bfd66463eb..4cea84a590 100644 --- a/onnxruntime/core/session/provider_registration.cc +++ b/onnxruntime/core/session/provider_registration.cc @@ -220,6 +220,26 @@ ORT_API_STATUS_IMPL(OrtApis::GetCUDAProviderOptionsAsString, _In_ const OrtCUDAP return CreateStatus(ORT_FAIL, "CUDA execution provider is not enabled in this build."); } +ORT_API_STATUS_IMPL(OrtApis::UpdateCUDAProviderOptionsWithValue, + _Inout_ OrtCUDAProviderOptionsV2* cuda_options, + _In_ const char* key, + _In_ void* value) { + ORT_UNUSED_PARAMETER(cuda_options); + ORT_UNUSED_PARAMETER(key); + ORT_UNUSED_PARAMETER(value); + return CreateNotEnabledStatus("CUDA"); +} + +ORT_API_STATUS_IMPL(OrtApis::GetCUDAProviderOptionsByName, + _In_ const OrtCUDAProviderOptionsV2* cuda_options, + _In_ const char* key, + _Outptr_ void** ptr) { + ORT_UNUSED_PARAMETER(cuda_options); + ORT_UNUSED_PARAMETER(key); + ORT_UNUSED_PARAMETER(ptr); + return CreateNotEnabledStatus("CUDA"); +} + ORT_API(void, OrtApis::ReleaseCUDAProviderOptions, _Frees_ptr_opt_ OrtCUDAProviderOptionsV2* ptr) { ORT_UNUSED_PARAMETER(ptr); } diff --git a/onnxruntime/test/shared_lib/test_inference.cc b/onnxruntime/test/shared_lib/test_inference.cc index fd50a1e4bd..b846c8882e 100644 --- a/onnxruntime/test/shared_lib/test_inference.cc +++ b/onnxruntime/test/shared_lib/test_inference.cc @@ -2822,7 +2822,7 @@ INSTANTIATE_TEST_SUITE_P(CApiTensorRTTest, CApiTensorRTTest, #ifdef USE_CUDA -// This test uses CreateCUDAProviderOptions/UpdateCUDAProviderOptions APIs to configure and create a CUDA Execution Provider instance +// This test uses CreateCUDAProviderOptions/UpdateCUDAProviderOptions/UpdateCUDAProviderOptionsWithValue APIs to configure and create a CUDA Execution Provider instance TEST(CApiTest, TestConfigureCUDAProviderOptions) { const auto& api = Ort::GetApi(); @@ -2830,12 +2830,21 @@ TEST(CApiTest, TestConfigureCUDAProviderOptions) { ASSERT_TRUE(api.CreateCUDAProviderOptions(&cuda_options) == nullptr); std::unique_ptr rel_cuda_options(cuda_options, api.ReleaseCUDAProviderOptions); + // Only test updating OrtCUDAProviderOptionsV2 instance with user provided compute stream not running the inference + cudaStream_t compute_stream = nullptr; + void* user_compute_stream = nullptr; + cudaStreamCreateWithFlags(&compute_stream, cudaStreamNonBlocking); + ASSERT_TRUE(api.UpdateCUDAProviderOptionsWithValue(rel_cuda_options.get(), "user_compute_stream", compute_stream) == nullptr); + ASSERT_TRUE(api.GetCUDAProviderOptionsByName(rel_cuda_options.get(), "user_compute_stream", &user_compute_stream) == nullptr); + ASSERT_TRUE(user_compute_stream == (void*)compute_stream); + cudaStreamDestroy(compute_stream); + std::vector keys{ - "device_id", "gpu_mem_limit", "arena_extend_strategy", + "device_id", "has_user_compute_stream", "gpu_mem_limit", "arena_extend_strategy", "cudnn_conv_algo_search", "do_copy_in_default_stream", "cudnn_conv_use_max_workspace", "cudnn_conv1d_pad_to_nc1d"}; std::vector values{ - "0", "1024", "kSameAsRequested", + "0", "0", "1024", "kSameAsRequested", "DEFAULT", "1", "1"}; ASSERT_TRUE(api.UpdateCUDAProviderOptions(rel_cuda_options.get(), keys.data(), values.data(), 6) == nullptr);