diff --git a/onnxruntime/python/backend/backend.py b/onnxruntime/python/backend/backend.py index 97b7358f2a..67423fe9b5 100644 --- a/onnxruntime/python/backend/backend.py +++ b/onnxruntime/python/backend/backend.py @@ -87,7 +87,7 @@ class OnnxRuntimeBackend(Backend): """ if device == "CUDA": device = "GPU" - return device in get_device() + return "-" + device in get_device() or device + "-" in get_device() or device == get_device() @classmethod def prepare(cls, model, device=None, **kwargs):