mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-26 03:00:54 +00:00
For HLSL shader ops in the DirectML EP (STFT,DFT) FP16 ops should fal… (#15448)
CP: [For HLSL shader ops in the DirectML EP (STFT,DFT) FP16 ops should fallback to CPU when there is no hardware support #15414 ](https://github.com/microsoft/onnxruntime/pull/15414) For HLSL shader ops in the DirectML EP (STFT,DFT) FP16 ops should fallback to CPU when there is no hardware support.
This commit is contained in:
parent
6657df9212
commit
ce9ad8c8bc
4 changed files with 46 additions and 2 deletions
|
|
@ -176,6 +176,15 @@ namespace Dml
|
|||
sizeof(featureLevels)
|
||||
));
|
||||
|
||||
D3D12_FEATURE_DATA_D3D12_OPTIONS4 featureOptions = {};
|
||||
if (SUCCEEDED(d3d12Device->CheckFeatureSupport(
|
||||
D3D12_FEATURE_D3D12_OPTIONS4,
|
||||
&featureOptions,
|
||||
sizeof(featureOptions))))
|
||||
{
|
||||
m_native16BitShaderOpsSupported = featureOptions.Native16BitShaderOpsSupported;
|
||||
}
|
||||
|
||||
m_isMcdmDevice = (featureLevels.MaxSupportedFeatureLevel == D3D_FEATURE_LEVEL_1_0_CORE_PRIVATE);
|
||||
|
||||
m_context = std::make_shared<ExecutionContext>(m_d3d12Device.Get(), m_dmlDevice.Get(), queue);
|
||||
|
|
@ -650,10 +659,28 @@ namespace Dml
|
|||
return false;
|
||||
}
|
||||
|
||||
bool IsCustomOpShader(const onnxruntime::Node& node)
|
||||
{
|
||||
auto custom_ops = std::array<char*, 2>{
|
||||
"DFT",
|
||||
"STFT"
|
||||
};
|
||||
|
||||
for (auto& custom_op : custom_ops)
|
||||
{
|
||||
if (strcmp(custom_op, node.OpType().c_str()) == 0)
|
||||
{
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
bool DoesNodeContainSupportedDataTypes(
|
||||
const onnxruntime::Node& node,
|
||||
_In_opt_ const InternalRegistrationInfo* regInfo,
|
||||
uint32_t supportedDeviceDataTypeMask // Each bit corresponds to each DML_TENSOR_DATA_TYPE.
|
||||
uint32_t supportedDeviceDataTypeMask, // Each bit corresponds to each DML_TENSOR_DATA_TYPE.
|
||||
bool native16BitShaderOpsSupported
|
||||
)
|
||||
{
|
||||
std::vector<onnxruntime::NodeArg const*> constantCpuInputs;
|
||||
|
|
@ -698,6 +725,14 @@ namespace Dml
|
|||
return;
|
||||
}
|
||||
|
||||
if (onnxElementType == MLOperatorTensorDataType::Float16 &&
|
||||
!native16BitShaderOpsSupported &&
|
||||
IsCustomOpShader(node))
|
||||
{
|
||||
nodeContainsSupportedDataTypes = false;
|
||||
return;
|
||||
}
|
||||
|
||||
// Allow nodeArgs that are SequenceTensor when they are actually implemented by CPU Kernels.
|
||||
if (edgeType == MLOperatorEdgeType::SequenceTensor)
|
||||
{
|
||||
|
|
@ -768,7 +803,7 @@ namespace Dml
|
|||
}
|
||||
|
||||
// Check whether the node uses any data types which are unsupported by the device.
|
||||
if (!DoesNodeContainSupportedDataTypes(node, internalRegInfo.get(), supportedDeviceDataTypeMask))
|
||||
if (!DoesNodeContainSupportedDataTypes(node, internalRegInfo.get(), supportedDeviceDataTypeMask, m_native16BitShaderOpsSupported))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -184,6 +184,7 @@ namespace Dml
|
|||
ComPtr<IDMLDevice> m_dmlDevice;
|
||||
bool m_isMcdmDevice = false;
|
||||
bool m_areMetacommandsEnabled = true;
|
||||
bool m_native16BitShaderOpsSupported = false;
|
||||
std::shared_ptr<ExecutionContext> m_context;
|
||||
std::unique_ptr<PooledUploadHeap> m_uploadHeap;
|
||||
std::unique_ptr<ReadbackHeap> m_readbackHeap;
|
||||
|
|
|
|||
|
|
@ -5,6 +5,10 @@
|
|||
|
||||
#include "../External/D3DX12/d3dx12.h"
|
||||
|
||||
// NOTE: When this operator's implementation is moved into DML, the associated FP16 fallback
|
||||
// should be removed from IsCustomOpShader(...) in
|
||||
// onnxruntime\core\providers\dml\DmlExecutionProvider\src\ExecutionProvider.cpp
|
||||
|
||||
// The shader headers are produced using "GeneratedShaders/GenerateShaders.bat"
|
||||
namespace DFTFloat32
|
||||
{
|
||||
|
|
|
|||
|
|
@ -2,6 +2,10 @@
|
|||
|
||||
#include "DmlDFT.h"
|
||||
|
||||
// NOTE: When this operator's implementation is moved into DML, the associated FP16 fallback
|
||||
// should be removed from IsCustomOpShader(...) in
|
||||
// onnxruntime\core\providers\dml\DmlExecutionProvider\src\ExecutionProvider.cpp
|
||||
|
||||
enum DmlSTFTKernelInputIndex : uint32_t
|
||||
{
|
||||
Signal,
|
||||
|
|
|
|||
Loading…
Reference in a new issue