diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlGraphFusionHelper.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlGraphFusionHelper.cpp index a7c2fcf178..d436877e40 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlGraphFusionHelper.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlGraphFusionHelper.cpp @@ -217,8 +217,7 @@ namespace DmlGraphFusionHelper { ComPtr initializeInputBuffer; - // D3D_FEATURE_LEVEL_1_0_CORE doesn't support Custom heaps - if (providerImpl->IsMcdmDevice()) + if (!providerImpl->CustomHeapsSupported()) { initializeInputBuffer = CreateResource(providerImpl, tensorPtr, tensorByteSize); } diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.cpp index f97b72aa2d..36f24f2200 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.cpp @@ -181,6 +181,31 @@ namespace Dml } m_isMcdmDevice = (featureLevels.MaxSupportedFeatureLevel == D3D_FEATURE_LEVEL_1_0_CORE_PRIVATE); + m_areCustomHeapsSupported = !m_isMcdmDevice; + + if (m_isMcdmDevice) { + + // TODO: Ingest updated header file + typedef struct D3D12_FEATURE_DATA_D3D12_OPTIONS19 + { + BOOL MismatchingOutputDimensionsSupported; + UINT SupportedSampleCountsWithNoOutputs; + BOOL PointSamplingAddressesNeverRoundUp; + BOOL RasterizerDesc2Supported; + BOOL NarrowQuadrilateralLinesSupported; + BOOL AnisoFilterWithPointMipSupported; + UINT MaxSamplerDescriptorHeapSize; + UINT MaxSamplerDescriptorHeapSizeWithStaticSamplers; + UINT MaxViewDescriptorHeapSize; + _Out_ BOOL ComputeOnlyCustomHeapSupported; + } D3D12_FEATURE_DATA_D3D12_OPTIONS19; + + D3D12_FEATURE_DATA_D3D12_OPTIONS19 options19 = {}; + + // The call may fail in which case the default value is false + d3d12Device->CheckFeatureSupport((D3D12_FEATURE) 48 /*D3D12_FEATURE_D3D12_OPTIONS19*/, &options19, sizeof(options19)); + m_areCustomHeapsSupported = options19.ComputeOnlyCustomHeapSupported; + } m_context = std::make_shared(m_d3d12Device.Get(), m_dmlDevice.Get(), queue); @@ -1088,6 +1113,11 @@ namespace Dml return m_isMcdmDevice; } + bool __stdcall ExecutionProviderImpl::CustomHeapsSupported() const noexcept + { + return m_areCustomHeapsSupported; + } + bool __stdcall ExecutionProviderImpl::MetacommandsEnabled() const noexcept { return m_areMetacommandsEnabled; diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.h index 31b893a2f2..09e486d27f 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.h +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.h @@ -148,6 +148,7 @@ namespace Dml } STDMETHOD_(bool, IsMcdmDevice)() const noexcept final; + STDMETHOD_(bool, CustomHeapsSupported)() const noexcept final; STDMETHOD_(bool, MetacommandsEnabled)() const noexcept final; std::shared_ptr GetGpuAllocator(); @@ -183,6 +184,7 @@ namespace Dml ComPtr m_d3d12Device; ComPtr m_dmlDevice; bool m_isMcdmDevice = false; + bool m_areCustomHeapsSupported = false; bool m_areMetacommandsEnabled = true; bool m_native16BitShaderOpsSupported = false; std::shared_ptr m_context; diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/IExecutionProvider.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/IExecutionProvider.h index d7a0a607cd..332a574861 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/IExecutionProvider.h +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/IExecutionProvider.h @@ -69,6 +69,7 @@ namespace Dml STDMETHOD(AllocatePooledResource(size_t size, AllocatorRoundingMode roundingMode, ID3D12Resource **d3dResource, IUnknown* *pooledResource)) const noexcept = 0; STDMETHOD_(bool, IsMcdmDevice)() const noexcept = 0; + STDMETHOD_(bool, CustomHeapsSupported)() const noexcept = 0; STDMETHOD_(bool, MetacommandsEnabled)() const noexcept = 0; }; } // namespace Dml