Utilize DML constant input graph node (#18267)

### Description
This PR also includes,
8b0a55e7cc DML constant pow operator
7520974970 Enable custom heaps based on query-

### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->

---------

Co-authored-by: Jeff Bloomfield <jeffbloo@microsoft.com>
This commit is contained in:
raoanag 2023-11-17 16:43:09 -08:00 committed by Jeff Bloomfield
parent dcfff10f57
commit d5f3aae3fd
11 changed files with 229 additions and 38 deletions

View file

@ -226,8 +226,7 @@ namespace DmlGraphFusionHelper
{
ComPtr<ID3D12Resource> initializeInputBuffer;
// D3D_FEATURE_LEVEL_1_0_CORE doesn't support Custom heaps
if (providerImpl->IsMcdmDevice())
if (!providerImpl->CustomHeapsSupported())
{
initializeInputBuffer = CreateResource(providerImpl, tensorPtr, tensorByteSize);
}
@ -294,6 +293,7 @@ namespace DmlGraphFusionHelper
const uint32_t inputCount,
const uint32_t outputCount,
_Inout_ std::vector<DML_OPERATOR_GRAPH_NODE_DESC>& dmlOperatorGraphNodes,
_Inout_ std::vector<DML_CONSTANT_DATA_GRAPH_NODE_DESC>& dmlConstantGraphNodes,
_Inout_ std::vector<DML_GRAPH_NODE_DESC>& dmlGraphNodes,
_Inout_ std::vector<DML_GRAPH_EDGE_DESC>& dmlInputEdges,
_Inout_ std::vector<DML_GRAPH_EDGE_DESC>& dmlOutputEdges,
@ -302,8 +302,24 @@ namespace DmlGraphFusionHelper
for (size_t i = 0; i < graphDesc.nodes.size(); ++i)
{
auto& nodeInfo = graphDesc.nodes[i];
dmlOperatorGraphNodes[i] = DML_OPERATOR_GRAPH_NODE_DESC{nodeInfo.op.Get(), nodeInfo.name.data()};
dmlGraphNodes[i] = DML_GRAPH_NODE_DESC{DML_GRAPH_NODE_TYPE_OPERATOR, &dmlOperatorGraphNodes[i]};
if (std::holds_alternative<Microsoft::WRL::ComPtr<IDMLOperator>>(nodeInfo.nodeDef))
{
dmlOperatorGraphNodes[i] = DML_OPERATOR_GRAPH_NODE_DESC{std::get<Microsoft::WRL::ComPtr<IDMLOperator>>(nodeInfo.nodeDef).Get(), nodeInfo.name.data()};
dmlGraphNodes[i] = DML_GRAPH_NODE_DESC{DML_GRAPH_NODE_TYPE_OPERATOR, &dmlOperatorGraphNodes[i]};
}
else
{
auto& nodeDefinitionData = std::get<std::vector<uint8_t>>(nodeInfo.nodeDef);
dmlConstantGraphNodes[i] = DML_CONSTANT_DATA_GRAPH_NODE_DESC{
nodeDefinitionData.data(),
nodeDefinitionData.size(),
nodeInfo.name.data()
};
// TODO: Change as new header is ingested
dmlGraphNodes[i] = DML_GRAPH_NODE_DESC{static_cast<DML_GRAPH_NODE_TYPE>(2), &dmlConstantGraphNodes[i]};
}
}
for (size_t i = 0; i < graphDesc.inputEdges.size(); ++i)
@ -392,6 +408,8 @@ namespace DmlGraphFusionHelper
// convert DML EP GraphDesc into DML_GRAPH_DESC and create IDMLCompiledOperator
DML_GRAPH_DESC dmlGraphDesc = {};
std::vector<DML_OPERATOR_GRAPH_NODE_DESC> dmlOperatorGraphNodes(graphDesc.nodes.size());
std::vector<DML_CONSTANT_DATA_GRAPH_NODE_DESC> dmlConstantGraphNodes(graphDesc.nodes.size());
std::vector<DML_GRAPH_NODE_DESC> dmlGraphNodes(graphDesc.nodes.size());
std::vector<DML_GRAPH_EDGE_DESC> dmlInputEdges(graphDesc.inputEdges.size());
std::vector<DML_GRAPH_EDGE_DESC> dmlOutputEdges(graphDesc.outputEdges.size());
@ -402,6 +420,7 @@ namespace DmlGraphFusionHelper
fusedNodeInputCount,
fusedNodeOutputCount,
dmlOperatorGraphNodes,
dmlConstantGraphNodes,
dmlGraphNodes,
dmlInputEdges,
dmlOutputEdges,

View file

@ -182,6 +182,32 @@ namespace Dml
}
m_isMcdmDevice = (featureLevels.MaxSupportedFeatureLevel == D3D_FEATURE_LEVEL_1_0_CORE_PRIVATE);
m_areCustomHeapsSupported = !m_isMcdmDevice;
if (m_isMcdmDevice)
{
// TODO: Ingest updated header file
typedef struct D3D12_FEATURE_DATA_D3D12_OPTIONS19
{
BOOL MismatchingOutputDimensionsSupported;
UINT SupportedSampleCountsWithNoOutputs;
BOOL PointSamplingAddressesNeverRoundUp;
BOOL RasterizerDesc2Supported;
BOOL NarrowQuadrilateralLinesSupported;
BOOL AnisoFilterWithPointMipSupported;
UINT MaxSamplerDescriptorHeapSize;
UINT MaxSamplerDescriptorHeapSizeWithStaticSamplers;
UINT MaxViewDescriptorHeapSize;
_Out_ BOOL ComputeOnlyCustomHeapSupported;
} D3D12_FEATURE_DATA_D3D12_OPTIONS19;
D3D12_FEATURE_DATA_D3D12_OPTIONS19 options19 = {};
// The call may fail in which case the default value is false
d3d12Device->CheckFeatureSupport(static_cast<D3D12_FEATURE>(48) /*D3D12_FEATURE_D3D12_OPTIONS19*/, &options19, sizeof(options19));
m_areCustomHeapsSupported = options19.ComputeOnlyCustomHeapSupported;
}
m_context = std::make_shared<ExecutionContext>(m_d3d12Device.Get(), m_dmlDevice.Get(), queue);
@ -1089,6 +1115,11 @@ namespace Dml
return m_isMcdmDevice;
}
bool __stdcall ExecutionProviderImpl::CustomHeapsSupported() const noexcept
{
return m_areCustomHeapsSupported;
}
bool __stdcall ExecutionProviderImpl::MetacommandsEnabled() const noexcept
{
return m_areMetacommandsEnabled;

View file

@ -150,6 +150,7 @@ namespace Dml
}
STDMETHOD_(bool, IsMcdmDevice)() const noexcept final;
STDMETHOD_(bool, CustomHeapsSupported)() const noexcept final;
STDMETHOD_(bool, MetacommandsEnabled)() const noexcept final;
bool DynamicGraphFusionEnabled() const noexcept;
@ -186,6 +187,7 @@ namespace Dml
ComPtr<ID3D12Device> m_d3d12Device;
ComPtr<IDMLDevice> m_dmlDevice;
bool m_isMcdmDevice = false;
bool m_areCustomHeapsSupported = false;
bool m_areMetacommandsEnabled = true;
bool m_dynamicGraphFusionEnabled = false;
bool m_native16BitShaderOpsSupported = false;

