Add conditional check in Get/Set current GPU device id (#20932)

### Description

Add conditional check in Get/Set current GPU device id


### Motivation and Context

Currently with ROCm build, calling `GetCurrentGpuDeviceId` will still
try to find CUDA libraries and log the following error message:

```text
[E:onnxruntime:, provider_bridge_ort.cc:1836 TryGetProviderInfo_CUDA] /onnxruntime_src/onnxruntime/core/session/provider_bridge_ort.cc:1511 onnxruntime::Provider& onnxruntime::ProviderLibrary::Get() [ONNXRuntimeError] : 1 : FAIL : Failed to load library libonnxruntime_providers_cuda.so with error: libonnxruntime_providers_cuda.so: cannot open shared object file: No such file or directory
```

This is unnecessary and confusing.
This commit is contained in:
Chester Liu 2024-06-06 17:10:14 +08:00 committed by GitHub
parent 3ecf48e3b5
commit 5b87544aab
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -2099,22 +2099,36 @@ ORT_API_STATUS_IMPL(OrtSessionOptionsAppendExecutionProvider_CUDA, _In_ OrtSessi
return OrtApis::SessionOptionsAppendExecutionProvider_CUDA(options, &provider_options);
}
ORT_API_STATUS_IMPL(OrtApis::SetCurrentGpuDeviceId, _In_ int device_id) {
ORT_API_STATUS_IMPL(OrtApis::SetCurrentGpuDeviceId, [[maybe_unused]] _In_ int device_id) {
API_IMPL_BEGIN
#ifdef USE_CUDA
if (auto* info = onnxruntime::TryGetProviderInfo_CUDA())
return info->SetCurrentGpuDeviceId(device_id);
#endif
#ifdef USE_ROCM
if (auto* info = onnxruntime::TryGetProviderInfo_ROCM())
return info->SetCurrentGpuDeviceId(device_id);
#endif
return CreateStatus(ORT_FAIL, "CUDA and/or ROCM execution provider is either not enabled or not available.");
API_IMPL_END
}
ORT_API_STATUS_IMPL(OrtApis::GetCurrentGpuDeviceId, _In_ int* device_id) {
ORT_API_STATUS_IMPL(OrtApis::GetCurrentGpuDeviceId, [[maybe_unused]] _In_ int* device_id) {
API_IMPL_BEGIN
#ifdef USE_CUDA
if (auto* info = onnxruntime::TryGetProviderInfo_CUDA())
return info->GetCurrentGpuDeviceId(device_id);
#endif
#ifdef USE_ROCM
if (auto* info = onnxruntime::TryGetProviderInfo_ROCM())
return info->GetCurrentGpuDeviceId(device_id);
#endif
return CreateStatus(ORT_FAIL, "CUDA and/or ROCM execution provider is either not enabled or not available.");
API_IMPL_END
}