From ce9ad8c8bccf25d89dbd8983ed3d29609f3d6607 Mon Sep 17 00:00:00 2001 From: Sheil Kumar Date: Mon, 10 Apr 2023 13:21:40 -0700 Subject: [PATCH] =?UTF-8?q?For=20HLSL=20shader=20ops=20in=20the=20DirectML?= =?UTF-8?q?=20EP=20(STFT,DFT)=20FP16=20ops=20should=20fal=E2=80=A6=20(#154?= =?UTF-8?q?48)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit CP: [For HLSL shader ops in the DirectML EP (STFT,DFT) FP16 ops should fallback to CPU when there is no hardware support #15414 ](https://github.com/microsoft/onnxruntime/pull/15414) For HLSL shader ops in the DirectML EP (STFT,DFT) FP16 ops should fallback to CPU when there is no hardware support. --- .../src/ExecutionProvider.cpp | 39 ++++++++++++++++++- .../src/ExecutionProvider.h | 1 + .../src/Operators/DmlDFT.h | 4 ++ .../src/Operators/DmlSTFT.h | 4 ++ 4 files changed, 46 insertions(+), 2 deletions(-) diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.cpp index 71a409ea5c..90ecc7f87f 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.cpp @@ -176,6 +176,15 @@ namespace Dml sizeof(featureLevels) )); + D3D12_FEATURE_DATA_D3D12_OPTIONS4 featureOptions = {}; + if (SUCCEEDED(d3d12Device->CheckFeatureSupport( + D3D12_FEATURE_D3D12_OPTIONS4, + &featureOptions, + sizeof(featureOptions)))) + { + m_native16BitShaderOpsSupported = featureOptions.Native16BitShaderOpsSupported; + } + m_isMcdmDevice = (featureLevels.MaxSupportedFeatureLevel == D3D_FEATURE_LEVEL_1_0_CORE_PRIVATE); m_context = std::make_shared(m_d3d12Device.Get(), m_dmlDevice.Get(), queue); @@ -650,10 +659,28 @@ namespace Dml return false; } + bool IsCustomOpShader(const onnxruntime::Node& node) + { + auto custom_ops = std::array{ + "DFT", + "STFT" + }; + + for (auto& custom_op : custom_ops) + { + if (strcmp(custom_op, node.OpType().c_str()) == 0) + { + return true; + } + } + return false; + } + bool DoesNodeContainSupportedDataTypes( const onnxruntime::Node& node, _In_opt_ const InternalRegistrationInfo* regInfo, - uint32_t supportedDeviceDataTypeMask // Each bit corresponds to each DML_TENSOR_DATA_TYPE. + uint32_t supportedDeviceDataTypeMask, // Each bit corresponds to each DML_TENSOR_DATA_TYPE. + bool native16BitShaderOpsSupported ) { std::vector constantCpuInputs; @@ -698,6 +725,14 @@ namespace Dml return; } + if (onnxElementType == MLOperatorTensorDataType::Float16 && + !native16BitShaderOpsSupported && + IsCustomOpShader(node)) + { + nodeContainsSupportedDataTypes = false; + return; + } + // Allow nodeArgs that are SequenceTensor when they are actually implemented by CPU Kernels. if (edgeType == MLOperatorEdgeType::SequenceTensor) { @@ -768,7 +803,7 @@ namespace Dml } // Check whether the node uses any data types which are unsupported by the device. - if (!DoesNodeContainSupportedDataTypes(node, internalRegInfo.get(), supportedDeviceDataTypeMask)) + if (!DoesNodeContainSupportedDataTypes(node, internalRegInfo.get(), supportedDeviceDataTypeMask, m_native16BitShaderOpsSupported)) { return false; } diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.h index 7078382894..b9ac772095 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.h +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.h @@ -184,6 +184,7 @@ namespace Dml ComPtr m_dmlDevice; bool m_isMcdmDevice = false; bool m_areMetacommandsEnabled = true; + bool m_native16BitShaderOpsSupported = false; std::shared_ptr m_context; std::unique_ptr m_uploadHeap; std::unique_ptr m_readbackHeap; diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlDFT.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlDFT.h index 636e17982d..77965a9d7c 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlDFT.h +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlDFT.h @@ -5,6 +5,10 @@ #include "../External/D3DX12/d3dx12.h" +// NOTE: When this operator's implementation is moved into DML, the associated FP16 fallback +// should be removed from IsCustomOpShader(...) in +// onnxruntime\core\providers\dml\DmlExecutionProvider\src\ExecutionProvider.cpp + // The shader headers are produced using "GeneratedShaders/GenerateShaders.bat" namespace DFTFloat32 { diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlSTFT.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlSTFT.h index dea5545fa1..cca4911028 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlSTFT.h +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlSTFT.h @@ -2,6 +2,10 @@ #include "DmlDFT.h" +// NOTE: When this operator's implementation is moved into DML, the associated FP16 fallback +// should be removed from IsCustomOpShader(...) in +// onnxruntime\core\providers\dml\DmlExecutionProvider\src\ExecutionProvider.cpp + enum DmlSTFTKernelInputIndex : uint32_t { Signal,