View file

@ -149,7 +149,7 @@ namespace Dml::GraphDescBuilder
const std::unordered_map<std::string, std::pair<const ONNX_NAMESPACE::TensorProto*, bool>>& isInitializerTransferable,
const std::unordered_map<std::string, GraphNodeProperties>& graphNodePropertyMap,
IDMLDevice* device,
const void* executionHandle,
const ExecutionProviderImpl* executionHandle,
const onnxruntime::Path& modelPath,
gsl::span<const onnxruntime::Node* const> subgraphNodes,
gsl::span<const onnxruntime::NodeArg* const> subgraphInputs,
@ -198,7 +198,7 @@ namespace Dml::GraphDescBuilder
const uint32_t minNodeCountToReuseCommandList = 5;
bool reuseCommandList = false;
if (subgraphNodes.size() >= minNodeCountToReuseCommandList)
if (subgraphNodes.size() >= minNodeCountToReuseCommandList || executionHandle->IsMcdmDevice())
{
reuseCommandList = true;
}
@ -232,14 +232,22 @@ namespace Dml::GraphDescBuilder
{
ComPtr<IMLOperatorTensor> tensor = nullptr;
// Check whether this specific node requested support for constant CPU inputs
if (std::find(requiredConstantCpuInputs.begin(), requiredConstantCpuInputs.end(), inputIndex) != requiredConstantCpuInputs.end())
auto inputDefs = node.InputDefs();
if (inputIndex < inputDefs.size())
{
auto inputDefs = node.InputDefs();
if (inputIndex < inputDefs.size())
const onnxruntime::NodeArg* arg = inputDefs[inputIndex];
tensor = constantCpuGraphInputGetter(arg->Name());
if (tensor == nullptr)
{
const onnxruntime::NodeArg* arg = inputDefs[inputIndex];
tensor = constantCpuGraphInputGetter(arg->Name());
bool inputRequiredAsConstant = std::find(
requiredConstantCpuInputs.begin(),
requiredConstantCpuInputs.end(),
inputIndex) != requiredConstantCpuInputs.end();
// This shouldn't happen since kernel creation is deferred and repeated when required constant inputs are not present.
ORT_THROW_HR_IF(E_UNEXPECTED, inputRequiredAsConstant);
}
}
@ -289,6 +297,7 @@ namespace Dml::GraphDescBuilder
std::unordered_map<uint32_t, uint32_t> operatorGraphNodeIndexToMainGraphNodeIndexMap;
uint32_t graphNodeCount = gsl::narrow_cast<uint32_t>(graphNodes.size());
const bool isNodeAsOpDesc = graphNodeCreateInfo.nodesAsOperatorDesc.size() > 0;
size_t firstOpDescGraphNodeIndex = graphNodes.size();
if (isNodeAsOpDesc)
{
@ -298,6 +307,8 @@ namespace Dml::GraphDescBuilder
ORT_THROW_HR_IF(E_UNEXPECTED, !graphNodeCreateInfo.nodesAsOperatorDesc[nodeIndex]);
operatorGraphNodeIndexToMainGraphNodeIndexMap.emplace(nodeIndex, graphNodeCount++);
}
graphNodes.resize(graphNodes.size() + graphNodeCreateInfo.nodeCount);
}
else
{
@ -306,7 +317,7 @@ namespace Dml::GraphDescBuilder
ORT_THROW_HR_IF(E_UNEXPECTED, !graphNodeCreateInfo.nodesAsIDMLOperator[nodeIndex].Get());
operatorGraphNodeIndexToMainGraphNodeIndexMap.emplace(nodeIndex, graphNodeCount++);
NodeInfo nodeInfo = {};
nodeInfo.op = std::move(graphNodeCreateInfo.nodesAsIDMLOperator[nodeIndex]);
nodeInfo.nodeDef = std::move(graphNodeCreateInfo.nodesAsIDMLOperator[nodeIndex]);
graphNodes.push_back(std::move(nodeInfo));
}
}
@ -328,21 +339,59 @@ namespace Dml::GraphDescBuilder
const uint32_t dmlFusedNodeInputIndex = iter->second;
DML_INPUT_GRAPH_EDGE_DESC edge = {};
edge.GraphInputIndex = dmlFusedNodeInputIndex;
edge.ToNodeIndex = mainGraphNodeIndex;
edge.ToNodeInputIndex = operatorGraphInputEdge.ToNodeInputIndex; // ?? might need to point inputIndex
graphInputEdges.push_back(edge);
// If this is a constant input, set the appropriate flags on the desc
if (isNodeAsOpDesc &&
dmlFusedNodeInputIndex < isConstGpuGraphInputCount &&
isConstGpuGraphInput[dmlFusedNodeInputIndex])
{
auto& graphInputNode = graphNodeCreateInfo.nodesAsOperatorDesc[operatorGraphInputEdge.ToNodeIndex];
std::vector<DmlBufferTensorDesc*> toNodeInputTensorDescs = graphInputNode->GetInputTensors();
DmlBufferTensorDesc* tensorDesc = toNodeInputTensorDescs[operatorGraphInputEdge.ToNodeInputIndex];
tensorDesc->flags |= DML_TENSOR_FLAG_OWNED_BY_DML;
// This is a highly inefficient approach to generating constant nodes. It duplicates constant data
// across the graph input as well as every consumer's unique constant node. However it is currently
// only used for small inputs.
// TODO: Rework this to create DML constant nodes with the minimum data size actually used by consuming
// nodes. This would allow this size to be reduced while handling the case that 1D scale and zero point
// values that have been de-duplicated with conversion to scalars in kernels.
uint32_t c_maxConstNodeDataSize = 1024 * 1024;
ComPtr<OnnxTensorWrapper> constantInput = constantCpuGraphInputGetter(arg->Name());
if (constantInput && constantInput->GetTensorByteSize() < c_maxConstNodeDataSize)
{
auto data = static_cast<const uint8_t*>(constantInput->GetData());
std::vector<uint8_t> tensorData(data, data + constantInput->GetTensorByteSize());
NodeInfo nodeInfo = {};
nodeInfo.nodeDef = std::move(tensorData);
graphNodes.push_back(std::move(nodeInfo));
DML_INTERMEDIATE_GRAPH_EDGE_DESC edge = {};
edge.FromNodeIndex = static_cast<UINT>(graphNodes.size() - 1);
edge.FromNodeOutputIndex = 0;
edge.ToNodeIndex = mainGraphNodeIndex;
edge.ToNodeInputIndex = operatorGraphInputEdge.ToNodeInputIndex;
graphIntermediateEdges.push_back(edge);
}
else
{
DML_INPUT_GRAPH_EDGE_DESC edge = {};
edge.GraphInputIndex = dmlFusedNodeInputIndex;
edge.ToNodeIndex = mainGraphNodeIndex;
edge.ToNodeInputIndex = operatorGraphInputEdge.ToNodeInputIndex;
graphInputEdges.push_back(edge);
auto& graphInputNode = graphNodeCreateInfo.nodesAsOperatorDesc[operatorGraphInputEdge.ToNodeIndex];
std::vector<DmlBufferTensorDesc*> toNodeInputTensorDescs = graphInputNode->GetInputTensors();
DmlBufferTensorDesc* tensorDesc = toNodeInputTensorDescs[operatorGraphInputEdge.ToNodeInputIndex];
tensorDesc->flags |= DML_TENSOR_FLAG_OWNED_BY_DML;
}
}
else
{
DML_INPUT_GRAPH_EDGE_DESC edge = {};
edge.GraphInputIndex = dmlFusedNodeInputIndex;
edge.ToNodeIndex = mainGraphNodeIndex;
edge.ToNodeInputIndex = operatorGraphInputEdge.ToNodeInputIndex;
graphInputEdges.push_back(edge);
}
}
else
@ -387,17 +436,28 @@ namespace Dml::GraphDescBuilder
if (isNodeAsOpDesc)
{
for (auto& opDesc : graphNodeCreateInfo.nodesAsOperatorDesc)
for (size_t i = 0; i < graphNodeCreateInfo.nodesAsOperatorDesc.size(); ++i)
{
auto& opDesc = graphNodeCreateInfo.nodesAsOperatorDesc[i];
DML_OPERATOR_DESC dmlDesc = SchemaHelpers::ConvertOperatorDesc(*opDesc, &allocator);
// TODO: Change as new header is ingested
if (dmlDesc.Type == (DML_OPERATOR_TYPE) DML_OPERATOR_QUANTIZED_LINEAR_AVERAGE_POOLING)
dmlDesc.Type = (DML_OPERATOR_TYPE) 169;
// TODO: Change as new header is ingested
if (dmlDesc.Type == (DML_OPERATOR_TYPE) DML_OPERATOR_MATRIX_MULTIPLY_INTEGER_TO_FLOAT)
dmlDesc.Type = (DML_OPERATOR_TYPE) 170;
ComPtr<IDMLOperator> op;
ORT_THROW_IF_FAILED(device->CreateOperator(&dmlDesc, IID_PPV_ARGS(&op)));
allocator.Reset();
NodeInfo nodeInfo = {};
nodeInfo.op = std::move(op);
nodeInfo.nodeDef = std::move(op);
nodeInfo.name = node.Name();
graphNodes.push_back(std::move(nodeInfo));
graphNodes[firstOpDescGraphNodeIndex + i] = std::move(nodeInfo);
}
}
}

