From d5f3aae3fd30a79d9bbeed26a70907618ac84835 Mon Sep 17 00:00:00 2001 From: raoanag <127366241+raoanag@users.noreply.github.com> Date: Fri, 17 Nov 2023 16:43:09 -0800 Subject: [PATCH] Utilize DML constant input graph node (#18267) ### Description This PR also includes, 8b0a55e7cc DML constant pow operator 7520974970 Enable custom heaps based on query- ### Motivation and Context --------- Co-authored-by: Jeff Bloomfield --- .../src/DmlGraphFusionHelper.cpp | 27 ++++- .../src/ExecutionProvider.cpp | 31 ++++++ .../src/ExecutionProvider.h | 2 + .../src/GraphDescBuilder.cpp | 104 ++++++++++++++---- .../src/GraphDescBuilder.h | 5 +- .../src/IExecutionProvider.h | 1 + .../src/MLOperatorAuthorImpl.cpp | 29 ++++- .../src/MLOperatorAuthorImpl.h | 7 ++ .../src/Operators/DmlOperatorElementWise.cpp | 38 +++++-- .../MLOperatorAuthorHelper.h | 13 +++ .../MLOperatorAuthorPrivate.h | 10 ++ 11 files changed, 229 insertions(+), 38 deletions(-) diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlGraphFusionHelper.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlGraphFusionHelper.cpp index 4f7ec18814..18cdc5d1bf 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlGraphFusionHelper.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlGraphFusionHelper.cpp @@ -226,8 +226,7 @@ namespace DmlGraphFusionHelper { ComPtr initializeInputBuffer; - // D3D_FEATURE_LEVEL_1_0_CORE doesn't support Custom heaps - if (providerImpl->IsMcdmDevice()) + if (!providerImpl->CustomHeapsSupported()) { initializeInputBuffer = CreateResource(providerImpl, tensorPtr, tensorByteSize); } @@ -294,6 +293,7 @@ namespace DmlGraphFusionHelper const uint32_t inputCount, const uint32_t outputCount, _Inout_ std::vector& dmlOperatorGraphNodes, + _Inout_ std::vector& dmlConstantGraphNodes, _Inout_ std::vector& dmlGraphNodes, _Inout_ std::vector& dmlInputEdges, _Inout_ std::vector& dmlOutputEdges, @@ -302,8 +302,24 @@ namespace DmlGraphFusionHelper for (size_t i = 0; i < graphDesc.nodes.size(); ++i) { auto& nodeInfo = graphDesc.nodes[i]; - dmlOperatorGraphNodes[i] = DML_OPERATOR_GRAPH_NODE_DESC{nodeInfo.op.Get(), nodeInfo.name.data()}; - dmlGraphNodes[i] = DML_GRAPH_NODE_DESC{DML_GRAPH_NODE_TYPE_OPERATOR, &dmlOperatorGraphNodes[i]}; + + if (std::holds_alternative>(nodeInfo.nodeDef)) + { + dmlOperatorGraphNodes[i] = DML_OPERATOR_GRAPH_NODE_DESC{std::get>(nodeInfo.nodeDef).Get(), nodeInfo.name.data()}; + dmlGraphNodes[i] = DML_GRAPH_NODE_DESC{DML_GRAPH_NODE_TYPE_OPERATOR, &dmlOperatorGraphNodes[i]}; + } + else + { + auto& nodeDefinitionData = std::get>(nodeInfo.nodeDef); + dmlConstantGraphNodes[i] = DML_CONSTANT_DATA_GRAPH_NODE_DESC{ + nodeDefinitionData.data(), + nodeDefinitionData.size(), + nodeInfo.name.data() + }; + + // TODO: Change as new header is ingested + dmlGraphNodes[i] = DML_GRAPH_NODE_DESC{static_cast(2), &dmlConstantGraphNodes[i]}; + } } for (size_t i = 0; i < graphDesc.inputEdges.size(); ++i) @@ -392,6 +408,8 @@ namespace DmlGraphFusionHelper // convert DML EP GraphDesc into DML_GRAPH_DESC and create IDMLCompiledOperator DML_GRAPH_DESC dmlGraphDesc = {}; std::vector dmlOperatorGraphNodes(graphDesc.nodes.size()); + std::vector dmlConstantGraphNodes(graphDesc.nodes.size()); + std::vector dmlGraphNodes(graphDesc.nodes.size()); std::vector dmlInputEdges(graphDesc.inputEdges.size()); std::vector dmlOutputEdges(graphDesc.outputEdges.size()); @@ -402,6 +420,7 @@ namespace DmlGraphFusionHelper fusedNodeInputCount, fusedNodeOutputCount, dmlOperatorGraphNodes, + dmlConstantGraphNodes, dmlGraphNodes, dmlInputEdges, dmlOutputEdges, diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.cpp index 8644b8d56a..49a64c4810 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.cpp @@ -182,6 +182,32 @@ namespace Dml } m_isMcdmDevice = (featureLevels.MaxSupportedFeatureLevel == D3D_FEATURE_LEVEL_1_0_CORE_PRIVATE); + m_areCustomHeapsSupported = !m_isMcdmDevice; + + if (m_isMcdmDevice) + { + + // TODO: Ingest updated header file + typedef struct D3D12_FEATURE_DATA_D3D12_OPTIONS19 + { + BOOL MismatchingOutputDimensionsSupported; + UINT SupportedSampleCountsWithNoOutputs; + BOOL PointSamplingAddressesNeverRoundUp; + BOOL RasterizerDesc2Supported; + BOOL NarrowQuadrilateralLinesSupported; + BOOL AnisoFilterWithPointMipSupported; + UINT MaxSamplerDescriptorHeapSize; + UINT MaxSamplerDescriptorHeapSizeWithStaticSamplers; + UINT MaxViewDescriptorHeapSize; + _Out_ BOOL ComputeOnlyCustomHeapSupported; + } D3D12_FEATURE_DATA_D3D12_OPTIONS19; + + D3D12_FEATURE_DATA_D3D12_OPTIONS19 options19 = {}; + + // The call may fail in which case the default value is false + d3d12Device->CheckFeatureSupport(static_cast(48) /*D3D12_FEATURE_D3D12_OPTIONS19*/, &options19, sizeof(options19)); + m_areCustomHeapsSupported = options19.ComputeOnlyCustomHeapSupported; + } m_context = std::make_shared(m_d3d12Device.Get(), m_dmlDevice.Get(), queue); @@ -1089,6 +1115,11 @@ namespace Dml return m_isMcdmDevice; } + bool __stdcall ExecutionProviderImpl::CustomHeapsSupported() const noexcept + { + return m_areCustomHeapsSupported; + } + bool __stdcall ExecutionProviderImpl::MetacommandsEnabled() const noexcept { return m_areMetacommandsEnabled; diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.h index 3aaa11cdee..ab932fb8a4 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.h +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.h @@ -150,6 +150,7 @@ namespace Dml } STDMETHOD_(bool, IsMcdmDevice)() const noexcept final; + STDMETHOD_(bool, CustomHeapsSupported)() const noexcept final; STDMETHOD_(bool, MetacommandsEnabled)() const noexcept final; bool DynamicGraphFusionEnabled() const noexcept; @@ -186,6 +187,7 @@ namespace Dml ComPtr m_d3d12Device; ComPtr m_dmlDevice; bool m_isMcdmDevice = false; + bool m_areCustomHeapsSupported = false; bool m_areMetacommandsEnabled = true; bool m_dynamicGraphFusionEnabled = false; bool m_native16BitShaderOpsSupported = false; diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphDescBuilder.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphDescBuilder.cpp index 3fc8f415e5..ba022533a1 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphDescBuilder.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphDescBuilder.cpp @@ -149,7 +149,7 @@ namespace Dml::GraphDescBuilder const std::unordered_map>& isInitializerTransferable, const std::unordered_map& graphNodePropertyMap, IDMLDevice* device, - const void* executionHandle, + const ExecutionProviderImpl* executionHandle, const onnxruntime::Path& modelPath, gsl::span subgraphNodes, gsl::span subgraphInputs, @@ -198,7 +198,7 @@ namespace Dml::GraphDescBuilder const uint32_t minNodeCountToReuseCommandList = 5; bool reuseCommandList = false; - if (subgraphNodes.size() >= minNodeCountToReuseCommandList) + if (subgraphNodes.size() >= minNodeCountToReuseCommandList || executionHandle->IsMcdmDevice()) { reuseCommandList = true; } @@ -232,14 +232,22 @@ namespace Dml::GraphDescBuilder { ComPtr tensor = nullptr; - // Check whether this specific node requested support for constant CPU inputs - if (std::find(requiredConstantCpuInputs.begin(), requiredConstantCpuInputs.end(), inputIndex) != requiredConstantCpuInputs.end()) + auto inputDefs = node.InputDefs(); + + if (inputIndex < inputDefs.size()) { - auto inputDefs = node.InputDefs(); - if (inputIndex < inputDefs.size()) + const onnxruntime::NodeArg* arg = inputDefs[inputIndex]; + tensor = constantCpuGraphInputGetter(arg->Name()); + + if (tensor == nullptr) { - const onnxruntime::NodeArg* arg = inputDefs[inputIndex]; - tensor = constantCpuGraphInputGetter(arg->Name()); + bool inputRequiredAsConstant = std::find( + requiredConstantCpuInputs.begin(), + requiredConstantCpuInputs.end(), + inputIndex) != requiredConstantCpuInputs.end(); + + // This shouldn't happen since kernel creation is deferred and repeated when required constant inputs are not present. + ORT_THROW_HR_IF(E_UNEXPECTED, inputRequiredAsConstant); } } @@ -289,6 +297,7 @@ namespace Dml::GraphDescBuilder std::unordered_map operatorGraphNodeIndexToMainGraphNodeIndexMap; uint32_t graphNodeCount = gsl::narrow_cast(graphNodes.size()); const bool isNodeAsOpDesc = graphNodeCreateInfo.nodesAsOperatorDesc.size() > 0; + size_t firstOpDescGraphNodeIndex = graphNodes.size(); if (isNodeAsOpDesc) { @@ -298,6 +307,8 @@ namespace Dml::GraphDescBuilder ORT_THROW_HR_IF(E_UNEXPECTED, !graphNodeCreateInfo.nodesAsOperatorDesc[nodeIndex]); operatorGraphNodeIndexToMainGraphNodeIndexMap.emplace(nodeIndex, graphNodeCount++); } + + graphNodes.resize(graphNodes.size() + graphNodeCreateInfo.nodeCount); } else { @@ -306,7 +317,7 @@ namespace Dml::GraphDescBuilder ORT_THROW_HR_IF(E_UNEXPECTED, !graphNodeCreateInfo.nodesAsIDMLOperator[nodeIndex].Get()); operatorGraphNodeIndexToMainGraphNodeIndexMap.emplace(nodeIndex, graphNodeCount++); NodeInfo nodeInfo = {}; - nodeInfo.op = std::move(graphNodeCreateInfo.nodesAsIDMLOperator[nodeIndex]); + nodeInfo.nodeDef = std::move(graphNodeCreateInfo.nodesAsIDMLOperator[nodeIndex]); graphNodes.push_back(std::move(nodeInfo)); } } @@ -328,21 +339,59 @@ namespace Dml::GraphDescBuilder const uint32_t dmlFusedNodeInputIndex = iter->second; - DML_INPUT_GRAPH_EDGE_DESC edge = {}; - edge.GraphInputIndex = dmlFusedNodeInputIndex; - edge.ToNodeIndex = mainGraphNodeIndex; - edge.ToNodeInputIndex = operatorGraphInputEdge.ToNodeInputIndex; // ?? might need to point inputIndex - graphInputEdges.push_back(edge); - // If this is a constant input, set the appropriate flags on the desc if (isNodeAsOpDesc && dmlFusedNodeInputIndex < isConstGpuGraphInputCount && isConstGpuGraphInput[dmlFusedNodeInputIndex]) { - auto& graphInputNode = graphNodeCreateInfo.nodesAsOperatorDesc[operatorGraphInputEdge.ToNodeIndex]; - std::vector toNodeInputTensorDescs = graphInputNode->GetInputTensors(); - DmlBufferTensorDesc* tensorDesc = toNodeInputTensorDescs[operatorGraphInputEdge.ToNodeInputIndex]; - tensorDesc->flags |= DML_TENSOR_FLAG_OWNED_BY_DML; + // This is a highly inefficient approach to generating constant nodes. It duplicates constant data + // across the graph input as well as every consumer's unique constant node. However it is currently + // only used for small inputs. + + // TODO: Rework this to create DML constant nodes with the minimum data size actually used by consuming + // nodes. This would allow this size to be reduced while handling the case that 1D scale and zero point + // values that have been de-duplicated with conversion to scalars in kernels. + uint32_t c_maxConstNodeDataSize = 1024 * 1024; + + ComPtr constantInput = constantCpuGraphInputGetter(arg->Name()); + + if (constantInput && constantInput->GetTensorByteSize() < c_maxConstNodeDataSize) + { + auto data = static_cast(constantInput->GetData()); + std::vector tensorData(data, data + constantInput->GetTensorByteSize()); + + NodeInfo nodeInfo = {}; + nodeInfo.nodeDef = std::move(tensorData); + graphNodes.push_back(std::move(nodeInfo)); + + DML_INTERMEDIATE_GRAPH_EDGE_DESC edge = {}; + edge.FromNodeIndex = static_cast(graphNodes.size() - 1); + edge.FromNodeOutputIndex = 0; + edge.ToNodeIndex = mainGraphNodeIndex; + edge.ToNodeInputIndex = operatorGraphInputEdge.ToNodeInputIndex; + graphIntermediateEdges.push_back(edge); + } + else + { + DML_INPUT_GRAPH_EDGE_DESC edge = {}; + edge.GraphInputIndex = dmlFusedNodeInputIndex; + edge.ToNodeIndex = mainGraphNodeIndex; + edge.ToNodeInputIndex = operatorGraphInputEdge.ToNodeInputIndex; + graphInputEdges.push_back(edge); + + auto& graphInputNode = graphNodeCreateInfo.nodesAsOperatorDesc[operatorGraphInputEdge.ToNodeIndex]; + std::vector toNodeInputTensorDescs = graphInputNode->GetInputTensors(); + DmlBufferTensorDesc* tensorDesc = toNodeInputTensorDescs[operatorGraphInputEdge.ToNodeInputIndex]; + tensorDesc->flags |= DML_TENSOR_FLAG_OWNED_BY_DML; + } + } + else + { + DML_INPUT_GRAPH_EDGE_DESC edge = {}; + edge.GraphInputIndex = dmlFusedNodeInputIndex; + edge.ToNodeIndex = mainGraphNodeIndex; + edge.ToNodeInputIndex = operatorGraphInputEdge.ToNodeInputIndex; + graphInputEdges.push_back(edge); } } else @@ -387,17 +436,28 @@ namespace Dml::GraphDescBuilder if (isNodeAsOpDesc) { - for (auto& opDesc : graphNodeCreateInfo.nodesAsOperatorDesc) + for (size_t i = 0; i < graphNodeCreateInfo.nodesAsOperatorDesc.size(); ++i) { + auto& opDesc = graphNodeCreateInfo.nodesAsOperatorDesc[i]; + DML_OPERATOR_DESC dmlDesc = SchemaHelpers::ConvertOperatorDesc(*opDesc, &allocator); + + // TODO: Change as new header is ingested + if (dmlDesc.Type == (DML_OPERATOR_TYPE) DML_OPERATOR_QUANTIZED_LINEAR_AVERAGE_POOLING) + dmlDesc.Type = (DML_OPERATOR_TYPE) 169; + + // TODO: Change as new header is ingested + if (dmlDesc.Type == (DML_OPERATOR_TYPE) DML_OPERATOR_MATRIX_MULTIPLY_INTEGER_TO_FLOAT) + dmlDesc.Type = (DML_OPERATOR_TYPE) 170; + ComPtr op; ORT_THROW_IF_FAILED(device->CreateOperator(&dmlDesc, IID_PPV_ARGS(&op))); allocator.Reset(); NodeInfo nodeInfo = {}; - nodeInfo.op = std::move(op); + nodeInfo.nodeDef = std::move(op); nodeInfo.name = node.Name(); - graphNodes.push_back(std::move(nodeInfo)); + graphNodes[firstOpDescGraphNodeIndex + i] = std::move(nodeInfo); } } } diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphDescBuilder.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphDescBuilder.h index 0039678c00..c95e89b455 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphDescBuilder.h +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphDescBuilder.h @@ -4,6 +4,7 @@ #pragma once #include "MLOperatorAuthorImpl.h" +#include "ExecutionProvider.h" namespace Dml { @@ -27,7 +28,7 @@ namespace Dml struct NodeInfo { - Microsoft::WRL::ComPtr op; + std::variant, std::vector> nodeDef; std::string name; }; @@ -47,7 +48,7 @@ namespace Dml const std::unordered_map>& isInitializerTransferable, const std::unordered_map& graphNodePropertyMap, IDMLDevice* device, - const void* executionHandle, + const ExecutionProviderImpl* executionHandle, const onnxruntime::Path& modelPath, gsl::span subgraphNodes, gsl::span subgraphInputs, diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/IExecutionProvider.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/IExecutionProvider.h index a8a6d6745e..17fd7c18ba 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/IExecutionProvider.h +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/IExecutionProvider.h @@ -76,6 +76,7 @@ namespace Dml STDMETHOD(AllocatePooledResource(size_t size, AllocatorRoundingMode roundingMode, ID3D12Resource **d3dResource, IUnknown* *pooledResource)) const noexcept = 0; STDMETHOD_(bool, IsMcdmDevice)() const noexcept = 0; + STDMETHOD_(bool, CustomHeapsSupported)() const noexcept = 0; STDMETHOD_(bool, MetacommandsEnabled)() const noexcept = 0; }; } // namespace Dml diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/MLOperatorAuthorImpl.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/MLOperatorAuthorImpl.cpp index 4deec620fe..dbd06abf82 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/MLOperatorAuthorImpl.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/MLOperatorAuthorImpl.cpp @@ -1123,7 +1123,7 @@ namespace Windows::AI::MachineLearning::Adapter } ORT_CATCH_RETURN } - + template HRESULT STDMETHODCALLTYPE OpNodeInfoWrapper::GetConstantInputTensor(uint32_t inputIndex, IMLOperatorTensor** tensor) const noexcept { @@ -1153,6 +1153,33 @@ namespace Windows::AI::MachineLearning::Adapter ORT_CATCH_RETURN } + template + HRESULT STDMETHODCALLTYPE OpNodeInfoWrapper::TryGetConstantInputTensor(uint32_t inputIndex, IMLOperatorTensor** tensor) const noexcept + { + ORT_TRY + { + auto constantInput = m_constantInputGetter(inputIndex); + ORT_THROW_HR_IF(E_INVALIDARG, !std::holds_alternative>(constantInput)); + + auto tensorWrapper = std::get>(constantInput); + if (tensorWrapper == nullptr) + { + bool inputRequiredAsConstant = std::find( + m_requiredConstantCpuInputs.begin(), + m_requiredConstantCpuInputs.end(), + inputIndex) != m_requiredConstantCpuInputs.end(); + + // This shouldn't happen since kernel creation is deferred and repeated when required constant inputs are not present. + ORT_THROW_HR_IF(E_UNEXPECTED, inputRequiredAsConstant); + } + + *tensor = tensorWrapper.Detach(); + + return S_OK; + } + ORT_CATCH_RETURN + } + HRESULT STDMETHODCALLTYPE OpKernelInfoWrapper::GetOutputTensorShape(uint32_t outputIndex, uint32_t dimensionCount, uint32_t* dimensions) const noexcept { ORT_TRY diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/MLOperatorAuthorImpl.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/MLOperatorAuthorImpl.h index 913997ff4a..6530d89d89 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/MLOperatorAuthorImpl.h +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/MLOperatorAuthorImpl.h @@ -204,6 +204,11 @@ class OpNodeInfoWrapper : public Base1_t, public Base2_t, public Closable _Outptr_ IMLOperatorTensor** tensor ) const noexcept; + HRESULT STDMETHODCALLTYPE TryGetConstantInputTensor( + uint32_t inputIndex, + _Outptr_ IMLOperatorTensor** tensor + ) const noexcept; + protected: // Lifetime is managed by the caller and guaranteed to outlive this class const onnxruntime::OpNodeProtoHelper* m_impl = nullptr; @@ -299,6 +304,8 @@ class OnnxTensorWrapper : public WRL::Base, public Closable const onnxruntime::Tensor* GetInterface() const { return nullptr; } onnxruntime::Tensor* GetInterface() { return nullptr; } + size_t GetTensorByteSize() const { return m_tensorByteSize; } + private: size_t m_tensorByteSize = 0; std::unique_ptr m_unpackedTensor; diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorElementWise.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorElementWise.cpp index 43d3465709..f0a16da3a3 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorElementWise.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorElementWise.cpp @@ -479,17 +479,37 @@ public: ML_CHECK_VALID_ARGUMENT(kernelInfo.GetInputCount() == 2); ML_CHECK_VALID_ARGUMENT(kernelInfo.GetOutputCount() == 1); - Initialize(kernelInfo, std::nullopt, std::nullopt, kernelInfo.GetTensorShapeDescription().GetOutputTensorShape(0)); + auto constExpTensor = kernelInfo.TryGetConstantInputTensor(1); + if (constExpTensor && constExpTensor->GetTotalElementCount() == 1) + { + std::vector> kernelInputIndices = {0}; - std::vector inputDescs = GetDmlInputDescs(); - std::vector outputDescs = GetDmlOutputDescs(); + Initialize(kernelInfo, kernelInputIndices, std::nullopt, kernelInfo.GetTensorShapeDescription().GetOutputTensorShape(0)); - DML_ELEMENT_WISE_POW_OPERATOR_DESC opDesc = {}; - opDesc.InputTensor = &inputDescs[0]; - opDesc.ExponentTensor = &inputDescs[1]; - opDesc.OutputTensor = &outputDescs[0]; + std::vector inputDescs = GetDmlInputDescs(); + std::vector outputDescs = GetDmlOutputDescs(); - SetDmlOperatorDesc({ DML_OPERATOR_ELEMENT_WISE_POW, &opDesc}, kernelInfo); + DML_ELEMENT_WISE_CONSTANT_POW_OPERATOR_DESC opDesc = {}; + opDesc.InputTensor = &inputDescs[0]; + opDesc.OutputTensor = &outputDescs[0]; + opDesc.Exponent = static_cast(ReadScalarTensorCastToFloat64(*constExpTensor)); + + SetDmlOperatorDesc({ DML_OPERATOR_ELEMENT_WISE_CONSTANT_POW, &opDesc}, kernelInfo); + } + else + { + Initialize(kernelInfo, std::nullopt, std::nullopt, kernelInfo.GetTensorShapeDescription().GetOutputTensorShape(0)); + + std::vector inputDescs = GetDmlInputDescs(); + std::vector outputDescs = GetDmlOutputDescs(); + + DML_ELEMENT_WISE_POW_OPERATOR_DESC opDesc = {}; + opDesc.InputTensor = &inputDescs[0]; + opDesc.ExponentTensor = &inputDescs[1]; + opDesc.OutputTensor = &outputDescs[0]; + + SetDmlOperatorDesc({ DML_OPERATOR_ELEMENT_WISE_POW, &opDesc}, kernelInfo); + } } }; @@ -565,7 +585,7 @@ public: opDesc.ScaleTensor = &inputDescs[1]; opDesc.ZeroPointTensor = &inputDescs[2]; opDesc.OutputTensor = &outputDescs[0]; - + SetDmlOperatorDesc({ApiTraits::OperatorDescTraits::Type, &opDesc}, kernelInfo); } }; diff --git a/onnxruntime/core/providers/dml/OperatorAuthorHelper/MLOperatorAuthorHelper.h b/onnxruntime/core/providers/dml/OperatorAuthorHelper/MLOperatorAuthorHelper.h index f94270cfad..59a1719d08 100644 --- a/onnxruntime/core/providers/dml/OperatorAuthorHelper/MLOperatorAuthorHelper.h +++ b/onnxruntime/core/providers/dml/OperatorAuthorHelper/MLOperatorAuthorHelper.h @@ -6,6 +6,7 @@ #include "core/providers/dml/DmlExecutionProvider/inc/MLOperatorAuthor.h" #include "MLOperatorAuthorPrivate.h" #include "core/common/gsl.h" +#include #ifdef ORT_NO_EXCEPTIONS #define ML_CHECK_BOOL(x) ORT_THROW_HR_IF(E_INVALIDARG, !(x)) @@ -604,6 +605,18 @@ public: return MLOperatorTensor(tensor.Get()); } + std::optional TryGetConstantInputTensor(uint32_t inputIndex) const + { + Microsoft::WRL::ComPtr tensor; + ORT_THROW_IF_FAILED(m_implPrivate->TryGetConstantInputTensor(inputIndex, &tensor)); + if (tensor) + { + return MLOperatorTensor(tensor.Get()); + } + + return std::nullopt; + } + uint32_t GetInputTensorDimensionCount(uint32_t inputIndex) const { auto shapeDesc = GetTensorShapeDescription(); diff --git a/onnxruntime/core/providers/dml/OperatorAuthorHelper/MLOperatorAuthorPrivate.h b/onnxruntime/core/providers/dml/OperatorAuthorHelper/MLOperatorAuthorPrivate.h index d1a705e151..3bec8d3864 100644 --- a/onnxruntime/core/providers/dml/OperatorAuthorHelper/MLOperatorAuthorPrivate.h +++ b/onnxruntime/core/providers/dml/OperatorAuthorHelper/MLOperatorAuthorPrivate.h @@ -41,6 +41,11 @@ IMLOperatorShapeInferenceContextPrivate : public IMLOperatorShapeInferenceContex _Outptr_ IMLOperatorTensor** tensor ) const noexcept PURE; + STDMETHOD(TryGetConstantInputTensor)( + uint32_t inputIndex, + _Outptr_ IMLOperatorTensor** tensor + ) const noexcept PURE; + //! Gets the number of dimensions of a tensor output of the operator. STDMETHOD(GetSequenceInputInfo)( uint32_t inputIndex, @@ -73,6 +78,11 @@ IMLOperatorKernelCreationContextPrivate : public IMLOperatorKernelCreationContex _Outptr_ IMLOperatorTensor** tensor ) const noexcept PURE; + STDMETHOD(TryGetConstantInputTensor)( + uint32_t inputIndex, + _Outptr_ IMLOperatorTensor** tensor + ) const noexcept PURE; + STDMETHOD_(bool, IsDmlGraphNode)() const noexcept PURE; STDMETHOD(SetDmlOperator)(