mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-07-01 03:45:06 +00:00
Utilize DML constant input graph node (#18267)
### Description This PR also includes,8b0a55e7ccDML constant pow operator7520974970Enable custom heaps based on query- ### Motivation and Context <!-- - Why is this change required? What problem does it solve? - If it fixes an open issue, please link to the issue here. --> --------- Co-authored-by: Jeff Bloomfield <jeffbloo@microsoft.com>
This commit is contained in:
parent
dcfff10f57
commit
d5f3aae3fd
11 changed files with 229 additions and 38 deletions
|
|
@ -226,8 +226,7 @@ namespace DmlGraphFusionHelper
|
|||
{
|
||||
ComPtr<ID3D12Resource> initializeInputBuffer;
|
||||
|
||||
// D3D_FEATURE_LEVEL_1_0_CORE doesn't support Custom heaps
|
||||
if (providerImpl->IsMcdmDevice())
|
||||
if (!providerImpl->CustomHeapsSupported())
|
||||
{
|
||||
initializeInputBuffer = CreateResource(providerImpl, tensorPtr, tensorByteSize);
|
||||
}
|
||||
|
|
@ -294,6 +293,7 @@ namespace DmlGraphFusionHelper
|
|||
const uint32_t inputCount,
|
||||
const uint32_t outputCount,
|
||||
_Inout_ std::vector<DML_OPERATOR_GRAPH_NODE_DESC>& dmlOperatorGraphNodes,
|
||||
_Inout_ std::vector<DML_CONSTANT_DATA_GRAPH_NODE_DESC>& dmlConstantGraphNodes,
|
||||
_Inout_ std::vector<DML_GRAPH_NODE_DESC>& dmlGraphNodes,
|
||||
_Inout_ std::vector<DML_GRAPH_EDGE_DESC>& dmlInputEdges,
|
||||
_Inout_ std::vector<DML_GRAPH_EDGE_DESC>& dmlOutputEdges,
|
||||
|
|
@ -302,8 +302,24 @@ namespace DmlGraphFusionHelper
|
|||
for (size_t i = 0; i < graphDesc.nodes.size(); ++i)
|
||||
{
|
||||
auto& nodeInfo = graphDesc.nodes[i];
|
||||
dmlOperatorGraphNodes[i] = DML_OPERATOR_GRAPH_NODE_DESC{nodeInfo.op.Get(), nodeInfo.name.data()};
|
||||
dmlGraphNodes[i] = DML_GRAPH_NODE_DESC{DML_GRAPH_NODE_TYPE_OPERATOR, &dmlOperatorGraphNodes[i]};
|
||||
|
||||
if (std::holds_alternative<Microsoft::WRL::ComPtr<IDMLOperator>>(nodeInfo.nodeDef))
|
||||
{
|
||||
dmlOperatorGraphNodes[i] = DML_OPERATOR_GRAPH_NODE_DESC{std::get<Microsoft::WRL::ComPtr<IDMLOperator>>(nodeInfo.nodeDef).Get(), nodeInfo.name.data()};
|
||||
dmlGraphNodes[i] = DML_GRAPH_NODE_DESC{DML_GRAPH_NODE_TYPE_OPERATOR, &dmlOperatorGraphNodes[i]};
|
||||
}
|
||||
else
|
||||
{
|
||||
auto& nodeDefinitionData = std::get<std::vector<uint8_t>>(nodeInfo.nodeDef);
|
||||
dmlConstantGraphNodes[i] = DML_CONSTANT_DATA_GRAPH_NODE_DESC{
|
||||
nodeDefinitionData.data(),
|
||||
nodeDefinitionData.size(),
|
||||
nodeInfo.name.data()
|
||||
};
|
||||
|
||||
// TODO: Change as new header is ingested
|
||||
dmlGraphNodes[i] = DML_GRAPH_NODE_DESC{static_cast<DML_GRAPH_NODE_TYPE>(2), &dmlConstantGraphNodes[i]};
|
||||
}
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < graphDesc.inputEdges.size(); ++i)
|
||||
|
|
@ -392,6 +408,8 @@ namespace DmlGraphFusionHelper
|
|||
// convert DML EP GraphDesc into DML_GRAPH_DESC and create IDMLCompiledOperator
|
||||
DML_GRAPH_DESC dmlGraphDesc = {};
|
||||
std::vector<DML_OPERATOR_GRAPH_NODE_DESC> dmlOperatorGraphNodes(graphDesc.nodes.size());
|
||||
std::vector<DML_CONSTANT_DATA_GRAPH_NODE_DESC> dmlConstantGraphNodes(graphDesc.nodes.size());
|
||||
|
||||
std::vector<DML_GRAPH_NODE_DESC> dmlGraphNodes(graphDesc.nodes.size());
|
||||
std::vector<DML_GRAPH_EDGE_DESC> dmlInputEdges(graphDesc.inputEdges.size());
|
||||
std::vector<DML_GRAPH_EDGE_DESC> dmlOutputEdges(graphDesc.outputEdges.size());
|
||||
|
|
@ -402,6 +420,7 @@ namespace DmlGraphFusionHelper
|
|||
fusedNodeInputCount,
|
||||
fusedNodeOutputCount,
|
||||
dmlOperatorGraphNodes,
|
||||
dmlConstantGraphNodes,
|
||||
dmlGraphNodes,
|
||||
dmlInputEdges,
|
||||
dmlOutputEdges,
|
||||
|
|
|
|||
|
|
@ -182,6 +182,32 @@ namespace Dml
|
|||
}
|
||||
|
||||
m_isMcdmDevice = (featureLevels.MaxSupportedFeatureLevel == D3D_FEATURE_LEVEL_1_0_CORE_PRIVATE);
|
||||
m_areCustomHeapsSupported = !m_isMcdmDevice;
|
||||
|
||||
if (m_isMcdmDevice)
|
||||
{
|
||||
|
||||
// TODO: Ingest updated header file
|
||||
typedef struct D3D12_FEATURE_DATA_D3D12_OPTIONS19
|
||||
{
|
||||
BOOL MismatchingOutputDimensionsSupported;
|
||||
UINT SupportedSampleCountsWithNoOutputs;
|
||||
BOOL PointSamplingAddressesNeverRoundUp;
|
||||
BOOL RasterizerDesc2Supported;
|
||||
BOOL NarrowQuadrilateralLinesSupported;
|
||||
BOOL AnisoFilterWithPointMipSupported;
|
||||
UINT MaxSamplerDescriptorHeapSize;
|
||||
UINT MaxSamplerDescriptorHeapSizeWithStaticSamplers;
|
||||
UINT MaxViewDescriptorHeapSize;
|
||||
_Out_ BOOL ComputeOnlyCustomHeapSupported;
|
||||
} D3D12_FEATURE_DATA_D3D12_OPTIONS19;
|
||||
|
||||
D3D12_FEATURE_DATA_D3D12_OPTIONS19 options19 = {};
|
||||
|
||||
// The call may fail in which case the default value is false
|
||||
d3d12Device->CheckFeatureSupport(static_cast<D3D12_FEATURE>(48) /*D3D12_FEATURE_D3D12_OPTIONS19*/, &options19, sizeof(options19));
|
||||
m_areCustomHeapsSupported = options19.ComputeOnlyCustomHeapSupported;
|
||||
}
|
||||
|
||||
m_context = std::make_shared<ExecutionContext>(m_d3d12Device.Get(), m_dmlDevice.Get(), queue);
|
||||
|
||||
|
|
@ -1089,6 +1115,11 @@ namespace Dml
|
|||
return m_isMcdmDevice;
|
||||
}
|
||||
|
||||
bool __stdcall ExecutionProviderImpl::CustomHeapsSupported() const noexcept
|
||||
{
|
||||
return m_areCustomHeapsSupported;
|
||||
}
|
||||
|
||||
bool __stdcall ExecutionProviderImpl::MetacommandsEnabled() const noexcept
|
||||
{
|
||||
return m_areMetacommandsEnabled;
|
||||
|
|
|
|||
|
|
@ -150,6 +150,7 @@ namespace Dml
|
|||
}
|
||||
|
||||
STDMETHOD_(bool, IsMcdmDevice)() const noexcept final;
|
||||
STDMETHOD_(bool, CustomHeapsSupported)() const noexcept final;
|
||||
|
||||
STDMETHOD_(bool, MetacommandsEnabled)() const noexcept final;
|
||||
bool DynamicGraphFusionEnabled() const noexcept;
|
||||
|
|
@ -186,6 +187,7 @@ namespace Dml
|
|||
ComPtr<ID3D12Device> m_d3d12Device;
|
||||
ComPtr<IDMLDevice> m_dmlDevice;
|
||||
bool m_isMcdmDevice = false;
|
||||
bool m_areCustomHeapsSupported = false;
|
||||
bool m_areMetacommandsEnabled = true;
|
||||
bool m_dynamicGraphFusionEnabled = false;
|
||||
bool m_native16BitShaderOpsSupported = false;
|
||||
|
|
|
|||
|
|
@ -149,7 +149,7 @@ namespace Dml::GraphDescBuilder
|
|||
const std::unordered_map<std::string, std::pair<const ONNX_NAMESPACE::TensorProto*, bool>>& isInitializerTransferable,
|
||||
const std::unordered_map<std::string, GraphNodeProperties>& graphNodePropertyMap,
|
||||
IDMLDevice* device,
|
||||
const void* executionHandle,
|
||||
const ExecutionProviderImpl* executionHandle,
|
||||
const onnxruntime::Path& modelPath,
|
||||
gsl::span<const onnxruntime::Node* const> subgraphNodes,
|
||||
gsl::span<const onnxruntime::NodeArg* const> subgraphInputs,
|
||||
|
|
@ -198,7 +198,7 @@ namespace Dml::GraphDescBuilder
|
|||
const uint32_t minNodeCountToReuseCommandList = 5;
|
||||
bool reuseCommandList = false;
|
||||
|
||||
if (subgraphNodes.size() >= minNodeCountToReuseCommandList)
|
||||
if (subgraphNodes.size() >= minNodeCountToReuseCommandList || executionHandle->IsMcdmDevice())
|
||||
{
|
||||
reuseCommandList = true;
|
||||
}
|
||||
|
|
@ -232,14 +232,22 @@ namespace Dml::GraphDescBuilder
|
|||
{
|
||||
ComPtr<IMLOperatorTensor> tensor = nullptr;
|
||||
|
||||
// Check whether this specific node requested support for constant CPU inputs
|
||||
if (std::find(requiredConstantCpuInputs.begin(), requiredConstantCpuInputs.end(), inputIndex) != requiredConstantCpuInputs.end())
|
||||
auto inputDefs = node.InputDefs();
|
||||
|
||||
if (inputIndex < inputDefs.size())
|
||||
{
|
||||
auto inputDefs = node.InputDefs();
|
||||
if (inputIndex < inputDefs.size())
|
||||
const onnxruntime::NodeArg* arg = inputDefs[inputIndex];
|
||||
tensor = constantCpuGraphInputGetter(arg->Name());
|
||||
|
||||
if (tensor == nullptr)
|
||||
{
|
||||
const onnxruntime::NodeArg* arg = inputDefs[inputIndex];
|
||||
tensor = constantCpuGraphInputGetter(arg->Name());
|
||||
bool inputRequiredAsConstant = std::find(
|
||||
requiredConstantCpuInputs.begin(),
|
||||
requiredConstantCpuInputs.end(),
|
||||
inputIndex) != requiredConstantCpuInputs.end();
|
||||
|
||||
// This shouldn't happen since kernel creation is deferred and repeated when required constant inputs are not present.
|
||||
ORT_THROW_HR_IF(E_UNEXPECTED, inputRequiredAsConstant);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -289,6 +297,7 @@ namespace Dml::GraphDescBuilder
|
|||
std::unordered_map<uint32_t, uint32_t> operatorGraphNodeIndexToMainGraphNodeIndexMap;
|
||||
uint32_t graphNodeCount = gsl::narrow_cast<uint32_t>(graphNodes.size());
|
||||
const bool isNodeAsOpDesc = graphNodeCreateInfo.nodesAsOperatorDesc.size() > 0;
|
||||
size_t firstOpDescGraphNodeIndex = graphNodes.size();
|
||||
|
||||
if (isNodeAsOpDesc)
|
||||
{
|
||||
|
|
@ -298,6 +307,8 @@ namespace Dml::GraphDescBuilder
|
|||
ORT_THROW_HR_IF(E_UNEXPECTED, !graphNodeCreateInfo.nodesAsOperatorDesc[nodeIndex]);
|
||||
operatorGraphNodeIndexToMainGraphNodeIndexMap.emplace(nodeIndex, graphNodeCount++);
|
||||
}
|
||||
|
||||
graphNodes.resize(graphNodes.size() + graphNodeCreateInfo.nodeCount);
|
||||
}
|
||||
else
|
||||
{
|
||||
|
|
@ -306,7 +317,7 @@ namespace Dml::GraphDescBuilder
|
|||
ORT_THROW_HR_IF(E_UNEXPECTED, !graphNodeCreateInfo.nodesAsIDMLOperator[nodeIndex].Get());
|
||||
operatorGraphNodeIndexToMainGraphNodeIndexMap.emplace(nodeIndex, graphNodeCount++);
|
||||
NodeInfo nodeInfo = {};
|
||||
nodeInfo.op = std::move(graphNodeCreateInfo.nodesAsIDMLOperator[nodeIndex]);
|
||||
nodeInfo.nodeDef = std::move(graphNodeCreateInfo.nodesAsIDMLOperator[nodeIndex]);
|
||||
graphNodes.push_back(std::move(nodeInfo));
|
||||
}
|
||||
}
|
||||
|
|
@ -328,21 +339,59 @@ namespace Dml::GraphDescBuilder
|
|||
|
||||
const uint32_t dmlFusedNodeInputIndex = iter->second;
|
||||
|
||||
DML_INPUT_GRAPH_EDGE_DESC edge = {};
|
||||
edge.GraphInputIndex = dmlFusedNodeInputIndex;
|
||||
edge.ToNodeIndex = mainGraphNodeIndex;
|
||||
edge.ToNodeInputIndex = operatorGraphInputEdge.ToNodeInputIndex; // ?? might need to point inputIndex
|
||||
graphInputEdges.push_back(edge);
|
||||
|
||||
// If this is a constant input, set the appropriate flags on the desc
|
||||
if (isNodeAsOpDesc &&
|
||||
dmlFusedNodeInputIndex < isConstGpuGraphInputCount &&
|
||||
isConstGpuGraphInput[dmlFusedNodeInputIndex])
|
||||
{
|
||||
auto& graphInputNode = graphNodeCreateInfo.nodesAsOperatorDesc[operatorGraphInputEdge.ToNodeIndex];
|
||||
std::vector<DmlBufferTensorDesc*> toNodeInputTensorDescs = graphInputNode->GetInputTensors();
|
||||
DmlBufferTensorDesc* tensorDesc = toNodeInputTensorDescs[operatorGraphInputEdge.ToNodeInputIndex];
|
||||
tensorDesc->flags |= DML_TENSOR_FLAG_OWNED_BY_DML;
|
||||
// This is a highly inefficient approach to generating constant nodes. It duplicates constant data
|
||||
// across the graph input as well as every consumer's unique constant node. However it is currently
|
||||
// only used for small inputs.
|
||||
|
||||
// TODO: Rework this to create DML constant nodes with the minimum data size actually used by consuming
|
||||
// nodes. This would allow this size to be reduced while handling the case that 1D scale and zero point
|
||||
// values that have been de-duplicated with conversion to scalars in kernels.
|
||||
uint32_t c_maxConstNodeDataSize = 1024 * 1024;
|
||||
|
||||
ComPtr<OnnxTensorWrapper> constantInput = constantCpuGraphInputGetter(arg->Name());
|
||||
|
||||
if (constantInput && constantInput->GetTensorByteSize() < c_maxConstNodeDataSize)
|
||||
{
|
||||
auto data = static_cast<const uint8_t*>(constantInput->GetData());
|
||||
std::vector<uint8_t> tensorData(data, data + constantInput->GetTensorByteSize());
|
||||
|
||||
NodeInfo nodeInfo = {};
|
||||
nodeInfo.nodeDef = std::move(tensorData);
|
||||
graphNodes.push_back(std::move(nodeInfo));
|
||||
|
||||
DML_INTERMEDIATE_GRAPH_EDGE_DESC edge = {};
|
||||
edge.FromNodeIndex = static_cast<UINT>(graphNodes.size() - 1);
|
||||
edge.FromNodeOutputIndex = 0;
|
||||
edge.ToNodeIndex = mainGraphNodeIndex;
|
||||
edge.ToNodeInputIndex = operatorGraphInputEdge.ToNodeInputIndex;
|
||||
graphIntermediateEdges.push_back(edge);
|
||||
}
|
||||
else
|
||||
{
|
||||
DML_INPUT_GRAPH_EDGE_DESC edge = {};
|
||||
edge.GraphInputIndex = dmlFusedNodeInputIndex;
|
||||
edge.ToNodeIndex = mainGraphNodeIndex;
|
||||
edge.ToNodeInputIndex = operatorGraphInputEdge.ToNodeInputIndex;
|
||||
graphInputEdges.push_back(edge);
|
||||
|
||||
auto& graphInputNode = graphNodeCreateInfo.nodesAsOperatorDesc[operatorGraphInputEdge.ToNodeIndex];
|
||||
std::vector<DmlBufferTensorDesc*> toNodeInputTensorDescs = graphInputNode->GetInputTensors();
|
||||
DmlBufferTensorDesc* tensorDesc = toNodeInputTensorDescs[operatorGraphInputEdge.ToNodeInputIndex];
|
||||
tensorDesc->flags |= DML_TENSOR_FLAG_OWNED_BY_DML;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
DML_INPUT_GRAPH_EDGE_DESC edge = {};
|
||||
edge.GraphInputIndex = dmlFusedNodeInputIndex;
|
||||
edge.ToNodeIndex = mainGraphNodeIndex;
|
||||
edge.ToNodeInputIndex = operatorGraphInputEdge.ToNodeInputIndex;
|
||||
graphInputEdges.push_back(edge);
|
||||
}
|
||||
}
|
||||
else
|
||||
|
|
@ -387,17 +436,28 @@ namespace Dml::GraphDescBuilder
|
|||
|
||||
if (isNodeAsOpDesc)
|
||||
{
|
||||
for (auto& opDesc : graphNodeCreateInfo.nodesAsOperatorDesc)
|
||||
for (size_t i = 0; i < graphNodeCreateInfo.nodesAsOperatorDesc.size(); ++i)
|
||||
{
|
||||
auto& opDesc = graphNodeCreateInfo.nodesAsOperatorDesc[i];
|
||||
|
||||
DML_OPERATOR_DESC dmlDesc = SchemaHelpers::ConvertOperatorDesc(*opDesc, &allocator);
|
||||
|
||||
// TODO: Change as new header is ingested
|
||||
if (dmlDesc.Type == (DML_OPERATOR_TYPE) DML_OPERATOR_QUANTIZED_LINEAR_AVERAGE_POOLING)
|
||||
dmlDesc.Type = (DML_OPERATOR_TYPE) 169;
|
||||
|
||||
// TODO: Change as new header is ingested
|
||||
if (dmlDesc.Type == (DML_OPERATOR_TYPE) DML_OPERATOR_MATRIX_MULTIPLY_INTEGER_TO_FLOAT)
|
||||
dmlDesc.Type = (DML_OPERATOR_TYPE) 170;
|
||||
|
||||
ComPtr<IDMLOperator> op;
|
||||
ORT_THROW_IF_FAILED(device->CreateOperator(&dmlDesc, IID_PPV_ARGS(&op)));
|
||||
allocator.Reset();
|
||||
|
||||
NodeInfo nodeInfo = {};
|
||||
nodeInfo.op = std::move(op);
|
||||
nodeInfo.nodeDef = std::move(op);
|
||||
nodeInfo.name = node.Name();
|
||||
graphNodes.push_back(std::move(nodeInfo));
|
||||
graphNodes[firstOpDescGraphNodeIndex + i] = std::move(nodeInfo);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -4,6 +4,7 @@
|
|||
#pragma once
|
||||
|
||||
#include "MLOperatorAuthorImpl.h"
|
||||
#include "ExecutionProvider.h"
|
||||
|
||||
namespace Dml
|
||||
{
|
||||
|
|
@ -27,7 +28,7 @@ namespace Dml
|
|||
|
||||
struct NodeInfo
|
||||
{
|
||||
Microsoft::WRL::ComPtr<IDMLOperator> op;
|
||||
std::variant<Microsoft::WRL::ComPtr<IDMLOperator>, std::vector<uint8_t>> nodeDef;
|
||||
std::string name;
|
||||
};
|
||||
|
||||
|
|
@ -47,7 +48,7 @@ namespace Dml
|
|||
const std::unordered_map<std::string, std::pair<const ONNX_NAMESPACE::TensorProto*, bool>>& isInitializerTransferable,
|
||||
const std::unordered_map<std::string, GraphNodeProperties>& graphNodePropertyMap,
|
||||
IDMLDevice* device,
|
||||
const void* executionHandle,
|
||||
const ExecutionProviderImpl* executionHandle,
|
||||
const onnxruntime::Path& modelPath,
|
||||
gsl::span<const onnxruntime::Node* const> subgraphNodes,
|
||||
gsl::span<const onnxruntime::NodeArg* const> subgraphInputs,
|
||||
|
|
|
|||
|
|
@ -76,6 +76,7 @@ namespace Dml
|
|||
STDMETHOD(AllocatePooledResource(size_t size, AllocatorRoundingMode roundingMode, ID3D12Resource **d3dResource, IUnknown* *pooledResource)) const noexcept = 0;
|
||||
|
||||
STDMETHOD_(bool, IsMcdmDevice)() const noexcept = 0;
|
||||
STDMETHOD_(bool, CustomHeapsSupported)() const noexcept = 0;
|
||||
STDMETHOD_(bool, MetacommandsEnabled)() const noexcept = 0;
|
||||
};
|
||||
} // namespace Dml
|
||||
|
|
|
|||
|
|
@ -1123,7 +1123,7 @@ namespace Windows::AI::MachineLearning::Adapter
|
|||
}
|
||||
ORT_CATCH_RETURN
|
||||
}
|
||||
|
||||
|
||||
template <class NodeInfoImpl_t, class Base1_t, class Base2_t>
|
||||
HRESULT STDMETHODCALLTYPE OpNodeInfoWrapper<NodeInfoImpl_t, Base1_t, Base2_t>::GetConstantInputTensor(uint32_t inputIndex, IMLOperatorTensor** tensor) const noexcept
|
||||
{
|
||||
|
|
@ -1153,6 +1153,33 @@ namespace Windows::AI::MachineLearning::Adapter
|
|||
ORT_CATCH_RETURN
|
||||
}
|
||||
|
||||
template <class NodeInfoImpl_t, class Base1_t, class Base2_t>
|
||||
HRESULT STDMETHODCALLTYPE OpNodeInfoWrapper<NodeInfoImpl_t, Base1_t, Base2_t>::TryGetConstantInputTensor(uint32_t inputIndex, IMLOperatorTensor** tensor) const noexcept
|
||||
{
|
||||
ORT_TRY
|
||||
{
|
||||
auto constantInput = m_constantInputGetter(inputIndex);
|
||||
ORT_THROW_HR_IF(E_INVALIDARG, !std::holds_alternative<ComPtr<IMLOperatorTensor>>(constantInput));
|
||||
|
||||
auto tensorWrapper = std::get<ComPtr<IMLOperatorTensor>>(constantInput);
|
||||
if (tensorWrapper == nullptr)
|
||||
{
|
||||
bool inputRequiredAsConstant = std::find(
|
||||
m_requiredConstantCpuInputs.begin(),
|
||||
m_requiredConstantCpuInputs.end(),
|
||||
inputIndex) != m_requiredConstantCpuInputs.end();
|
||||
|
||||
// This shouldn't happen since kernel creation is deferred and repeated when required constant inputs are not present.
|
||||
ORT_THROW_HR_IF(E_UNEXPECTED, inputRequiredAsConstant);
|
||||
}
|
||||
|
||||
*tensor = tensorWrapper.Detach();
|
||||
|
||||
return S_OK;
|
||||
}
|
||||
ORT_CATCH_RETURN
|
||||
}
|
||||
|
||||
HRESULT STDMETHODCALLTYPE OpKernelInfoWrapper::GetOutputTensorShape(uint32_t outputIndex, uint32_t dimensionCount, uint32_t* dimensions) const noexcept
|
||||
{
|
||||
ORT_TRY
|
||||
|
|
|
|||
|
|
@ -204,6 +204,11 @@ class OpNodeInfoWrapper : public Base1_t, public Base2_t, public Closable
|
|||
_Outptr_ IMLOperatorTensor** tensor
|
||||
) const noexcept;
|
||||
|
||||
HRESULT STDMETHODCALLTYPE TryGetConstantInputTensor(
|
||||
uint32_t inputIndex,
|
||||
_Outptr_ IMLOperatorTensor** tensor
|
||||
) const noexcept;
|
||||
|
||||
protected:
|
||||
// Lifetime is managed by the caller and guaranteed to outlive this class
|
||||
const onnxruntime::OpNodeProtoHelper<NodeInfoImpl_t>* m_impl = nullptr;
|
||||
|
|
@ -299,6 +304,8 @@ class OnnxTensorWrapper : public WRL::Base<IMLOperatorTensor>, public Closable
|
|||
const onnxruntime::Tensor* GetInterface() const { return nullptr; }
|
||||
onnxruntime::Tensor* GetInterface() { return nullptr; }
|
||||
|
||||
size_t GetTensorByteSize() const { return m_tensorByteSize; }
|
||||
|
||||
private:
|
||||
size_t m_tensorByteSize = 0;
|
||||
std::unique_ptr<std::byte[]> m_unpackedTensor;
|
||||
|
|
|
|||
|
|
@ -479,17 +479,37 @@ public:
|
|||
ML_CHECK_VALID_ARGUMENT(kernelInfo.GetInputCount() == 2);
|
||||
ML_CHECK_VALID_ARGUMENT(kernelInfo.GetOutputCount() == 1);
|
||||
|
||||
Initialize(kernelInfo, std::nullopt, std::nullopt, kernelInfo.GetTensorShapeDescription().GetOutputTensorShape(0));
|
||||
auto constExpTensor = kernelInfo.TryGetConstantInputTensor(1);
|
||||
if (constExpTensor && constExpTensor->GetTotalElementCount() == 1)
|
||||
{
|
||||
std::vector<std::optional<uint32_t>> kernelInputIndices = {0};
|
||||
|
||||
std::vector<DML_TENSOR_DESC> inputDescs = GetDmlInputDescs();
|
||||
std::vector<DML_TENSOR_DESC> outputDescs = GetDmlOutputDescs();
|
||||
Initialize(kernelInfo, kernelInputIndices, std::nullopt, kernelInfo.GetTensorShapeDescription().GetOutputTensorShape(0));
|
||||
|
||||
DML_ELEMENT_WISE_POW_OPERATOR_DESC opDesc = {};
|
||||
opDesc.InputTensor = &inputDescs[0];
|
||||
opDesc.ExponentTensor = &inputDescs[1];
|
||||
opDesc.OutputTensor = &outputDescs[0];
|
||||
std::vector<DML_TENSOR_DESC> inputDescs = GetDmlInputDescs();
|
||||
std::vector<DML_TENSOR_DESC> outputDescs = GetDmlOutputDescs();
|
||||
|
||||
SetDmlOperatorDesc({ DML_OPERATOR_ELEMENT_WISE_POW, &opDesc}, kernelInfo);
|
||||
DML_ELEMENT_WISE_CONSTANT_POW_OPERATOR_DESC opDesc = {};
|
||||
opDesc.InputTensor = &inputDescs[0];
|
||||
opDesc.OutputTensor = &outputDescs[0];
|
||||
opDesc.Exponent = static_cast<float>(ReadScalarTensorCastToFloat64(*constExpTensor));
|
||||
|
||||
SetDmlOperatorDesc({ DML_OPERATOR_ELEMENT_WISE_CONSTANT_POW, &opDesc}, kernelInfo);
|
||||
}
|
||||
else
|
||||
{
|
||||
Initialize(kernelInfo, std::nullopt, std::nullopt, kernelInfo.GetTensorShapeDescription().GetOutputTensorShape(0));
|
||||
|
||||
std::vector<DML_TENSOR_DESC> inputDescs = GetDmlInputDescs();
|
||||
std::vector<DML_TENSOR_DESC> outputDescs = GetDmlOutputDescs();
|
||||
|
||||
DML_ELEMENT_WISE_POW_OPERATOR_DESC opDesc = {};
|
||||
opDesc.InputTensor = &inputDescs[0];
|
||||
opDesc.ExponentTensor = &inputDescs[1];
|
||||
opDesc.OutputTensor = &outputDescs[0];
|
||||
|
||||
SetDmlOperatorDesc({ DML_OPERATOR_ELEMENT_WISE_POW, &opDesc}, kernelInfo);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
|
|
@ -565,7 +585,7 @@ public:
|
|||
opDesc.ScaleTensor = &inputDescs[1];
|
||||
opDesc.ZeroPointTensor = &inputDescs[2];
|
||||
opDesc.OutputTensor = &outputDescs[0];
|
||||
|
||||
|
||||
SetDmlOperatorDesc({ApiTraits::OperatorDescTraits<TOperatorDesc>::Type, &opDesc}, kernelInfo);
|
||||
}
|
||||
};
|
||||
|
|
|
|||
|
|
@ -6,6 +6,7 @@
|
|||
#include "core/providers/dml/DmlExecutionProvider/inc/MLOperatorAuthor.h"
|
||||
#include "MLOperatorAuthorPrivate.h"
|
||||
#include "core/common/gsl.h"
|
||||
#include <optional>
|
||||
|
||||
#ifdef ORT_NO_EXCEPTIONS
|
||||
#define ML_CHECK_BOOL(x) ORT_THROW_HR_IF(E_INVALIDARG, !(x))
|
||||
|
|
@ -604,6 +605,18 @@ public:
|
|||
return MLOperatorTensor(tensor.Get());
|
||||
}
|
||||
|
||||
std::optional<MLOperatorTensor> TryGetConstantInputTensor(uint32_t inputIndex) const
|
||||
{
|
||||
Microsoft::WRL::ComPtr<IMLOperatorTensor> tensor;
|
||||
ORT_THROW_IF_FAILED(m_implPrivate->TryGetConstantInputTensor(inputIndex, &tensor));
|
||||
if (tensor)
|
||||
{
|
||||
return MLOperatorTensor(tensor.Get());
|
||||
}
|
||||
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
uint32_t GetInputTensorDimensionCount(uint32_t inputIndex) const
|
||||
{
|
||||
auto shapeDesc = GetTensorShapeDescription();
|
||||
|
|
|
|||
|
|
@ -41,6 +41,11 @@ IMLOperatorShapeInferenceContextPrivate : public IMLOperatorShapeInferenceContex
|
|||
_Outptr_ IMLOperatorTensor** tensor
|
||||
) const noexcept PURE;
|
||||
|
||||
STDMETHOD(TryGetConstantInputTensor)(
|
||||
uint32_t inputIndex,
|
||||
_Outptr_ IMLOperatorTensor** tensor
|
||||
) const noexcept PURE;
|
||||
|
||||
//! Gets the number of dimensions of a tensor output of the operator.
|
||||
STDMETHOD(GetSequenceInputInfo)(
|
||||
uint32_t inputIndex,
|
||||
|
|
@ -73,6 +78,11 @@ IMLOperatorKernelCreationContextPrivate : public IMLOperatorKernelCreationContex
|
|||
_Outptr_ IMLOperatorTensor** tensor
|
||||
) const noexcept PURE;
|
||||
|
||||
STDMETHOD(TryGetConstantInputTensor)(
|
||||
uint32_t inputIndex,
|
||||
_Outptr_ IMLOperatorTensor** tensor
|
||||
) const noexcept PURE;
|
||||
|
||||
STDMETHOD_(bool, IsDmlGraphNode)() const noexcept PURE;
|
||||
|
||||
STDMETHOD(SetDmlOperator)(
|
||||
|
|
|
|||
Loading…
Reference in a new issue