View file

@ -4,6 +4,7 @@
#pragma once
#include "MLOperatorAuthorImpl.h"
#include "ExecutionProvider.h"
namespace Dml
{
@ -27,7 +28,7 @@ namespace Dml
struct NodeInfo
{
Microsoft::WRL::ComPtr<IDMLOperator> op;
std::variant<Microsoft::WRL::ComPtr<IDMLOperator>, std::vector<uint8_t>> nodeDef;
std::string name;
};
@ -47,7 +48,7 @@ namespace Dml
const std::unordered_map<std::string, std::pair<const ONNX_NAMESPACE::TensorProto*, bool>>& isInitializerTransferable,
const std::unordered_map<std::string, GraphNodeProperties>& graphNodePropertyMap,
IDMLDevice* device,
const void* executionHandle,
const ExecutionProviderImpl* executionHandle,
const onnxruntime::Path& modelPath,
gsl::span<const onnxruntime::Node* const> subgraphNodes,
gsl::span<const onnxruntime::NodeArg* const> subgraphInputs,

View file

@ -76,6 +76,7 @@ namespace Dml
STDMETHOD(AllocatePooledResource(size_t size, AllocatorRoundingMode roundingMode, ID3D12Resource **d3dResource, IUnknown* *pooledResource)) const noexcept = 0;
STDMETHOD_(bool, IsMcdmDevice)() const noexcept = 0;
STDMETHOD_(bool, CustomHeapsSupported)() const noexcept = 0;
STDMETHOD_(bool, MetacommandsEnabled)() const noexcept = 0;
};
} // namespace Dml

