Utilize DML constant input graph node

This commit is contained in:
Jeff Bloomfield 2023-08-11 15:51:35 -07:00
parent 4dfeef48eb
commit 4c8ef4080d
4 changed files with 96 additions and 17 deletions

View file

@ -279,12 +279,30 @@ namespace DmlGraphFusionHelper
return initializerPartitionMap;
}
enum DML_PREVIEW_OPERATOR_TYPE
{
DML_PREVIEW_OPERATOR_FIRST = 0xC0000000,
};
enum DML_GRAPH_NODE_TYPE_PREVIEW
{
DML_GRAPH_NODE_TYPE_CONSTANT_DATA = 0xCC000000,
};
struct DML_CONSTANT_DATA_GRAPH_NODE_DESC_PREVIEW
{
const BYTE* data;
UINT64 dataSize;
_Field_z_ _Maybenull_ const char* Name;
};
void ConvertGraphDesc(
const Dml::GraphDescBuilder::GraphDesc& graphDesc,
_Out_ DML_GRAPH_DESC& dmlGraphDesc,
const uint32_t inputCount,
const uint32_t outputCount,
_Inout_ std::vector<DML_OPERATOR_GRAPH_NODE_DESC>& dmlOperatorGraphNodes,
_Inout_ std::vector<DML_CONSTANT_DATA_GRAPH_NODE_DESC_PREVIEW>& dmlConstantGraphNodes,
_Inout_ std::vector<DML_GRAPH_NODE_DESC>& dmlGraphNodes,
_Inout_ std::vector<DML_GRAPH_EDGE_DESC>& dmlInputEdges,
_Inout_ std::vector<DML_GRAPH_EDGE_DESC>& dmlOutputEdges,
@ -293,8 +311,22 @@ namespace DmlGraphFusionHelper
for (size_t i = 0; i < graphDesc.nodes.size(); ++i)
{
auto& nodeInfo = graphDesc.nodes[i];
dmlOperatorGraphNodes[i] = DML_OPERATOR_GRAPH_NODE_DESC{nodeInfo.op.Get(), nodeInfo.name.data()};
dmlGraphNodes[i] = DML_GRAPH_NODE_DESC{DML_GRAPH_NODE_TYPE_OPERATOR, &dmlOperatorGraphNodes[i]};
if (std::holds_alternative<Microsoft::WRL::ComPtr<IDMLOperator>>(nodeInfo.nodeDef))
{
dmlOperatorGraphNodes[i] = DML_OPERATOR_GRAPH_NODE_DESC{std::get<Microsoft::WRL::ComPtr<IDMLOperator>>(nodeInfo.nodeDef).Get(), nodeInfo.name.data()};
dmlGraphNodes[i] = DML_GRAPH_NODE_DESC{DML_GRAPH_NODE_TYPE_OPERATOR, &dmlOperatorGraphNodes[i]};
}
else
{
dmlConstantGraphNodes[i] = DML_CONSTANT_DATA_GRAPH_NODE_DESC_PREVIEW{
std::get<std::vector<uint8_t>>(nodeInfo.nodeDef).data(),
std::get<std::vector<uint8_t>>(nodeInfo.nodeDef).size(),
nodeInfo.name.data()
};
dmlGraphNodes[i] = DML_GRAPH_NODE_DESC{(DML_GRAPH_NODE_TYPE) DML_GRAPH_NODE_TYPE_CONSTANT_DATA, &dmlConstantGraphNodes[i]};
}
}
for (size_t i = 0; i < graphDesc.inputEdges.size(); ++i)
@ -360,6 +392,8 @@ namespace DmlGraphFusionHelper
// convert DML EP GraphDesc into DML_GRAPH_DESC and create IDMLCompiledOperator
DML_GRAPH_DESC dmlGraphDesc = {};
std::vector<DML_OPERATOR_GRAPH_NODE_DESC> dmlOperatorGraphNodes(graphDesc.nodes.size());
std::vector<DML_CONSTANT_DATA_GRAPH_NODE_DESC_PREVIEW> dmlConstantGraphNodes(graphDesc.nodes.size());
std::vector<DML_GRAPH_NODE_DESC> dmlGraphNodes(graphDesc.nodes.size());
std::vector<DML_GRAPH_EDGE_DESC> dmlInputEdges(graphDesc.inputEdges.size());
std::vector<DML_GRAPH_EDGE_DESC> dmlOutputEdges(graphDesc.outputEdges.size());
@ -370,6 +404,7 @@ namespace DmlGraphFusionHelper
fusedNodeInputCount,
fusedNodeOutputCount,
dmlOperatorGraphNodes,
dmlConstantGraphNodes,
dmlGraphNodes,
dmlInputEdges,
dmlOutputEdges,

View file

@ -264,6 +264,7 @@ namespace Dml::GraphDescBuilder
std::unordered_map<uint32_t, uint32_t> operatorGraphNodeIndexToMainGraphNodeIndexMap;
uint32_t graphNodeCount = gsl::narrow_cast<uint32_t>(graphNodes.size());
const bool isNodeAsOpDesc = graphNodeCreateInfo.nodesAsOperatorDesc.size() > 0;
size_t firstOpDescGraphNode = graphNodes.size();
if (isNodeAsOpDesc)
{
@ -273,6 +274,8 @@ namespace Dml::GraphDescBuilder
ORT_THROW_HR_IF(E_UNEXPECTED, !graphNodeCreateInfo.nodesAsOperatorDesc[nodeIndex]);
operatorGraphNodeIndexToMainGraphNodeIndexMap.emplace(nodeIndex, graphNodeCount++);
}
graphNodes.resize(graphNodes.size() + graphNodeCreateInfo.nodeCount);
}
else
{
@ -281,7 +284,7 @@ namespace Dml::GraphDescBuilder
ORT_THROW_HR_IF(E_UNEXPECTED, !graphNodeCreateInfo.nodesAsIDMLOperator[nodeIndex].Get());
operatorGraphNodeIndexToMainGraphNodeIndexMap.emplace(nodeIndex, graphNodeCount++);
NodeInfo nodeInfo = {};
nodeInfo.op = std::move(graphNodeCreateInfo.nodesAsIDMLOperator[nodeIndex]);
nodeInfo.nodeDef = std::move(graphNodeCreateInfo.nodesAsIDMLOperator[nodeIndex]);
graphNodes.push_back(std::move(nodeInfo));
}
}
@ -303,21 +306,58 @@ namespace Dml::GraphDescBuilder
const uint32_t dmlFusedNodeInputIndex = iter->second;
DML_INPUT_GRAPH_EDGE_DESC edge = {};
edge.GraphInputIndex = dmlFusedNodeInputIndex;
edge.ToNodeIndex = mainGraphNodeIndex;
edge.ToNodeInputIndex = operatorGraphInputEdge.ToNodeInputIndex; // ?? might need to point inputIndex
graphInputEdges.push_back(edge);
// If this is a constant input, set the appropriate flags on the desc
if (isNodeAsOpDesc &&
dmlFusedNodeInputIndex < isConstGpuGraphInputCount &&
isConstGpuGraphInput[dmlFusedNodeInputIndex])
{
auto& graphInputNode = graphNodeCreateInfo.nodesAsOperatorDesc[operatorGraphInputEdge.ToNodeIndex];
std::vector<DmlBufferTensorDesc*> toNodeInputTensorDescs = graphInputNode->GetInputTensors();
DmlBufferTensorDesc* tensorDesc = toNodeInputTensorDescs[operatorGraphInputEdge.ToNodeInputIndex];
tensorDesc->flags |= DML_TENSOR_FLAG_OWNED_BY_DML;
// This is a highly inefficient approach to generating constant nodes. It duplicates constant data
// across the graph input as well as every consumer's unique constant node. However it is currently
// only used for small inputs.
uint32_t c_maxConstNodeDataSize = 64;
ComPtr<OnnxTensorWrapper> constantInput = constantCpuGraphInputGetter(arg->Name());
if (constantInput && constantInput->GetTensorByteSize() < c_maxConstNodeDataSize)
{
std::vector<uint8_t> tensorData;
tensorData.insert(
tensorData.begin(),
static_cast<const uint8_t*>(constantInput->GetData()),
static_cast<const uint8_t*>(constantInput->GetData()) + constantInput->GetTensorByteSize());
NodeInfo nodeInfo = {};
nodeInfo.nodeDef = std::move(tensorData);
graphNodes.push_back(std::move(nodeInfo));
DML_INTERMEDIATE_GRAPH_EDGE_DESC edge = {};
edge.FromNodeIndex = static_cast<UINT>(graphNodes.size() - 1);
edge.FromNodeOutputIndex = 0;
edge.ToNodeIndex = mainGraphNodeIndex;
edge.ToNodeInputIndex = operatorGraphInputEdge.ToNodeInputIndex;
graphIntermediateEdges.push_back(edge);
}
else
{
DML_INPUT_GRAPH_EDGE_DESC edge = {};
edge.GraphInputIndex = dmlFusedNodeInputIndex;
edge.ToNodeIndex = mainGraphNodeIndex;
edge.ToNodeInputIndex = operatorGraphInputEdge.ToNodeInputIndex;
graphInputEdges.push_back(edge);
auto& graphInputNode = graphNodeCreateInfo.nodesAsOperatorDesc[operatorGraphInputEdge.ToNodeIndex];
std::vector<DmlBufferTensorDesc*> toNodeInputTensorDescs = graphInputNode->GetInputTensors();
DmlBufferTensorDesc* tensorDesc = toNodeInputTensorDescs[operatorGraphInputEdge.ToNodeInputIndex];
tensorDesc->flags |= DML_TENSOR_FLAG_OWNED_BY_DML;
}
}
else
{
DML_INPUT_GRAPH_EDGE_DESC edge = {};
edge.GraphInputIndex = dmlFusedNodeInputIndex;
edge.ToNodeIndex = mainGraphNodeIndex;
edge.ToNodeInputIndex = operatorGraphInputEdge.ToNodeInputIndex;
graphInputEdges.push_back(edge);
}
}
else
@ -360,17 +400,19 @@ namespace Dml::GraphDescBuilder
if (isNodeAsOpDesc)
{
for (auto& opDesc : graphNodeCreateInfo.nodesAsOperatorDesc)
for (uint32_t i = 0; i < graphNodeCreateInfo.nodesAsOperatorDesc.size(); ++i)
{
auto& opDesc = graphNodeCreateInfo.nodesAsOperatorDesc[i];
DML_OPERATOR_DESC dmlDesc = SchemaHelpers::ConvertOperatorDesc(*opDesc, &allocator);
ComPtr<IDMLOperator> op;
ORT_THROW_IF_FAILED(device->CreateOperator(&dmlDesc, IID_PPV_ARGS(&op)));
allocator.Reset();
NodeInfo nodeInfo = {};
nodeInfo.op = std::move(op);
nodeInfo.nodeDef = std::move(op);
nodeInfo.name = node.Name();
graphNodes.push_back(std::move(nodeInfo));
graphNodes[firstOpDescGraphNode + i] = std::move(nodeInfo);
}
}
}

View file

@ -27,7 +27,7 @@ namespace Dml
struct NodeInfo
{
Microsoft::WRL::ComPtr<IDMLOperator> op;
std::variant<Microsoft::WRL::ComPtr<IDMLOperator>, std::vector<uint8_t>> nodeDef;
std::string name;
};

View file

@ -339,6 +339,8 @@ class OnnxTensorWrapper : public WRL::Base<IMLOperatorTensor>, public Closable
const onnxruntime::Tensor* GetInterface() const { return nullptr; }
onnxruntime::Tensor* GetInterface() { return nullptr; }
size_t GetTensorByteSize() const { return m_tensorByteSize; }
private:
size_t m_tensorByteSize = 0;
std::unique_ptr<std::byte[]> m_unpackedTensor;