mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-06 00:03:22 +00:00
Add API for updating CUDA EP provider option user compute stream (#17037)
Add a generic `UpdateCUDAProviderOptionsWithValue()` C API to update CUDA EP provider options where its data type is pointer that can't be represented by string. Note: Please see some comments for the similar [PR ](https://github.com/microsoft/onnxruntime/pull/16965)for TRT EP.
This commit is contained in:
parent
a4902ee65b
commit
7361c283c7
8 changed files with 119 additions and 4 deletions
|
|
@ -4372,6 +4372,32 @@ struct OrtApi {
|
|||
* \since Version 1.16.
|
||||
*/
|
||||
ORT_API2_STATUS(GetTensorRTProviderOptionsByName, _In_ const OrtTensorRTProviderOptionsV2* tensorrt_options, _In_ const char* key, _Outptr_ void** ptr);
|
||||
|
||||
/**
|
||||
* Update CUDA EP provider option where its data type is pointer, for example 'user_compute_stream'.
|
||||
* If the data type of the provider option can be represented by string please use UpdateCUDAProviderOptions.
|
||||
*
|
||||
* Note: It's caller's responsibility to properly manage the lifetime of the instance pointed by this pointer.
|
||||
*
|
||||
* \param cuda_options - OrtCUDAProviderOptionsV2 instance
|
||||
* \param key - Name of the provider option
|
||||
* \param value - A pointer to the instance that will be assigned to this provider option
|
||||
*
|
||||
* \since Version 1.16.
|
||||
*/
|
||||
ORT_API2_STATUS(UpdateCUDAProviderOptionsWithValue, _Inout_ OrtCUDAProviderOptionsV2* cuda_options, _In_ const char* key, _In_ void* value);
|
||||
|
||||
/**
|
||||
* Get CUDA EP provider option where its data type is pointer.
|
||||
* If the data type of the provider option can be represented by string please use GetCUDAProviderOptionsAsString.
|
||||
*
|
||||
* \param cuda_options - OrtCUDAProviderOptionsV2 instance
|
||||
* \param key - Name of the provider option
|
||||
* \param ptr - A pointer to the instance that is kept by the provider option
|
||||
*
|
||||
* \since Version 1.16.
|
||||
*/
|
||||
ORT_API2_STATUS(GetCUDAProviderOptionsByName, _In_ const OrtCUDAProviderOptionsV2* cuda_options, _In_ const char* key, _Outptr_ void** ptr);
|
||||
};
|
||||
|
||||
/*
|
||||
|
|
|
|||
|
|
@ -14,6 +14,7 @@ namespace onnxruntime {
|
|||
namespace cuda {
|
||||
namespace provider_option_names {
|
||||
constexpr const char* kDeviceId = "device_id";
|
||||
constexpr const char* kHasUserComputeStream = "has_user_compute_stream";
|
||||
constexpr const char* kMemLimit = "gpu_mem_limit";
|
||||
constexpr const char* kArenaExtendStrategy = "arena_extend_strategy";
|
||||
constexpr const char* kCudnnConvAlgoSearch = "cudnn_conv_algo_search";
|
||||
|
|
@ -61,6 +62,7 @@ CUDAExecutionProviderInfo CUDAExecutionProviderInfo::FromProviderOptions(const P
|
|||
", must be between 0 (inclusive) and ", num_devices, " (exclusive).");
|
||||
return Status::OK();
|
||||
})
|
||||
.AddAssignmentToReference(cuda::provider_option_names::kHasUserComputeStream, info.has_user_compute_stream)
|
||||
.AddValueParser(
|
||||
cuda::provider_option_names::kGpuExternalAlloc,
|
||||
[&alloc](const std::string& value_str) -> Status {
|
||||
|
|
@ -125,6 +127,7 @@ CUDAExecutionProviderInfo CUDAExecutionProviderInfo::FromProviderOptions(const P
|
|||
ProviderOptions CUDAExecutionProviderInfo::ToProviderOptions(const CUDAExecutionProviderInfo& info) {
|
||||
const ProviderOptions options{
|
||||
{cuda::provider_option_names::kDeviceId, MakeStringWithClassicLocale(info.device_id)},
|
||||
{cuda::provider_option_names::kHasUserComputeStream, MakeStringWithClassicLocale(info.has_user_compute_stream)},
|
||||
{cuda::provider_option_names::kMemLimit, MakeStringWithClassicLocale(info.gpu_mem_limit)},
|
||||
{cuda::provider_option_names::kGpuExternalAlloc, MakeStringWithClassicLocale(reinterpret_cast<size_t>(info.external_allocator_info.alloc))},
|
||||
{cuda::provider_option_names::kGpuExternalFree, MakeStringWithClassicLocale(reinterpret_cast<size_t>(info.external_allocator_info.free))},
|
||||
|
|
@ -149,6 +152,7 @@ ProviderOptions CUDAExecutionProviderInfo::ToProviderOptions(const CUDAExecution
|
|||
ProviderOptions CUDAExecutionProviderInfo::ToProviderOptions(const OrtCUDAProviderOptionsV2& info) {
|
||||
const ProviderOptions options{
|
||||
{cuda::provider_option_names::kDeviceId, MakeStringWithClassicLocale(info.device_id)},
|
||||
{cuda::provider_option_names::kHasUserComputeStream, MakeStringWithClassicLocale(info.has_user_compute_stream)},
|
||||
{cuda::provider_option_names::kMemLimit, MakeStringWithClassicLocale(info.gpu_mem_limit)},
|
||||
{cuda::provider_option_names::kArenaExtendStrategy, EnumToName(arena_extend_strategy_mapping, info.arena_extend_strategy)},
|
||||
{cuda::provider_option_names::kCudnnConvAlgoSearch, EnumToName(ort_cudnn_conv_algo_search_mapping, info.cudnn_conv_algo_search)},
|
||||
|
|
|
|||
|
|
@ -226,6 +226,13 @@ struct CUDA_Provider : Provider {
|
|||
return std::make_shared<CUDAProviderFactory>(info);
|
||||
}
|
||||
|
||||
/**
|
||||
* This function will be called by the C API UpdateCUDAProviderOptions().
|
||||
*
|
||||
* What this function does is equivalent to resetting the OrtCUDAProviderOptionsV2 instance with
|
||||
* default CUDAExecutionProviderInf instance first and then set up the provided provider options.
|
||||
* See CUDAExecutionProviderInfo::FromProviderOptions() for more details.
|
||||
*/
|
||||
void UpdateProviderOptions(void* provider_options, const ProviderOptions& options) override {
|
||||
auto internal_options = onnxruntime::CUDAExecutionProviderInfo::FromProviderOptions(options);
|
||||
auto& cuda_options = *reinterpret_cast<OrtCUDAProviderOptionsV2*>(provider_options);
|
||||
|
|
@ -236,7 +243,11 @@ struct CUDA_Provider : Provider {
|
|||
cuda_options.arena_extend_strategy = internal_options.arena_extend_strategy;
|
||||
cuda_options.do_copy_in_default_stream = internal_options.do_copy_in_default_stream;
|
||||
cuda_options.has_user_compute_stream = internal_options.has_user_compute_stream;
|
||||
cuda_options.user_compute_stream = internal_options.user_compute_stream;
|
||||
// The 'has_user_compute_stream' of the OrtCUDAProviderOptionsV2 instance can be set byC API UpdateCUDAProviderOptionsWithValue() as well.
|
||||
// We only set the 'has_user_compute_stream' of the OrtCUDAProviderOptionsV2 instance if it is provided in options
|
||||
if (options.find("has_user_compute_stream") != options.end()) {
|
||||
cuda_options.user_compute_stream = internal_options.user_compute_stream;
|
||||
}
|
||||
cuda_options.default_memory_arena_cfg = internal_options.default_memory_arena_cfg;
|
||||
cuda_options.cudnn_conv_use_max_workspace = internal_options.cudnn_conv_use_max_workspace;
|
||||
cuda_options.enable_cuda_graph = internal_options.enable_cuda_graph;
|
||||
|
|
|
|||
|
|
@ -2713,6 +2713,8 @@ static constexpr OrtApi ort_api_1_to_16 = {
|
|||
&OrtApis::RunAsync,
|
||||
&OrtApis::UpdateTensorRTProviderOptionsWithValue,
|
||||
&OrtApis::GetTensorRTProviderOptionsByName,
|
||||
&OrtApis::UpdateCUDAProviderOptionsWithValue,
|
||||
&OrtApis::GetCUDAProviderOptionsByName,
|
||||
};
|
||||
|
||||
// OrtApiBase can never change as there is no way to know what version of OrtApiBase is returned by OrtGetApiBase.
|
||||
|
|
|
|||
|
|
@ -487,4 +487,6 @@ ORT_API_STATUS_IMPL(RunAsync, _Inout_ OrtSession* sess, _In_opt_ const OrtRunOpt
|
|||
_In_ RunAsyncCallbackFn run_async_callback, _In_opt_ void* user_data);
|
||||
ORT_API_STATUS_IMPL(UpdateTensorRTProviderOptionsWithValue, _Inout_ OrtTensorRTProviderOptionsV2* tensorrt_options, _In_ const char* key, _In_ void* value);
|
||||
ORT_API_STATUS_IMPL(GetTensorRTProviderOptionsByName, _In_ const OrtTensorRTProviderOptionsV2* tensorrt_options, _In_ const char* key, _Outptr_ void** ptr);
|
||||
ORT_API_STATUS_IMPL(UpdateCUDAProviderOptionsWithValue, _Inout_ OrtCUDAProviderOptionsV2* cuda_options, _In_ const char* key, _In_ void* value);
|
||||
ORT_API_STATUS_IMPL(GetCUDAProviderOptionsByName, _In_ const OrtCUDAProviderOptionsV2* cuda_options, _In_ const char* key, _Outptr_ void** ptr);
|
||||
} // namespace OrtApis
|
||||
|
|
|
|||
|
|
@ -1977,6 +1977,47 @@ ORT_API_STATUS_IMPL(OrtApis::GetCUDAProviderOptionsAsString, _In_ const OrtCUDAP
|
|||
API_IMPL_END
|
||||
}
|
||||
|
||||
ORT_API_STATUS_IMPL(OrtApis::UpdateCUDAProviderOptionsWithValue,
|
||||
_Inout_ OrtCUDAProviderOptionsV2* cuda_options,
|
||||
_In_ const char* key,
|
||||
_In_ void* value) {
|
||||
API_IMPL_BEGIN
|
||||
#ifdef USE_CUDA
|
||||
if (strcmp(key, "user_compute_stream") == 0) {
|
||||
cuda_options->has_user_compute_stream = 1;
|
||||
cuda_options->user_compute_stream = value;
|
||||
}
|
||||
return nullptr;
|
||||
#else
|
||||
ORT_UNUSED_PARAMETER(cuda_options);
|
||||
ORT_UNUSED_PARAMETER(key);
|
||||
ORT_UNUSED_PARAMETER(value);
|
||||
return CreateStatus(ORT_FAIL, "CUDA execution provider is not enabled in this build.");
|
||||
#endif
|
||||
API_IMPL_END
|
||||
}
|
||||
|
||||
ORT_API_STATUS_IMPL(OrtApis::GetCUDAProviderOptionsByName,
|
||||
_In_ const OrtCUDAProviderOptionsV2* cuda_options,
|
||||
_In_ const char* key,
|
||||
_Outptr_ void** ptr) {
|
||||
API_IMPL_BEGIN
|
||||
#ifdef USE_CUDA
|
||||
if (strcmp(key, "user_compute_stream") == 0) {
|
||||
*ptr = cuda_options->user_compute_stream;
|
||||
} else {
|
||||
*ptr = nullptr;
|
||||
}
|
||||
return nullptr;
|
||||
#else
|
||||
ORT_UNUSED_PARAMETER(cuda_options);
|
||||
ORT_UNUSED_PARAMETER(key);
|
||||
ORT_UNUSED_PARAMETER(ptr);
|
||||
return CreateStatus(ORT_FAIL, "CUDA execution provider is not enabled in this build.");
|
||||
#endif
|
||||
API_IMPL_END
|
||||
}
|
||||
|
||||
ORT_API(void, OrtApis::ReleaseCUDAProviderOptions, _Frees_ptr_opt_ OrtCUDAProviderOptionsV2* ptr) {
|
||||
#ifdef USE_CUDA
|
||||
std::unique_ptr<OrtCUDAProviderOptionsV2> p(ptr);
|
||||
|
|
|
|||
|
|
@ -220,6 +220,26 @@ ORT_API_STATUS_IMPL(OrtApis::GetCUDAProviderOptionsAsString, _In_ const OrtCUDAP
|
|||
return CreateStatus(ORT_FAIL, "CUDA execution provider is not enabled in this build.");
|
||||
}
|
||||
|
||||
ORT_API_STATUS_IMPL(OrtApis::UpdateCUDAProviderOptionsWithValue,
|
||||
_Inout_ OrtCUDAProviderOptionsV2* cuda_options,
|
||||
_In_ const char* key,
|
||||
_In_ void* value) {
|
||||
ORT_UNUSED_PARAMETER(cuda_options);
|
||||
ORT_UNUSED_PARAMETER(key);
|
||||
ORT_UNUSED_PARAMETER(value);
|
||||
return CreateNotEnabledStatus("CUDA");
|
||||
}
|
||||
|
||||
ORT_API_STATUS_IMPL(OrtApis::GetCUDAProviderOptionsByName,
|
||||
_In_ const OrtCUDAProviderOptionsV2* cuda_options,
|
||||
_In_ const char* key,
|
||||
_Outptr_ void** ptr) {
|
||||
ORT_UNUSED_PARAMETER(cuda_options);
|
||||
ORT_UNUSED_PARAMETER(key);
|
||||
ORT_UNUSED_PARAMETER(ptr);
|
||||
return CreateNotEnabledStatus("CUDA");
|
||||
}
|
||||
|
||||
ORT_API(void, OrtApis::ReleaseCUDAProviderOptions, _Frees_ptr_opt_ OrtCUDAProviderOptionsV2* ptr) {
|
||||
ORT_UNUSED_PARAMETER(ptr);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -2822,7 +2822,7 @@ INSTANTIATE_TEST_SUITE_P(CApiTensorRTTest, CApiTensorRTTest,
|
|||
|
||||
#ifdef USE_CUDA
|
||||
|
||||
// This test uses CreateCUDAProviderOptions/UpdateCUDAProviderOptions APIs to configure and create a CUDA Execution Provider instance
|
||||
// This test uses CreateCUDAProviderOptions/UpdateCUDAProviderOptions/UpdateCUDAProviderOptionsWithValue APIs to configure and create a CUDA Execution Provider instance
|
||||
TEST(CApiTest, TestConfigureCUDAProviderOptions) {
|
||||
const auto& api = Ort::GetApi();
|
||||
|
||||
|
|
@ -2830,12 +2830,21 @@ TEST(CApiTest, TestConfigureCUDAProviderOptions) {
|
|||
ASSERT_TRUE(api.CreateCUDAProviderOptions(&cuda_options) == nullptr);
|
||||
std::unique_ptr<OrtCUDAProviderOptionsV2, decltype(api.ReleaseCUDAProviderOptions)> rel_cuda_options(cuda_options, api.ReleaseCUDAProviderOptions);
|
||||
|
||||
// Only test updating OrtCUDAProviderOptionsV2 instance with user provided compute stream not running the inference
|
||||
cudaStream_t compute_stream = nullptr;
|
||||
void* user_compute_stream = nullptr;
|
||||
cudaStreamCreateWithFlags(&compute_stream, cudaStreamNonBlocking);
|
||||
ASSERT_TRUE(api.UpdateCUDAProviderOptionsWithValue(rel_cuda_options.get(), "user_compute_stream", compute_stream) == nullptr);
|
||||
ASSERT_TRUE(api.GetCUDAProviderOptionsByName(rel_cuda_options.get(), "user_compute_stream", &user_compute_stream) == nullptr);
|
||||
ASSERT_TRUE(user_compute_stream == (void*)compute_stream);
|
||||
cudaStreamDestroy(compute_stream);
|
||||
|
||||
std::vector<const char*> keys{
|
||||
"device_id", "gpu_mem_limit", "arena_extend_strategy",
|
||||
"device_id", "has_user_compute_stream", "gpu_mem_limit", "arena_extend_strategy",
|
||||
"cudnn_conv_algo_search", "do_copy_in_default_stream", "cudnn_conv_use_max_workspace", "cudnn_conv1d_pad_to_nc1d"};
|
||||
|
||||
std::vector<const char*> values{
|
||||
"0", "1024", "kSameAsRequested",
|
||||
"0", "0", "1024", "kSameAsRequested",
|
||||
"DEFAULT", "1", "1"};
|
||||
|
||||
ASSERT_TRUE(api.UpdateCUDAProviderOptions(rel_cuda_options.get(), keys.data(), values.data(), 6) == nullptr);
|
||||
|
|
|
|||
Loading…
Reference in a new issue