mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-14 20:48:00 +00:00
fix supports_device() in python interface (#22473)
### Description
`get_device()` returns a string of hyphen connected device names, such
as "GPU-DML". It's a problem that when CUDA is disabled but OpenVino GPU
is enabled in the build, because in this case `get_device()` returns
"CPU-OPENVINO_GPU", so `supports_device("CUDA")` will return `True` in
this build.
Splitting the value of `get_device()` by "-" and check if the input is
in the list is not an option because it seems some code in the code base
stores the value of `get_device()` and use the value to call
`supports_device()`. Using this implementation will cause
`supports_device("GPU-DML")` to return `False` for a build with
`get_device() == "GPU-DML"` because `"GPU-DML" in ["GPU", "DML"]` is
`False`.
This change also helps to avoid further problems when "WebGPU" is
introduced.
This commit is contained in:
parent
1247d69c28
commit
55c584954c
1 changed files with 1 additions and 1 deletions
|
|
@ -87,7 +87,7 @@ class OnnxRuntimeBackend(Backend):
|
|||
"""
|
||||
if device == "CUDA":
|
||||
device = "GPU"
|
||||
return device in get_device()
|
||||
return "-" + device in get_device() or device + "-" in get_device() or device == get_device()
|
||||
|
||||
@classmethod
|
||||
def prepare(cls, model, device=None, **kwargs):
|
||||
|
|
|
|||
Loading…
Reference in a new issue