View file

@ -1123,7 +1123,7 @@ namespace Windows::AI::MachineLearning::Adapter
}
ORT_CATCH_RETURN
}
template <class NodeInfoImpl_t, class Base1_t, class Base2_t>
HRESULT STDMETHODCALLTYPE OpNodeInfoWrapper<NodeInfoImpl_t, Base1_t, Base2_t>::GetConstantInputTensor(uint32_t inputIndex, IMLOperatorTensor** tensor) const noexcept
{
@ -1153,6 +1153,33 @@ namespace Windows::AI::MachineLearning::Adapter
ORT_CATCH_RETURN
}
template <class NodeInfoImpl_t, class Base1_t, class Base2_t>
HRESULT STDMETHODCALLTYPE OpNodeInfoWrapper<NodeInfoImpl_t, Base1_t, Base2_t>::TryGetConstantInputTensor(uint32_t inputIndex, IMLOperatorTensor** tensor) const noexcept
{
ORT_TRY
{
auto constantInput = m_constantInputGetter(inputIndex);
ORT_THROW_HR_IF(E_INVALIDARG, !std::holds_alternative<ComPtr<IMLOperatorTensor>>(constantInput));
auto tensorWrapper = std::get<ComPtr<IMLOperatorTensor>>(constantInput);
if (tensorWrapper == nullptr)
{
bool inputRequiredAsConstant = std::find(
m_requiredConstantCpuInputs.begin(),
m_requiredConstantCpuInputs.end(),
inputIndex) != m_requiredConstantCpuInputs.end();
// This shouldn't happen since kernel creation is deferred and repeated when required constant inputs are not present.
ORT_THROW_HR_IF(E_UNEXPECTED, inputRequiredAsConstant);
}
*tensor = tensorWrapper.Detach();
return S_OK;
}
ORT_CATCH_RETURN
}
HRESULT STDMETHODCALLTYPE OpKernelInfoWrapper::GetOutputTensorShape(uint32_t outputIndex, uint32_t dimensionCount, uint32_t* dimensions) const noexcept
{
ORT_TRY

View file

@ -204,6 +204,11 @@ class OpNodeInfoWrapper : public Base1_t, public Base2_t, public Closable
_Outptr_ IMLOperatorTensor** tensor
) const noexcept;
HRESULT STDMETHODCALLTYPE TryGetConstantInputTensor(
uint32_t inputIndex,
_Outptr_ IMLOperatorTensor** tensor
) const noexcept;
protected:
// Lifetime is managed by the caller and guaranteed to outlive this class
const onnxruntime::OpNodeProtoHelper<NodeInfoImpl_t>* m_impl = nullptr;
@ -299,6 +304,8 @@ class OnnxTensorWrapper : public WRL::Base<IMLOperatorTensor>, public Closable
const onnxruntime::Tensor* GetInterface() const { return nullptr; }
onnxruntime::Tensor* GetInterface() { return nullptr; }
size_t GetTensorByteSize() const { return m_tensorByteSize; }
private:
size_t m_tensorByteSize = 0;
std::unique_ptr<std::byte[]> m_unpackedTensor;

