For HLSL shader ops in the DirectML EP (STFT,DFT) FP16 ops should fal… (#15448)

CP: [For HLSL shader ops in the DirectML EP (STFT,DFT) FP16 ops should
fallback to CPU when there is no hardware support #15414
](https://github.com/microsoft/onnxruntime/pull/15414)

For HLSL shader ops in the DirectML EP (STFT,DFT) FP16 ops should
fallback to CPU when there is no hardware support.
This commit is contained in:
Sheil Kumar 2023-04-10 13:21:40 -07:00 committed by GitHub
parent 6657df9212
commit ce9ad8c8bc
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 46 additions and 2 deletions

View file

@ -176,6 +176,15 @@ namespace Dml
sizeof(featureLevels)
));
D3D12_FEATURE_DATA_D3D12_OPTIONS4 featureOptions = {};
if (SUCCEEDED(d3d12Device->CheckFeatureSupport(
D3D12_FEATURE_D3D12_OPTIONS4,
&featureOptions,
sizeof(featureOptions))))
{
m_native16BitShaderOpsSupported = featureOptions.Native16BitShaderOpsSupported;
}
m_isMcdmDevice = (featureLevels.MaxSupportedFeatureLevel == D3D_FEATURE_LEVEL_1_0_CORE_PRIVATE);
m_context = std::make_shared<ExecutionContext>(m_d3d12Device.Get(), m_dmlDevice.Get(), queue);
@ -650,10 +659,28 @@ namespace Dml
return false;
}
bool IsCustomOpShader(const onnxruntime::Node& node)
{
auto custom_ops = std::array<char*, 2>{
"DFT",
"STFT"
};
for (auto& custom_op : custom_ops)
{
if (strcmp(custom_op, node.OpType().c_str()) == 0)
{
return true;
}
}
return false;
}
bool DoesNodeContainSupportedDataTypes(
const onnxruntime::Node& node,
_In_opt_ const InternalRegistrationInfo* regInfo,
uint32_t supportedDeviceDataTypeMask // Each bit corresponds to each DML_TENSOR_DATA_TYPE.
uint32_t supportedDeviceDataTypeMask, // Each bit corresponds to each DML_TENSOR_DATA_TYPE.
bool native16BitShaderOpsSupported
)
{
std::vector<onnxruntime::NodeArg const*> constantCpuInputs;
@ -698,6 +725,14 @@ namespace Dml
return;
}
if (onnxElementType == MLOperatorTensorDataType::Float16 &&
!native16BitShaderOpsSupported &&
IsCustomOpShader(node))
{
nodeContainsSupportedDataTypes = false;
return;
}
// Allow nodeArgs that are SequenceTensor when they are actually implemented by CPU Kernels.
if (edgeType == MLOperatorEdgeType::SequenceTensor)
{
@ -768,7 +803,7 @@ namespace Dml
}
// Check whether the node uses any data types which are unsupported by the device.
if (!DoesNodeContainSupportedDataTypes(node, internalRegInfo.get(), supportedDeviceDataTypeMask))
if (!DoesNodeContainSupportedDataTypes(node, internalRegInfo.get(), supportedDeviceDataTypeMask, m_native16BitShaderOpsSupported))
{
return false;
}

View file

@ -184,6 +184,7 @@ namespace Dml
ComPtr<IDMLDevice> m_dmlDevice;
bool m_isMcdmDevice = false;
bool m_areMetacommandsEnabled = true;
bool m_native16BitShaderOpsSupported = false;
std::shared_ptr<ExecutionContext> m_context;
std::unique_ptr<PooledUploadHeap> m_uploadHeap;
std::unique_ptr<ReadbackHeap> m_readbackHeap;

View file

@ -5,6 +5,10 @@
#include "../External/D3DX12/d3dx12.h"
// NOTE: When this operator's implementation is moved into DML, the associated FP16 fallback
// should be removed from IsCustomOpShader(...) in
// onnxruntime\core\providers\dml\DmlExecutionProvider\src\ExecutionProvider.cpp
// The shader headers are produced using "GeneratedShaders/GenerateShaders.bat"
namespace DFTFloat32
{

View file

@ -2,6 +2,10 @@
#include "DmlDFT.h"
// NOTE: When this operator's implementation is moved into DML, the associated FP16 fallback
// should be removed from IsCustomOpShader(...) in
// onnxruntime\core\providers\dml\DmlExecutionProvider\src\ExecutionProvider.cpp
enum DmlSTFTKernelInputIndex : uint32_t
{
Signal,