diff --git a/include/onnxruntime/core/framework/ortdevice.h b/include/onnxruntime/core/framework/ortdevice.h index f15543f22f..6f658ab65b 100644 --- a/include/onnxruntime/core/framework/ortdevice.h +++ b/include/onnxruntime/core/framework/ortdevice.h @@ -17,6 +17,7 @@ struct OrtDevice { static const DeviceType GPU = 1; // Nvidia or AMD static const DeviceType FPGA = 2; static const DeviceType NPU = 3; // Ascend + static const DeviceType DML = 4; struct MemType { // Pre-defined memory types. diff --git a/onnxruntime/core/framework/allocator.cc b/onnxruntime/core/framework/allocator.cc index b6dc8ad56f..26b98b0a04 100644 --- a/onnxruntime/core/framework/allocator.cc +++ b/onnxruntime/core/framework/allocator.cc @@ -139,13 +139,16 @@ ORT_API_STATUS_IMPL(OrtApis::CreateMemoryInfo, _In_ const char* name1, enum OrtA *out = new OrtMemoryInfo(onnxruntime::CPU, type, OrtDevice(), id1, mem_type1); } else if (strcmp(name1, onnxruntime::CUDA) == 0 || strcmp(name1, onnxruntime::OpenVINO_GPU) == 0 || - strcmp(name1, onnxruntime::DML) == 0 || strcmp(name1, onnxruntime::HIP) == 0 || strcmp(name1, onnxruntime::WEBGPU_BUFFER) == 0 || strcmp(name1, onnxruntime::WEBNN_TENSOR) == 0) { *out = new OrtMemoryInfo( name1, type, OrtDevice(OrtDevice::GPU, OrtDevice::MemType::DEFAULT, static_cast(id1)), id1, mem_type1); + } else if (strcmp(name1, onnxruntime::DML) == 0) { + *out = new OrtMemoryInfo( + name1, type, OrtDevice(OrtDevice::DML, OrtDevice::MemType::DEFAULT, static_cast(id1)), id1, + mem_type1); } else if (strcmp(name1, onnxruntime::OpenVINO_RT_NPU) == 0) { *out = new OrtMemoryInfo( name1, type, OrtDevice(OrtDevice::NPU, OrtDevice::MemType::DEFAULT, static_cast(id1)), id1, diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/BucketizedBufferAllocator.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/BucketizedBufferAllocator.cpp index 801cceb3bd..68b9b3fe57 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/BucketizedBufferAllocator.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/BucketizedBufferAllocator.cpp @@ -41,23 +41,19 @@ namespace Dml D3D12_HEAP_FLAGS heapFlags, D3D12_RESOURCE_FLAGS resourceFlags, D3D12_RESOURCE_STATES initialState, - std::unique_ptr&& subAllocator - ) + std::unique_ptr&& subAllocator) : onnxruntime::IAllocator( - OrtMemoryInfo( - "DML", - OrtAllocatorType::OrtDeviceAllocator, - OrtDevice(OrtDevice::GPU, OrtDevice::MemType::DEFAULT, 0) - ) - ), - m_device(device), - m_heapProperties(heapProps), - m_heapFlags(heapFlags), - m_resourceFlags(resourceFlags), - m_initialState(initialState), - m_context(context), - m_subAllocator(std::move(subAllocator)) - { + OrtMemoryInfo( + "DML", + OrtAllocatorType::OrtDeviceAllocator, + OrtDevice(OrtDevice::DML, OrtDevice::MemType::DEFAULT, 0))), + m_device(device), + m_heapProperties(heapProps), + m_heapFlags(heapFlags), + m_resourceFlags(resourceFlags), + m_initialState(initialState), + m_context(context), + m_subAllocator(std::move(subAllocator)) { } /*static*/ gsl::index BucketizedBufferAllocator::GetBucketIndexFromSize(uint64_t size) diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlExternalBufferAllocator.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlExternalBufferAllocator.h index b22f0b2853..6aae05c999 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlExternalBufferAllocator.h +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlExternalBufferAllocator.h @@ -20,15 +20,13 @@ namespace Dml class DmlExternalBufferAllocator : public onnxruntime::IAllocator { public: - DmlExternalBufferAllocator(int device_id) : onnxruntime::IAllocator( - OrtMemoryInfo( - "DML", - OrtAllocatorType::OrtDeviceAllocator, - OrtDevice(OrtDevice::GPU, OrtDevice::MemType::DEFAULT, 0) - )) - { - m_device = onnxruntime::DMLProviderFactoryCreator::CreateD3D12Device(device_id, false); - } + DmlExternalBufferAllocator(int device_id) : onnxruntime::IAllocator( + OrtMemoryInfo( + "DML", + OrtAllocatorType::OrtDeviceAllocator, + OrtDevice(OrtDevice::DML, OrtDevice::MemType::DEFAULT, 0))) { + m_device = onnxruntime::DMLProviderFactoryCreator::CreateD3D12Device(device_id, false); + } void* Alloc(size_t size) final { diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.cpp index 6b0faaad43..2deb83ec13 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.cpp @@ -73,20 +73,17 @@ namespace Dml bool enableMetacommands, bool enableGraphCapture, bool enableSyncSpinning, - bool disableMemoryArena) : - IExecutionProvider(onnxruntime::kDmlExecutionProvider, OrtDevice(OrtDevice::GPU, OrtDevice::MemType::DEFAULT, 0)) - { - D3D12_COMMAND_LIST_TYPE queueType = executionContext->GetCommandListTypeForQueue(); - if (queueType != D3D12_COMMAND_LIST_TYPE_DIRECT && queueType != D3D12_COMMAND_LIST_TYPE_COMPUTE) - { - // DML requires either DIRECT or COMPUTE command queues. - ORT_THROW_HR(E_INVALIDARG); - } + bool disableMemoryArena) : IExecutionProvider(onnxruntime::kDmlExecutionProvider, OrtDevice(OrtDevice::DML, OrtDevice::MemType::DEFAULT, 0)) { + D3D12_COMMAND_LIST_TYPE queueType = executionContext->GetCommandListTypeForQueue(); + if (queueType != D3D12_COMMAND_LIST_TYPE_DIRECT && queueType != D3D12_COMMAND_LIST_TYPE_COMPUTE) { + // DML requires either DIRECT or COMPUTE command queues. + ORT_THROW_HR(E_INVALIDARG); + } - ComPtr device; - GRAPHICS_THROW_IF_FAILED(dmlDevice->GetParentDevice(IID_GRAPHICS_PPV_ARGS(device.GetAddressOf()))); + ComPtr device; + GRAPHICS_THROW_IF_FAILED(dmlDevice->GetParentDevice(IID_GRAPHICS_PPV_ARGS(device.GetAddressOf()))); - m_impl = wil::MakeOrThrow(dmlDevice, device.Get(), executionContext, enableMetacommands, enableGraphCapture, enableSyncSpinning, disableMemoryArena); + m_impl = wil::MakeOrThrow(dmlDevice, device.Get(), executionContext, enableMetacommands, enableGraphCapture, enableSyncSpinning, disableMemoryArena); } std::vector> diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.h index c20969250f..32a5b9add3 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.h +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.h @@ -242,8 +242,8 @@ namespace Dml bool CanCopy(const OrtDevice& srcDevice, const OrtDevice& dstDevice) const final { - return (srcDevice.Type() == OrtDevice::GPU) || - (dstDevice.Type() == OrtDevice::GPU); + return (srcDevice.Type() == OrtDevice::DML) || + (dstDevice.Type() == OrtDevice::DML); } private: diff --git a/onnxruntime/core/session/inference_session.cc b/onnxruntime/core/session/inference_session.cc index 4be107758d..42a2b4ef3e 100644 --- a/onnxruntime/core/session/inference_session.cc +++ b/onnxruntime/core/session/inference_session.cc @@ -1662,6 +1662,23 @@ static void ResolveMemoryPatternFlags(SessionState& session_state) { } } } + +// This function is called when the session is being initialized. +// For now, this function only checks for invalid combination of DML EP with other EPs. +// TODO: extend this function to check for other invalid combinations of EPs. +common::Status InferenceSession::HasInvalidCombinationOfExecutionProviders() const { + // DML EP is only allowed with CPU EP + bool has_dml_ep = execution_providers_.Get(kDmlExecutionProvider) != nullptr; + if (has_dml_ep) { + const auto& ep_list = execution_providers_.GetIds(); + for (const auto& ep : ep_list) { + if (ep == kDmlExecutionProvider || ep == kCpuExecutionProvider) continue; + return common::Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT, "DML EP can be used with only CPU EP."); + } + } + return Status::OK(); +} + #if defined(_MSC_VER) && !defined(__clang__) #pragma warning(push) // VC++ reports: "Releasing unheld lock 'l' in function 'onnxruntime::InferenceSession::Initialize'". But I don't see anything wrong. @@ -1719,6 +1736,11 @@ common::Status InferenceSession::Initialize() { execution_providers_.SetCpuProviderWasImplicitlyAdded(true); } + // Check for the presence of an invalid combination of execution providers in the session + // For e.g. we don't support DML EP and other GPU EPs to be present in the same session + // This check is placed here because it serves as a common place for all language bindings. + ORT_RETURN_IF_ERROR_SESSIONID_(HasInvalidCombinationOfExecutionProviders()); + // re-acquire mutex std::lock_guard l(session_mutex_); diff --git a/onnxruntime/core/session/inference_session.h b/onnxruntime/core/session/inference_session.h index 514a478e3f..0675f64848 100644 --- a/onnxruntime/core/session/inference_session.h +++ b/onnxruntime/core/session/inference_session.h @@ -620,7 +620,7 @@ class InferenceSession { const Environment& session_env); void ConstructorCommon(const SessionOptions& session_options, const Environment& session_env); - + [[nodiscard]] common::Status HasInvalidCombinationOfExecutionProviders() const; [[nodiscard]] common::Status SaveModelMetadata(const onnxruntime::Model& model); #if !defined(ORT_MINIMAL_BUILD) diff --git a/onnxruntime/python/onnxruntime_pybind_mlvalue.cc b/onnxruntime/python/onnxruntime_pybind_mlvalue.cc index ebb1a54fac..74bd20461e 100644 --- a/onnxruntime/python/onnxruntime_pybind_mlvalue.cc +++ b/onnxruntime/python/onnxruntime_pybind_mlvalue.cc @@ -291,7 +291,7 @@ void DmlToCpuMemCpy(void* dst, const void* src, size_t num_bytes) { const std::unordered_map* GetDmlToHostMemCpyFunction() { static std::unordered_map map{ - {OrtDevice::GPU, DmlToCpuMemCpy}}; + {OrtDevice::DML, DmlToCpuMemCpy}}; return ↦ } diff --git a/onnxruntime/python/onnxruntime_pybind_ortvalue.cc b/onnxruntime/python/onnxruntime_pybind_ortvalue.cc index 18785cd607..6a57fc5f90 100644 --- a/onnxruntime/python/onnxruntime_pybind_ortvalue.cc +++ b/onnxruntime/python/onnxruntime_pybind_ortvalue.cc @@ -96,16 +96,22 @@ void addOrtValueMethods(pybind11::module& m) { // Likewise, there is no need to specify the name (as the name was previously used to lookup the def list) // TODO: Add check to ensure that string arrays are not passed - we currently don't support string tensors in CUDA CreateGenericMLValue(nullptr, GetRocmAllocator(device.Id()), "", array_on_cpu, ml_value.get(), true, false, CpuToRocmMemCpy); -#elif USE_DML - // InputDeflist is null because OrtValue creation is not tied to a specific model - // Likewise, there is no need to specify the name (as the name was previously used to lookup the def list) - // TODO: Add check to ensure that string arrays are not passed - we currently don't support string tensors in DML - CreateGenericMLValue( - nullptr, GetDmlAllocator(device.Id()), "", array_on_cpu, ml_value.get(), true, false, CpuToDmlMemCpy); #else - throw std::runtime_error( - "Can't allocate memory on the CUDA device using this package of OnnxRuntime. " - "Please use the CUDA package of OnnxRuntime to use this feature."); + throw std::runtime_error( + "Can't allocate memory on the CUDA device using this package of OnnxRuntime. " + "Please use the CUDA package of OnnxRuntime to use this feature."); +#endif + } else if (device.Type() == OrtDevice::DML) { +#if USE_DML + // InputDeflist is null because OrtValue creation is not tied to a specific model + // Likewise, there is no need to specify the name (as the name was previously used to lookup the def list) + // TODO: Add check to ensure that string arrays are not passed - we currently don't support string tensors in DML + CreateGenericMLValue( + nullptr, GetDmlAllocator(device.Id()), "", array_on_cpu, ml_value.get(), true, false, CpuToDmlMemCpy); +#else + throw std::runtime_error( + "Can't allocate memory on the CUDA device using this package of OnnxRuntime. " + "Please use the CUDA package of OnnxRuntime to use this feature."); #endif } else if (device.Type() == OrtDevice::NPU) { #ifdef USE_CANN @@ -116,9 +122,9 @@ void addOrtValueMethods(pybind11::module& m) { CreateGenericMLValue(nullptr, GetCannAllocator(device.Id()), "", array_on_cpu, ml_value.get(), true, false, CpuToCannMemCpy); #else - throw std::runtime_error( - "Can't allocate memory on the CANN device using this package of OnnxRuntime. " - "Please use the CANN package of OnnxRuntime to use this feature."); + throw std::runtime_error( + "Can't allocate memory on the CANN device using this package of OnnxRuntime. " + "Please use the CANN package of OnnxRuntime to use this feature."); #endif } else { throw std::runtime_error("Unsupported device: Cannot place the OrtValue on this device"); @@ -160,19 +166,24 @@ void addOrtValueMethods(pybind11::module& m) { } onnxruntime::python::CopyDataToTensor( - py_values, - values_type, - *(ml_value->GetMutable()), - CpuToRocmMemCpy); -#elif USE_DML - onnxruntime::python::CopyDataToTensor( - py_values, - values_type, - *(ml_value->GetMutable()), - CpuToDmlMemCpy); + py_values, + values_type, + *(ml_value->GetMutable()), + CpuToRocmMemCpy); #else - throw std::runtime_error( - "Unsupported GPU device: Cannot find the supported GPU device."); + throw std::runtime_error( + "Unsupported GPU device: Cannot find the supported GPU device."); +#endif + } else if (device.Type() == OrtDevice::DML) { +#if USE_DML + onnxruntime::python::CopyDataToTensor( + py_values, + values_type, + *(ml_value->GetMutable()), + CpuToDmlMemCpy); +#else + throw std::runtime_error( + "Unsupported GPU device: Cannot find the supported GPU device."); #endif } else { throw std::runtime_error("Unsupported device: Cannot update the OrtValue on this device"); diff --git a/onnxruntime/python/onnxruntime_pybind_state.cc b/onnxruntime/python/onnxruntime_pybind_state.cc index b8c8293746..4d9583be0e 100644 --- a/onnxruntime/python/onnxruntime_pybind_state.cc +++ b/onnxruntime/python/onnxruntime_pybind_state.cc @@ -288,11 +288,9 @@ const char* GetDeviceName(const OrtDevice& device) { case OrtDevice::CPU: return CPU; case OrtDevice::GPU: -#ifdef USE_DML - return DML; -#else return CUDA; -#endif + case OrtDevice::DML: + return DML; case OrtDevice::FPGA: return "FPGA"; case OrtDevice::NPU: @@ -1579,7 +1577,7 @@ void addObjectMethods(py::module& m, ExecutionProviderRegistrationFn ep_registra .def_static("cann", []() { return OrtDevice::NPU; }) .def_static("fpga", []() { return OrtDevice::FPGA; }) .def_static("npu", []() { return OrtDevice::NPU; }) - .def_static("dml", []() { return OrtDevice::GPU; }) + .def_static("dml", []() { return OrtDevice::DML; }) .def_static("webgpu", []() { return OrtDevice::GPU; }) .def_static("default_memory", []() { return OrtDevice::MemType::DEFAULT; });