View file

@ -479,17 +479,37 @@ public:
ML_CHECK_VALID_ARGUMENT(kernelInfo.GetInputCount() == 2);
ML_CHECK_VALID_ARGUMENT(kernelInfo.GetOutputCount() == 1);
Initialize(kernelInfo, std::nullopt, std::nullopt, kernelInfo.GetTensorShapeDescription().GetOutputTensorShape(0));
auto constExpTensor = kernelInfo.TryGetConstantInputTensor(1);
if (constExpTensor && constExpTensor->GetTotalElementCount() == 1)
{
std::vector<std::optional<uint32_t>> kernelInputIndices = {0};
std::vector<DML_TENSOR_DESC> inputDescs = GetDmlInputDescs();
std::vector<DML_TENSOR_DESC> outputDescs = GetDmlOutputDescs();
Initialize(kernelInfo, kernelInputIndices, std::nullopt, kernelInfo.GetTensorShapeDescription().GetOutputTensorShape(0));
DML_ELEMENT_WISE_POW_OPERATOR_DESC opDesc = {};
opDesc.InputTensor = &inputDescs[0];
opDesc.ExponentTensor = &inputDescs[1];
opDesc.OutputTensor = &outputDescs[0];
std::vector<DML_TENSOR_DESC> inputDescs = GetDmlInputDescs();
std::vector<DML_TENSOR_DESC> outputDescs = GetDmlOutputDescs();
SetDmlOperatorDesc({ DML_OPERATOR_ELEMENT_WISE_POW, &opDesc}, kernelInfo);
DML_ELEMENT_WISE_CONSTANT_POW_OPERATOR_DESC opDesc = {};
opDesc.InputTensor = &inputDescs[0];
opDesc.OutputTensor = &outputDescs[0];
opDesc.Exponent = static_cast<float>(ReadScalarTensorCastToFloat64(*constExpTensor));
SetDmlOperatorDesc({ DML_OPERATOR_ELEMENT_WISE_CONSTANT_POW, &opDesc}, kernelInfo);
}
else
{
Initialize(kernelInfo, std::nullopt, std::nullopt, kernelInfo.GetTensorShapeDescription().GetOutputTensorShape(0));
std::vector<DML_TENSOR_DESC> inputDescs = GetDmlInputDescs();
std::vector<DML_TENSOR_DESC> outputDescs = GetDmlOutputDescs();
DML_ELEMENT_WISE_POW_OPERATOR_DESC opDesc = {};
opDesc.InputTensor = &inputDescs[0];
opDesc.ExponentTensor = &inputDescs[1];
opDesc.OutputTensor = &outputDescs[0];
SetDmlOperatorDesc({ DML_OPERATOR_ELEMENT_WISE_POW, &opDesc}, kernelInfo);
}
}
};
@ -565,7 +585,7 @@ public:
opDesc.ScaleTensor = &inputDescs[1];
opDesc.ZeroPointTensor = &inputDescs[2];
opDesc.OutputTensor = &outputDescs[0];
SetDmlOperatorDesc({ApiTraits::OperatorDescTraits<TOperatorDesc>::Type, &opDesc}, kernelInfo);
}
};

