Add API for updating CUDA EP provider option user compute stream (#17037)

Add a generic `UpdateCUDAProviderOptionsWithValue()` C API to update
CUDA EP provider options where its data type is pointer that can't be
represented by string.

Note: Please see some comments for the similar [PR
](https://github.com/microsoft/onnxruntime/pull/16965)for TRT EP.
This commit is contained in:
Chi Lo 2023-08-09 09:24:19 -07:00 committed by GitHub
parent a4902ee65b
commit 7361c283c7
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
8 changed files with 119 additions and 4 deletions

View file

@ -4372,6 +4372,32 @@ struct OrtApi {
* \since Version 1.16.
*/
ORT_API2_STATUS(GetTensorRTProviderOptionsByName, _In_ const OrtTensorRTProviderOptionsV2* tensorrt_options, _In_ const char* key, _Outptr_ void** ptr);
/**
* Update CUDA EP provider option where its data type is pointer, for example 'user_compute_stream'.
* If the data type of the provider option can be represented by string please use UpdateCUDAProviderOptions.
*
* Note: It's caller's responsibility to properly manage the lifetime of the instance pointed by this pointer.
*
* \param cuda_options - OrtCUDAProviderOptionsV2 instance
* \param key - Name of the provider option
* \param value - A pointer to the instance that will be assigned to this provider option
*
* \since Version 1.16.
*/
ORT_API2_STATUS(UpdateCUDAProviderOptionsWithValue, _Inout_ OrtCUDAProviderOptionsV2* cuda_options, _In_ const char* key, _In_ void* value);
/**
* Get CUDA EP provider option where its data type is pointer.
* If the data type of the provider option can be represented by string please use GetCUDAProviderOptionsAsString.
*
* \param cuda_options - OrtCUDAProviderOptionsV2 instance
* \param key - Name of the provider option
* \param ptr - A pointer to the instance that is kept by the provider option
*
* \since Version 1.16.
*/
ORT_API2_STATUS(GetCUDAProviderOptionsByName, _In_ const OrtCUDAProviderOptionsV2* cuda_options, _In_ const char* key, _Outptr_ void** ptr);
};
/*

View file

@ -14,6 +14,7 @@ namespace onnxruntime {
namespace cuda {
namespace provider_option_names {
constexpr const char* kDeviceId = "device_id";
constexpr const char* kHasUserComputeStream = "has_user_compute_stream";
constexpr const char* kMemLimit = "gpu_mem_limit";
constexpr const char* kArenaExtendStrategy = "arena_extend_strategy";
constexpr const char* kCudnnConvAlgoSearch = "cudnn_conv_algo_search";
@ -61,6 +62,7 @@ CUDAExecutionProviderInfo CUDAExecutionProviderInfo::FromProviderOptions(const P
", must be between 0 (inclusive) and ", num_devices, " (exclusive).");
return Status::OK();
})
.AddAssignmentToReference(cuda::provider_option_names::kHasUserComputeStream, info.has_user_compute_stream)
.AddValueParser(
cuda::provider_option_names::kGpuExternalAlloc,
[&alloc](const std::string& value_str) -> Status {
@ -125,6 +127,7 @@ CUDAExecutionProviderInfo CUDAExecutionProviderInfo::FromProviderOptions(const P
ProviderOptions CUDAExecutionProviderInfo::ToProviderOptions(const CUDAExecutionProviderInfo& info) {
const ProviderOptions options{
{cuda::provider_option_names::kDeviceId, MakeStringWithClassicLocale(info.device_id)},
{cuda::provider_option_names::kHasUserComputeStream, MakeStringWithClassicLocale(info.has_user_compute_stream)},
{cuda::provider_option_names::kMemLimit, MakeStringWithClassicLocale(info.gpu_mem_limit)},
{cuda::provider_option_names::kGpuExternalAlloc, MakeStringWithClassicLocale(reinterpret_cast<size_t>(info.external_allocator_info.alloc))},
{cuda::provider_option_names::kGpuExternalFree, MakeStringWithClassicLocale(reinterpret_cast<size_t>(info.external_allocator_info.free))},
@ -149,6 +152,7 @@ ProviderOptions CUDAExecutionProviderInfo::ToProviderOptions(const CUDAExecution
ProviderOptions CUDAExecutionProviderInfo::ToProviderOptions(const OrtCUDAProviderOptionsV2& info) {
const ProviderOptions options{
{cuda::provider_option_names::kDeviceId, MakeStringWithClassicLocale(info.device_id)},
{cuda::provider_option_names::kHasUserComputeStream, MakeStringWithClassicLocale(info.has_user_compute_stream)},
{cuda::provider_option_names::kMemLimit, MakeStringWithClassicLocale(info.gpu_mem_limit)},
{cuda::provider_option_names::kArenaExtendStrategy, EnumToName(arena_extend_strategy_mapping, info.arena_extend_strategy)},
{cuda::provider_option_names::kCudnnConvAlgoSearch, EnumToName(ort_cudnn_conv_algo_search_mapping, info.cudnn_conv_algo_search)},

View file

@ -226,6 +226,13 @@ struct CUDA_Provider : Provider {
return std::make_shared<CUDAProviderFactory>(info);
}
/**
* This function will be called by the C API UpdateCUDAProviderOptions().
*
* What this function does is equivalent to resetting the OrtCUDAProviderOptionsV2 instance with
* default CUDAExecutionProviderInf instance first and then set up the provided provider options.
* See CUDAExecutionProviderInfo::FromProviderOptions() for more details.
*/
void UpdateProviderOptions(void* provider_options, const ProviderOptions& options) override {
auto internal_options = onnxruntime::CUDAExecutionProviderInfo::FromProviderOptions(options);
auto& cuda_options = *reinterpret_cast<OrtCUDAProviderOptionsV2*>(provider_options);
@ -236,7 +243,11 @@ struct CUDA_Provider : Provider {
cuda_options.arena_extend_strategy = internal_options.arena_extend_strategy;
cuda_options.do_copy_in_default_stream = internal_options.do_copy_in_default_stream;
cuda_options.has_user_compute_stream = internal_options.has_user_compute_stream;
cuda_options.user_compute_stream = internal_options.user_compute_stream;
// The 'has_user_compute_stream' of the OrtCUDAProviderOptionsV2 instance can be set byC API UpdateCUDAProviderOptionsWithValue() as well.
// We only set the 'has_user_compute_stream' of the OrtCUDAProviderOptionsV2 instance if it is provided in options
if (options.find("has_user_compute_stream") != options.end()) {
cuda_options.user_compute_stream = internal_options.user_compute_stream;
}
cuda_options.default_memory_arena_cfg = internal_options.default_memory_arena_cfg;
cuda_options.cudnn_conv_use_max_workspace = internal_options.cudnn_conv_use_max_workspace;
cuda_options.enable_cuda_graph = internal_options.enable_cuda_graph;

View file

@ -2713,6 +2713,8 @@ static constexpr OrtApi ort_api_1_to_16 = {
&OrtApis::RunAsync,
&OrtApis::UpdateTensorRTProviderOptionsWithValue,
&OrtApis::GetTensorRTProviderOptionsByName,
&OrtApis::UpdateCUDAProviderOptionsWithValue,
&OrtApis::GetCUDAProviderOptionsByName,
};
// OrtApiBase can never change as there is no way to know what version of OrtApiBase is returned by OrtGetApiBase.

View file

@ -487,4 +487,6 @@ ORT_API_STATUS_IMPL(RunAsync, _Inout_ OrtSession* sess, _In_opt_ const OrtRunOpt
_In_ RunAsyncCallbackFn run_async_callback, _In_opt_ void* user_data);
ORT_API_STATUS_IMPL(UpdateTensorRTProviderOptionsWithValue, _Inout_ OrtTensorRTProviderOptionsV2* tensorrt_options, _In_ const char* key, _In_ void* value);
ORT_API_STATUS_IMPL(GetTensorRTProviderOptionsByName, _In_ const OrtTensorRTProviderOptionsV2* tensorrt_options, _In_ const char* key, _Outptr_ void** ptr);
ORT_API_STATUS_IMPL(UpdateCUDAProviderOptionsWithValue, _Inout_ OrtCUDAProviderOptionsV2* cuda_options, _In_ const char* key, _In_ void* value);
ORT_API_STATUS_IMPL(GetCUDAProviderOptionsByName, _In_ const OrtCUDAProviderOptionsV2* cuda_options, _In_ const char* key, _Outptr_ void** ptr);
} // namespace OrtApis

View file

@ -1977,6 +1977,47 @@ ORT_API_STATUS_IMPL(OrtApis::GetCUDAProviderOptionsAsString, _In_ const OrtCUDAP
API_IMPL_END
}
ORT_API_STATUS_IMPL(OrtApis::UpdateCUDAProviderOptionsWithValue,
_Inout_ OrtCUDAProviderOptionsV2* cuda_options,
_In_ const char* key,
_In_ void* value) {
API_IMPL_BEGIN
#ifdef USE_CUDA
if (strcmp(key, "user_compute_stream") == 0) {
cuda_options->has_user_compute_stream = 1;
cuda_options->user_compute_stream = value;
}
return nullptr;
#else
ORT_UNUSED_PARAMETER(cuda_options);
ORT_UNUSED_PARAMETER(key);
ORT_UNUSED_PARAMETER(value);
return CreateStatus(ORT_FAIL, "CUDA execution provider is not enabled in this build.");
#endif
API_IMPL_END
}
ORT_API_STATUS_IMPL(OrtApis::GetCUDAProviderOptionsByName,
_In_ const OrtCUDAProviderOptionsV2* cuda_options,
_In_ const char* key,
_Outptr_ void** ptr) {
API_IMPL_BEGIN
#ifdef USE_CUDA
if (strcmp(key, "user_compute_stream") == 0) {
*ptr = cuda_options->user_compute_stream;
} else {
*ptr = nullptr;
}
return nullptr;
#else
ORT_UNUSED_PARAMETER(cuda_options);
ORT_UNUSED_PARAMETER(key);
ORT_UNUSED_PARAMETER(ptr);
return CreateStatus(ORT_FAIL, "CUDA execution provider is not enabled in this build.");
#endif
API_IMPL_END
}
ORT_API(void, OrtApis::ReleaseCUDAProviderOptions, _Frees_ptr_opt_ OrtCUDAProviderOptionsV2* ptr) {
#ifdef USE_CUDA
std::unique_ptr<OrtCUDAProviderOptionsV2> p(ptr);

View file

@ -220,6 +220,26 @@ ORT_API_STATUS_IMPL(OrtApis::GetCUDAProviderOptionsAsString, _In_ const OrtCUDAP
return CreateStatus(ORT_FAIL, "CUDA execution provider is not enabled in this build.");
}
ORT_API_STATUS_IMPL(OrtApis::UpdateCUDAProviderOptionsWithValue,
_Inout_ OrtCUDAProviderOptionsV2* cuda_options,
_In_ const char* key,
_In_ void* value) {
ORT_UNUSED_PARAMETER(cuda_options);
ORT_UNUSED_PARAMETER(key);
ORT_UNUSED_PARAMETER(value);
return CreateNotEnabledStatus("CUDA");
}
ORT_API_STATUS_IMPL(OrtApis::GetCUDAProviderOptionsByName,
_In_ const OrtCUDAProviderOptionsV2* cuda_options,
_In_ const char* key,
_Outptr_ void** ptr) {
ORT_UNUSED_PARAMETER(cuda_options);
ORT_UNUSED_PARAMETER(key);
ORT_UNUSED_PARAMETER(ptr);
return CreateNotEnabledStatus("CUDA");
}
ORT_API(void, OrtApis::ReleaseCUDAProviderOptions, _Frees_ptr_opt_ OrtCUDAProviderOptionsV2* ptr) {
ORT_UNUSED_PARAMETER(ptr);
}

View file

@ -2822,7 +2822,7 @@ INSTANTIATE_TEST_SUITE_P(CApiTensorRTTest, CApiTensorRTTest,
#ifdef USE_CUDA
// This test uses CreateCUDAProviderOptions/UpdateCUDAProviderOptions APIs to configure and create a CUDA Execution Provider instance
// This test uses CreateCUDAProviderOptions/UpdateCUDAProviderOptions/UpdateCUDAProviderOptionsWithValue APIs to configure and create a CUDA Execution Provider instance
TEST(CApiTest, TestConfigureCUDAProviderOptions) {
const auto& api = Ort::GetApi();
@ -2830,12 +2830,21 @@ TEST(CApiTest, TestConfigureCUDAProviderOptions) {
ASSERT_TRUE(api.CreateCUDAProviderOptions(&cuda_options) == nullptr);
std::unique_ptr<OrtCUDAProviderOptionsV2, decltype(api.ReleaseCUDAProviderOptions)> rel_cuda_options(cuda_options, api.ReleaseCUDAProviderOptions);
// Only test updating OrtCUDAProviderOptionsV2 instance with user provided compute stream not running the inference
cudaStream_t compute_stream = nullptr;
void* user_compute_stream = nullptr;
cudaStreamCreateWithFlags(&compute_stream, cudaStreamNonBlocking);
ASSERT_TRUE(api.UpdateCUDAProviderOptionsWithValue(rel_cuda_options.get(), "user_compute_stream", compute_stream) == nullptr);
ASSERT_TRUE(api.GetCUDAProviderOptionsByName(rel_cuda_options.get(), "user_compute_stream", &user_compute_stream) == nullptr);
ASSERT_TRUE(user_compute_stream == (void*)compute_stream);
cudaStreamDestroy(compute_stream);
std::vector<const char*> keys{
"device_id", "gpu_mem_limit", "arena_extend_strategy",
"device_id", "has_user_compute_stream", "gpu_mem_limit", "arena_extend_strategy",
"cudnn_conv_algo_search", "do_copy_in_default_stream", "cudnn_conv_use_max_workspace", "cudnn_conv1d_pad_to_nc1d"};
std::vector<const char*> values{
"0", "1024", "kSameAsRequested",
"0", "0", "1024", "kSameAsRequested",
"DEFAULT", "1", "1"};
ASSERT_TRUE(api.UpdateCUDAProviderOptions(rel_cuda_options.get(), keys.data(), values.data(), 6) == nullptr);