mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-31 23:27:43 +00:00
Add API for NPU Device Selection in the DML EP (#17612)
Co-authored-by: Sheil Kumar <sheilk@microsoft.com>
This commit is contained in:
parent
a441a71e8e
commit
b8f373b0ae
8 changed files with 390 additions and 27 deletions
|
|
@ -202,4 +202,4 @@ endif()
|
|||
|
||||
if (onnxruntime_USE_AZURE)
|
||||
include(onnxruntime_providers_azure.cmake)
|
||||
endif()
|
||||
endif()
|
||||
|
|
|
|||
|
|
@ -56,13 +56,13 @@
|
|||
if (GDK_PLATFORM STREQUAL Scarlett)
|
||||
target_link_libraries(onnxruntime_providers_dml PRIVATE ${gdk_dx_libs})
|
||||
else()
|
||||
target_link_libraries(onnxruntime_providers_dml PRIVATE dxguid.lib d3d12.lib dxgi.lib)
|
||||
target_link_libraries(onnxruntime_providers_dml PRIVATE dxguid.lib d3d12.lib dxgi.lib dxcore.lib)
|
||||
endif()
|
||||
|
||||
target_link_libraries(onnxruntime_providers_dml PRIVATE delayimp.lib)
|
||||
|
||||
if (NOT GDK_PLATFORM)
|
||||
set(onnxruntime_DELAYLOAD_FLAGS "${onnxruntime_DELAYLOAD_FLAGS} /DELAYLOAD:DirectML.dll /DELAYLOAD:d3d12.dll /DELAYLOAD:dxgi.dll /DELAYLOAD:api-ms-win-core-com-l1-1-0.dll /DELAYLOAD:shlwapi.dll /DELAYLOAD:oleaut32.dll /ignore:4199")
|
||||
set(onnxruntime_DELAYLOAD_FLAGS "${onnxruntime_DELAYLOAD_FLAGS} /DELAYLOAD:DirectML.dll /DELAYLOAD:d3d12.dll /DELAYLOAD:dxgi.dll /DELAYLOAD:api-ms-win-core-com-l1-1-0.dll /DELAYLOAD:shlwapi.dll /DELAYLOAD:oleaut32.dll /DELAYLOAD:ext-ms-win-dxcore-l1-*.dll /ignore:4199")
|
||||
endif()
|
||||
|
||||
target_compile_definitions(onnxruntime_providers_dml
|
||||
|
|
@ -88,4 +88,4 @@
|
|||
LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR}
|
||||
RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR}
|
||||
FRAMEWORK DESTINATION ${CMAKE_INSTALL_BINDIR})
|
||||
endif()
|
||||
endif()
|
||||
|
|
|
|||
|
|
@ -30,6 +30,31 @@ typedef struct IDMLDevice IDMLDevice;
|
|||
extern "C" {
|
||||
#endif
|
||||
|
||||
enum OrtDmlPerformancePreference {
|
||||
Default = 0,
|
||||
HighPerformance = 1,
|
||||
MinimumPower = 2
|
||||
};
|
||||
|
||||
enum OrtDmlDeviceFilter : uint32_t {
|
||||
Any = 0xffffffff,
|
||||
Gpu = 1 << 0,
|
||||
Npu = 1 << 1,
|
||||
};
|
||||
|
||||
inline OrtDmlDeviceFilter operator~(OrtDmlDeviceFilter a) { return (OrtDmlDeviceFilter) ~(int)a; }
|
||||
inline OrtDmlDeviceFilter operator|(OrtDmlDeviceFilter a, OrtDmlDeviceFilter b) { return (OrtDmlDeviceFilter)((int)a | (int)b); }
|
||||
inline OrtDmlDeviceFilter operator&(OrtDmlDeviceFilter a, OrtDmlDeviceFilter b) { return (OrtDmlDeviceFilter)((int)a & (int)b); }
|
||||
inline OrtDmlDeviceFilter operator^(OrtDmlDeviceFilter a, OrtDmlDeviceFilter b) { return (OrtDmlDeviceFilter)((int)a ^ (int)b); }
|
||||
inline OrtDmlDeviceFilter& operator|=(OrtDmlDeviceFilter& a, OrtDmlDeviceFilter b) { return (OrtDmlDeviceFilter&)((int&)a |= (int)b); }
|
||||
inline OrtDmlDeviceFilter& operator&=(OrtDmlDeviceFilter& a, OrtDmlDeviceFilter b) { return (OrtDmlDeviceFilter&)((int&)a &= (int)b); }
|
||||
inline OrtDmlDeviceFilter& operator^=(OrtDmlDeviceFilter& a, OrtDmlDeviceFilter b) { return (OrtDmlDeviceFilter&)((int&)a ^= (int)b); }
|
||||
|
||||
struct OrtDmlDeviceOptions {
|
||||
OrtDmlPerformancePreference Preference;
|
||||
OrtDmlDeviceFilter Filter;
|
||||
};
|
||||
|
||||
/**
|
||||
* [[deprecated]]
|
||||
* This export is deprecated.
|
||||
|
|
@ -99,6 +124,13 @@ struct OrtDmlApi {
|
|||
* This API gets the D3D12 resource when an OrtValue has been allocated by the DML EP.
|
||||
*/
|
||||
ORT_API2_STATUS(GetD3D12ResourceFromAllocation, _In_ OrtAllocator* provider, _In_ void* dml_resource, _Out_ ID3D12Resource** d3d_resource);
|
||||
|
||||
/**
|
||||
* SessionOptionsAppendExecutionProvider_DML2
|
||||
* Creates a DirectML Execution Provider given the supplied device options that contain a performance preference
|
||||
* (high power, low power, or defult) and a device filter (None, GPU, or NPU).
|
||||
*/
|
||||
ORT_API2_STATUS(SessionOptionsAppendExecutionProvider_DML2, _In_ OrtSessionOptions* options, OrtDmlDeviceOptions* device_opts);
|
||||
};
|
||||
|
||||
#ifdef __cplusplus
|
||||
|
|
|
|||
|
|
@ -1,6 +1,9 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include <dxcore.h>
|
||||
#include <vector>
|
||||
|
||||
#include <DirectML.h>
|
||||
#ifndef _GAMING_XBOX
|
||||
#include <dxgi1_4.h>
|
||||
|
|
@ -92,12 +95,298 @@ bool IsSoftwareAdapter(IDXGIAdapter1* adapter) {
|
|||
return isSoftwareAdapter || (isBasicRenderDriverVendorId && isBasicRenderDriverDeviceId);
|
||||
}
|
||||
|
||||
static bool IsHardwareAdapter(IDXCoreAdapter* adapter) {
|
||||
bool is_hardware = false;
|
||||
THROW_IF_FAILED(adapter->GetProperty(
|
||||
DXCoreAdapterProperty::IsHardware,
|
||||
&is_hardware));
|
||||
return is_hardware;
|
||||
}
|
||||
|
||||
static bool IsGPU(IDXCoreAdapter* compute_adapter) {
|
||||
// Only considering hardware adapters
|
||||
if (!IsHardwareAdapter(compute_adapter)) {
|
||||
return false;
|
||||
}
|
||||
return compute_adapter->IsAttributeSupported(DXCORE_ADAPTER_ATTRIBUTE_D3D12_GRAPHICS);
|
||||
}
|
||||
|
||||
static bool IsNPU(IDXCoreAdapter* compute_adapter) {
|
||||
// Only considering hardware adapters
|
||||
if (!IsHardwareAdapter(compute_adapter)) {
|
||||
return false;
|
||||
}
|
||||
return !(compute_adapter->IsAttributeSupported(DXCORE_ADAPTER_ATTRIBUTE_D3D12_GRAPHICS));
|
||||
}
|
||||
|
||||
enum class DeviceType { GPU, NPU, BadDevice };
|
||||
|
||||
static DeviceType FilterAdapterTypeQuery(IDXCoreAdapter* adapter, OrtDmlDeviceFilter filter) {
|
||||
auto allow_gpus = (filter & OrtDmlDeviceFilter::Gpu) == OrtDmlDeviceFilter::Gpu;
|
||||
if (IsGPU(adapter) && allow_gpus) {
|
||||
return DeviceType::GPU;
|
||||
}
|
||||
|
||||
auto allow_npus = (filter & OrtDmlDeviceFilter::Npu) == OrtDmlDeviceFilter::Npu;
|
||||
if (IsNPU(adapter) && allow_npus) {
|
||||
return DeviceType::NPU;
|
||||
}
|
||||
|
||||
return DeviceType::BadDevice;
|
||||
}
|
||||
|
||||
// Struct for holding each adapter
|
||||
struct AdapterInfo {
|
||||
ComPtr<IDXCoreAdapter> Adapter;
|
||||
DeviceType Type; // GPU or NPU
|
||||
};
|
||||
|
||||
static ComPtr<IDXCoreAdapterList> EnumerateDXCoreAdapters(IDXCoreAdapterFactory* adapter_factory) {
|
||||
ComPtr<IDXCoreAdapterList> adapter_list;
|
||||
|
||||
// TODO: use_dxcore_workload_enumeration should be determined by QI
|
||||
// When DXCore APIs are available QI for relevant enumeration interfaces
|
||||
constexpr bool use_dxcore_workload_enumeration = false;
|
||||
if (!use_dxcore_workload_enumeration) {
|
||||
// Get a list of all the adapters that support compute
|
||||
GUID attributes[]{ DXCORE_ADAPTER_ATTRIBUTE_D3D12_CORE_COMPUTE };
|
||||
ORT_THROW_IF_FAILED(
|
||||
adapter_factory->CreateAdapterList(_countof(attributes),
|
||||
attributes,
|
||||
adapter_list.GetAddressOf()));
|
||||
}
|
||||
|
||||
return adapter_list;
|
||||
}
|
||||
|
||||
static void SortDXCoreAdaptersByPreference(
|
||||
IDXCoreAdapterList* adapter_list,
|
||||
OrtDmlPerformancePreference preference) {
|
||||
if (adapter_list->GetAdapterCount() <= 1) {
|
||||
return;
|
||||
}
|
||||
|
||||
// DML prefers the HighPerformance adapter by default
|
||||
std::array<DXCoreAdapterPreference, 1> adapter_list_preferences = {
|
||||
DXCoreAdapterPreference::HighPerformance
|
||||
};
|
||||
|
||||
// If callers specify minimum power change the DXCore sort policy
|
||||
// NOTE DXCoreAdapterPrefernce does not apply to mixed adapter lists - only to GPU lists
|
||||
if (preference == OrtDmlPerformancePreference::MinimumPower) {
|
||||
adapter_list_preferences[0] = DXCoreAdapterPreference::MinimumPower;
|
||||
}
|
||||
|
||||
ORT_THROW_IF_FAILED(adapter_list->Sort(
|
||||
static_cast<uint32_t>(adapter_list_preferences.size()),
|
||||
adapter_list_preferences.data()));
|
||||
}
|
||||
|
||||
static std::vector<AdapterInfo> FilterDXCoreAdapters(
|
||||
IDXCoreAdapterList* adapter_list,
|
||||
OrtDmlDeviceFilter filter) {
|
||||
auto adapter_infos = std::vector<AdapterInfo>();
|
||||
const uint32_t count = adapter_list->GetAdapterCount();
|
||||
for (uint32_t i = 0; i < count; ++i) {
|
||||
ComPtr<IDXCoreAdapter> candidate_adapter;
|
||||
ORT_THROW_IF_FAILED(adapter_list->GetAdapter(i, candidate_adapter.GetAddressOf()));
|
||||
|
||||
// Add the adapters that are valid based on the device filter (GPU, NPU, or Both)
|
||||
auto adapter_type = FilterAdapterTypeQuery(candidate_adapter.Get(), filter);
|
||||
if (adapter_type != DeviceType::BadDevice) {
|
||||
adapter_infos.push_back(AdapterInfo{candidate_adapter, adapter_type});
|
||||
}
|
||||
}
|
||||
|
||||
return adapter_infos;
|
||||
}
|
||||
|
||||
static void SortHeterogenousDXCoreAdapterList(
|
||||
std::vector<AdapterInfo>& adapter_infos,
|
||||
OrtDmlDeviceFilter filter,
|
||||
OrtDmlPerformancePreference preference) {
|
||||
if (adapter_infos.size() <= 1) {
|
||||
return;
|
||||
}
|
||||
|
||||
// When considering both GPUs and NPUs sort them by performance preference
|
||||
// of Default (Gpus first), HighPerformance (GPUs first), or LowPower (NPUs first)
|
||||
auto keep_npus = (filter & OrtDmlDeviceFilter::Npu) == OrtDmlDeviceFilter::Npu;
|
||||
auto only_npus = filter == OrtDmlDeviceFilter::Npu;
|
||||
if (!keep_npus || only_npus) {
|
||||
return;
|
||||
}
|
||||
|
||||
struct SortingPolicy {
|
||||
// default is false because GPUs are considered higher priority in
|
||||
// a mixed adapter environment
|
||||
bool npus_first_ = false;
|
||||
|
||||
SortingPolicy(bool npus_first = false) : npus_first_(npus_first) { }
|
||||
|
||||
bool operator()(const AdapterInfo& a, const AdapterInfo& b) {
|
||||
return npus_first_ ? a.Type < b.Type : a.Type > b.Type;
|
||||
}
|
||||
};
|
||||
|
||||
auto npus_first = (preference == OrtDmlPerformancePreference::MinimumPower);
|
||||
auto policy = SortingPolicy(npus_first);
|
||||
std::sort(adapter_infos.begin(), adapter_infos.end(), policy);
|
||||
}
|
||||
|
||||
std::shared_ptr<IExecutionProviderFactory> DMLProviderFactoryCreator::Create(int device_id) {
|
||||
return Create(device_id, /*skip_software_device_check*/ false);
|
||||
}
|
||||
|
||||
Microsoft::WRL::ComPtr<ID3D12Device> DMLProviderFactoryCreator::CreateD3D12Device(int device_id, bool skip_software_device_check)
|
||||
{
|
||||
std::shared_ptr<IExecutionProviderFactory> DMLProviderFactoryCreator::CreateFromOptions(
|
||||
OrtDmlDeviceOptions* device_options) {
|
||||
auto default_device_options = OrtDmlDeviceOptions { Default, Gpu };
|
||||
if (device_options == nullptr) {
|
||||
device_options = &default_device_options;
|
||||
}
|
||||
|
||||
OrtDmlPerformancePreference preference = device_options->Preference;
|
||||
OrtDmlDeviceFilter filter = device_options->Filter;
|
||||
|
||||
// Create DXCore Adapter Factory
|
||||
ComPtr<IDXCoreAdapterFactory> adapter_factory;
|
||||
ORT_THROW_IF_FAILED(::DXCoreCreateAdapterFactory(adapter_factory.GetAddressOf()));
|
||||
|
||||
// Get all DML compatible DXCore adapters
|
||||
ComPtr<IDXCoreAdapterList> adapter_list;
|
||||
adapter_list = EnumerateDXCoreAdapters(adapter_factory.Get());
|
||||
|
||||
if (adapter_list->GetAdapterCount() == 0) {
|
||||
ORT_THROW("No GPUs or NPUs detected.");
|
||||
}
|
||||
|
||||
// Sort the adapter list to honor DXCore hardware ordering
|
||||
SortDXCoreAdaptersByPreference(adapter_list.Get(), preference);
|
||||
|
||||
// TODO: use_dxcore_workload_enumeration should be determined by QI
|
||||
// When DXCore APIs are available QI for relevant enumeration interfaces
|
||||
constexpr bool use_dxcore_workload_enumeration = false;
|
||||
|
||||
std::vector<AdapterInfo> adapter_infos;
|
||||
if (!use_dxcore_workload_enumeration) {
|
||||
// Filter all DXCore adapters to hardware type specified by the device filter
|
||||
adapter_infos = FilterDXCoreAdapters(adapter_list.Get(), filter);
|
||||
if (adapter_infos.size() == 0) {
|
||||
ORT_THROW("No devices detected that match the filter criteria.");
|
||||
}
|
||||
}
|
||||
|
||||
// DXCore Sort ignores NPUs. When both GPUs and NPUs are present, manually sort them.
|
||||
SortHeterogenousDXCoreAdapterList(adapter_infos, filter, preference);
|
||||
|
||||
// Extract just the adapters
|
||||
auto adapters = std::vector<ComPtr<IDXCoreAdapter>>(adapter_infos.size());
|
||||
std::transform(
|
||||
adapter_infos.begin(), adapter_infos.end(),
|
||||
adapters.begin(),
|
||||
[](auto& a){ return a.Adapter; });
|
||||
|
||||
return onnxruntime::DMLProviderFactoryCreator::CreateFromAdapterList(std::move(adapters));
|
||||
}
|
||||
|
||||
static std::optional<OrtDmlPerformancePreference> ParsePerformancePreference(const ProviderOptions& provider_options) {
|
||||
static const std::string PerformancePreference = "performance_preference";
|
||||
static const std::string Default = "default";
|
||||
static const std::string HighPerformance = "high_performance";
|
||||
static const std::string MinimumPower = "minimum_power";
|
||||
|
||||
auto preference_it = provider_options.find(PerformancePreference);
|
||||
if (preference_it != provider_options.end()) {
|
||||
if (preference_it->second == Default) {
|
||||
return OrtDmlPerformancePreference::Default;
|
||||
}
|
||||
|
||||
if (preference_it->second == HighPerformance) {
|
||||
return OrtDmlPerformancePreference::HighPerformance;
|
||||
}
|
||||
|
||||
if (preference_it->second == MinimumPower) {
|
||||
return OrtDmlPerformancePreference::MinimumPower;
|
||||
}
|
||||
|
||||
ORT_THROW("Invalid PerformancePreference provided for DirectML EP device selection.");
|
||||
}
|
||||
|
||||
return {};
|
||||
}
|
||||
|
||||
static std::optional<OrtDmlDeviceFilter> ParseFilter(const ProviderOptions& provider_options) {
|
||||
static const std::string Filter = "filter";
|
||||
static const std::string Any = "any";
|
||||
static const std::string Gpu = "gpu";
|
||||
static const std::string Npu = "npu";
|
||||
|
||||
auto preference_it = provider_options.find(Filter);
|
||||
if (preference_it != provider_options.end()) {
|
||||
if (preference_it->second == Any) {
|
||||
return OrtDmlDeviceFilter::Any;
|
||||
}
|
||||
|
||||
if (preference_it->second == Gpu) {
|
||||
return OrtDmlDeviceFilter::Gpu;
|
||||
}
|
||||
|
||||
if (preference_it->second == Npu) {
|
||||
return OrtDmlDeviceFilter::Npu;
|
||||
}
|
||||
|
||||
ORT_THROW("Invalid Filter provided for DirectML EP device selection.");
|
||||
}
|
||||
|
||||
return {};
|
||||
}
|
||||
|
||||
static std::optional<int> ParseDeviceId(const ProviderOptions& provider_options) {
|
||||
static const std::string DeviceId = "device_id";
|
||||
|
||||
auto preference_it = provider_options.find(DeviceId);
|
||||
if (preference_it != provider_options.end()) {
|
||||
if (!preference_it->second.empty()) {
|
||||
return std::stoi(preference_it->second);
|
||||
}
|
||||
}
|
||||
|
||||
return {};
|
||||
}
|
||||
|
||||
std::shared_ptr<IExecutionProviderFactory> DMLProviderFactoryCreator::CreateFromProviderOptions(
|
||||
const ProviderOptions& provider_options) {
|
||||
auto device_id = ParseDeviceId(provider_options);
|
||||
if (device_id.has_value())
|
||||
{
|
||||
return onnxruntime::DMLProviderFactoryCreator::Create(device_id.value());
|
||||
}
|
||||
|
||||
auto preference = ParsePerformancePreference(provider_options);
|
||||
auto filter = ParseFilter(provider_options);
|
||||
|
||||
// If no preference/filters are specified then create with default preference/filters.
|
||||
if (!preference.has_value() && !filter.has_value()) {
|
||||
return onnxruntime::DMLProviderFactoryCreator::CreateFromOptions(nullptr);
|
||||
}
|
||||
|
||||
if (!preference.has_value()) {
|
||||
preference = OrtDmlPerformancePreference::Default;
|
||||
}
|
||||
|
||||
if (!filter.has_value()) {
|
||||
filter = OrtDmlDeviceFilter::Gpu;
|
||||
}
|
||||
|
||||
OrtDmlDeviceOptions device_options;
|
||||
device_options.Preference = preference.value();
|
||||
device_options.Filter = filter.value();
|
||||
return onnxruntime::DMLProviderFactoryCreator::CreateFromOptions(&device_options);
|
||||
}
|
||||
|
||||
Microsoft::WRL::ComPtr<ID3D12Device> DMLProviderFactoryCreator::CreateD3D12Device(
|
||||
int device_id,
|
||||
bool skip_software_device_check) {
|
||||
#ifdef _GAMING_XBOX
|
||||
ComPtr<ID3D12Device> d3d12_device;
|
||||
D3D12XBOX_CREATE_DEVICE_PARAMETERS params = {};
|
||||
|
|
@ -128,8 +417,7 @@ Microsoft::WRL::ComPtr<ID3D12Device> DMLProviderFactoryCreator::CreateD3D12Devic
|
|||
return d3d12_device;
|
||||
}
|
||||
|
||||
Microsoft::WRL::ComPtr<IDMLDevice> DMLProviderFactoryCreator::CreateDMLDevice(ID3D12Device* d3d12_device)
|
||||
{
|
||||
Microsoft::WRL::ComPtr<IDMLDevice> DMLProviderFactoryCreator::CreateDMLDevice(ID3D12Device* d3d12_device) {
|
||||
DML_CREATE_DEVICE_FLAGS flags = DML_CREATE_DEVICE_FLAG_NONE;
|
||||
|
||||
// In debug builds, enable the DML debug layer if the D3D12 debug layer is also enabled
|
||||
|
|
@ -153,9 +441,7 @@ Microsoft::WRL::ComPtr<IDMLDevice> DMLProviderFactoryCreator::CreateDMLDevice(ID
|
|||
return dml_device;
|
||||
}
|
||||
|
||||
std::shared_ptr<IExecutionProviderFactory> DMLProviderFactoryCreator::Create(int device_id, bool skip_software_device_check) {
|
||||
ComPtr<ID3D12Device> d3d12_device = CreateD3D12Device(device_id, skip_software_device_check);
|
||||
|
||||
std::shared_ptr<IExecutionProviderFactory> CreateDMLDeviceAndProviderFactory(ID3D12Device* d3d12_device) {
|
||||
D3D12_COMMAND_QUEUE_DESC cmd_queue_desc = {};
|
||||
cmd_queue_desc.Type = D3D12_COMMAND_LIST_TYPE_DIRECT;
|
||||
cmd_queue_desc.Flags = D3D12_COMMAND_QUEUE_FLAG_DISABLE_GPU_TIMEOUT;
|
||||
|
|
@ -163,10 +449,27 @@ std::shared_ptr<IExecutionProviderFactory> DMLProviderFactoryCreator::Create(int
|
|||
ComPtr<ID3D12CommandQueue> cmd_queue;
|
||||
ORT_THROW_IF_FAILED(d3d12_device->CreateCommandQueue(&cmd_queue_desc, IID_GRAPHICS_PPV_ARGS(cmd_queue.ReleaseAndGetAddressOf())));
|
||||
|
||||
auto dml_device = CreateDMLDevice(d3d12_device.Get());
|
||||
auto dml_device = onnxruntime::DMLProviderFactoryCreator::CreateDMLDevice(d3d12_device);
|
||||
return CreateExecutionProviderFactory_DML(dml_device.Get(), cmd_queue.Get());
|
||||
}
|
||||
|
||||
std::shared_ptr<IExecutionProviderFactory> DMLProviderFactoryCreator::Create(int device_id, bool skip_software_device_check) {
|
||||
ComPtr<ID3D12Device> d3d12_device = CreateD3D12Device(device_id, skip_software_device_check);
|
||||
return CreateDMLDeviceAndProviderFactory(d3d12_device.Get());
|
||||
}
|
||||
|
||||
std::shared_ptr<IExecutionProviderFactory> DMLProviderFactoryCreator::CreateFromAdapterList(
|
||||
std::vector<ComPtr<IDXCoreAdapter>>&& dxcore_devices) {
|
||||
// Choose the first device from the list since it's the highest priority
|
||||
auto dxcore_device = dxcore_devices[0];
|
||||
|
||||
// Create D3D12 Device from DXCore Adapter
|
||||
ComPtr<ID3D12Device> d3d12_device;
|
||||
ORT_THROW_IF_FAILED(D3D12CreateDevice(dxcore_device.Get(), D3D_FEATURE_LEVEL_11_0, IID_GRAPHICS_PPV_ARGS(d3d12_device.ReleaseAndGetAddressOf())));
|
||||
|
||||
return CreateDMLDeviceAndProviderFactory(d3d12_device.Get());
|
||||
}
|
||||
|
||||
} // namespace onnxruntime
|
||||
|
||||
// [[deprecated]]
|
||||
|
|
@ -211,6 +514,17 @@ ORT_API_STATUS_IMPL(FreeGPUAllocation, _In_ void* ptr) {
|
|||
API_IMPL_END
|
||||
}
|
||||
|
||||
ORT_API_STATUS_IMPL(SessionOptionsAppendExecutionProvider_DML2, _In_ OrtSessionOptions* options, OrtDmlDeviceOptions* device_options) {
|
||||
API_IMPL_BEGIN
|
||||
#ifdef USE_DML
|
||||
auto factory = onnxruntime::DMLProviderFactoryCreator::CreateFromOptions(device_options);
|
||||
// return the create function for a dxcore device
|
||||
options->provider_factories.push_back(factory);
|
||||
#endif // USE_DML
|
||||
return nullptr;
|
||||
API_IMPL_END
|
||||
}
|
||||
|
||||
ORT_API_STATUS_IMPL(GetD3D12ResourceFromAllocation, _In_ OrtAllocator* ort_allocator, _In_ void* allocation, _Out_ ID3D12Resource** d3d_resource) {
|
||||
API_IMPL_BEGIN
|
||||
#ifdef USE_DML
|
||||
|
|
|
|||
|
|
@ -7,14 +7,26 @@
|
|||
|
||||
#include <wrl/client.h>
|
||||
#include <d3d12.h>
|
||||
#include "core/framework/provider_options.h"
|
||||
#include "core/providers/providers.h"
|
||||
#include "core/providers/dml/dml_provider_factory.h"
|
||||
|
||||
#include <dxcore.h>
|
||||
#include <vector>
|
||||
|
||||
namespace onnxruntime {
|
||||
|
||||
struct DMLProviderFactoryCreator {
|
||||
static std::shared_ptr<IExecutionProviderFactory> Create(int device_id);
|
||||
static std::shared_ptr<IExecutionProviderFactory> Create(int device_id, bool skip_software_device_check);
|
||||
|
||||
static std::shared_ptr<IExecutionProviderFactory> CreateFromProviderOptions(
|
||||
const ProviderOptions& provider_options_map);
|
||||
static std::shared_ptr<IExecutionProviderFactory> CreateFromOptions(OrtDmlDeviceOptions* device_options);
|
||||
|
||||
static std::shared_ptr<IExecutionProviderFactory> CreateFromAdapterList(
|
||||
std::vector<Microsoft::WRL::ComPtr<IDXCoreAdapter>>&& dxcore_devices);
|
||||
|
||||
static Microsoft::WRL::ComPtr<ID3D12Device> CreateD3D12Device(int device_id, bool skip_software_device_check);
|
||||
static Microsoft::WRL::ComPtr<IDMLDevice> CreateDMLDevice(ID3D12Device* d3d12_device);
|
||||
};
|
||||
|
|
|
|||
|
|
@ -12,6 +12,10 @@
|
|||
#include "core/session/ort_apis.h"
|
||||
#include "core/providers/openvino/openvino_provider_factory_creator.h"
|
||||
|
||||
#if defined(USE_DML)
|
||||
#include "core/providers/dml/dml_provider_factory_creator.h"
|
||||
#endif
|
||||
|
||||
using namespace onnxruntime;
|
||||
|
||||
namespace {
|
||||
|
|
@ -67,7 +71,13 @@ ORT_API_STATUS_IMPL(OrtApis::SessionOptionsAppendExecutionProvider,
|
|||
(std::string(provider_name) + " execution provider is not supported in this build. ").c_str());
|
||||
};
|
||||
|
||||
if (strcmp(provider_name, "QNN") == 0) {
|
||||
if (strcmp(provider_name, "DML") == 0) {
|
||||
#if defined(USE_DML)
|
||||
options->provider_factories.push_back(DMLProviderFactoryCreator::CreateFromProviderOptions(provider_options));
|
||||
#else
|
||||
status = create_not_supported_status();
|
||||
#endif
|
||||
} else if (strcmp(provider_name, "QNN") == 0) {
|
||||
#if defined(USE_QNN)
|
||||
options->provider_factories.push_back(QNNProviderFactoryCreator::Create(provider_options, &(options->value)));
|
||||
#else
|
||||
|
|
|
|||
|
|
@ -879,18 +879,10 @@ std::unique_ptr<IExecutionProvider> CreateExecutionProviderInstance(
|
|||
#endif
|
||||
} else if (type == kDmlExecutionProvider) {
|
||||
#ifdef USE_DML
|
||||
int device_id = 0;
|
||||
auto it = provider_options_map.find(type);
|
||||
if (it != provider_options_map.end()) {
|
||||
for (auto option : it->second) {
|
||||
if (option.first == "device_id") {
|
||||
if (!option.second.empty()) {
|
||||
device_id = std::stoi(option.second);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return onnxruntime::DMLProviderFactoryCreator::Create(device_id)->CreateProvider();
|
||||
auto cit = provider_options_map.find(type);
|
||||
return onnxruntime::DMLProviderFactoryCreator::CreateFromProviderOptions(
|
||||
cit == provider_options_map.end() ? ProviderOptions{} : cit->second)
|
||||
->CreateProvider();
|
||||
#endif
|
||||
} else if (type == kNnapiExecutionProvider) {
|
||||
#if defined(USE_NNAPI)
|
||||
|
|
|
|||
|
|
@ -662,9 +662,12 @@ select from 'TF8', 'TF16', 'UINT8', 'FLOAT', 'ITENSOR'. \n)");
|
|||
#endif
|
||||
} else if (provider_name == onnxruntime::kDmlExecutionProvider) {
|
||||
#ifdef USE_DML
|
||||
Ort::ThrowOnError(OrtSessionOptionsAppendExecutionProvider_DML(session_options, 0));
|
||||
std::unordered_map<std::string, std::string> dml_options;
|
||||
dml_options["performance_preference"] = "high_performance";
|
||||
dml_options["device_filter"] = "gpu";
|
||||
session_options.AppendExecutionProvider("DML", dml_options);
|
||||
#else
|
||||
ORT_THROW("DirectML is not supported in this build\n");
|
||||
ORT_THROW("DML is not supported in this build\n");
|
||||
#endif
|
||||
} else if (provider_name == onnxruntime::kAclExecutionProvider) {
|
||||
#ifdef USE_ACL
|
||||
|
|
|
|||
Loading…
Reference in a new issue