fix supports_device() in python interface (#22473)

### Description

`get_device()` returns a string of hyphen connected device names, such
as "GPU-DML". It's a problem that when CUDA is disabled but OpenVino GPU
is enabled in the build, because in this case `get_device()` returns
"CPU-OPENVINO_GPU", so `supports_device("CUDA")` will return `True` in
this build.

Splitting the value of `get_device()` by "-" and check if the input is
in the list is not an option because it seems some code in the code base
stores the value of `get_device()` and use the value to call
`supports_device()`. Using this implementation will cause
`supports_device("GPU-DML")` to return `False` for a build with
`get_device() == "GPU-DML"` because `"GPU-DML" in ["GPU", "DML"]` is
`False`.

This change also helps to avoid further problems when "WebGPU" is
introduced.
This commit is contained in:
Yulong Wang 2024-10-17 12:10:25 -07:00 committed by GitHub
parent 1247d69c28
commit 55c584954c
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -87,7 +87,7 @@ class OnnxRuntimeBackend(Backend):
"""
if device == "CUDA":
device = "GPU"
return device in get_device()
return "-" + device in get_device() or device + "-" in get_device() or device == get_device()
@classmethod
def prepare(cls, model, device=None, **kwargs):