View file

@ -6,6 +6,7 @@
#include "core/providers/dml/DmlExecutionProvider/inc/MLOperatorAuthor.h"
#include "MLOperatorAuthorPrivate.h"
#include "core/common/gsl.h"
#include <optional>
#ifdef ORT_NO_EXCEPTIONS
#define ML_CHECK_BOOL(x) ORT_THROW_HR_IF(E_INVALIDARG, !(x))
@ -604,6 +605,18 @@ public:
return MLOperatorTensor(tensor.Get());
}
std::optional<MLOperatorTensor> TryGetConstantInputTensor(uint32_t inputIndex) const
{
Microsoft::WRL::ComPtr<IMLOperatorTensor> tensor;
ORT_THROW_IF_FAILED(m_implPrivate->TryGetConstantInputTensor(inputIndex, &tensor));
if (tensor)
{
return MLOperatorTensor(tensor.Get());
}
return std::nullopt;
}
uint32_t GetInputTensorDimensionCount(uint32_t inputIndex) const
{
auto shapeDesc = GetTensorShapeDescription();

View file

@ -41,6 +41,11 @@ IMLOperatorShapeInferenceContextPrivate : public IMLOperatorShapeInferenceContex
_Outptr_ IMLOperatorTensor** tensor
) const noexcept PURE;
STDMETHOD(TryGetConstantInputTensor)(
uint32_t inputIndex,
_Outptr_ IMLOperatorTensor** tensor
) const noexcept PURE;
//! Gets the number of dimensions of a tensor output of the operator.
STDMETHOD(GetSequenceInputInfo)(
uint32_t inputIndex,
@ -73,6 +78,11 @@ IMLOperatorKernelCreationContextPrivate : public IMLOperatorKernelCreationContex
_Outptr_ IMLOperatorTensor** tensor
) const noexcept PURE;
STDMETHOD(TryGetConstantInputTensor)(
uint32_t inputIndex,
_Outptr_ IMLOperatorTensor** tensor
) const noexcept PURE;
STDMETHOD_(bool, IsDmlGraphNode)() const noexcept PURE;
STDMETHOD(SetDmlOperator)(