mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-01 23:30:35 +00:00
Utilize DML constant input graph node
This commit is contained in:
parent
4dfeef48eb
commit
4c8ef4080d
4 changed files with 96 additions and 17 deletions
|
|
@ -279,12 +279,30 @@ namespace DmlGraphFusionHelper
|
|||
return initializerPartitionMap;
|
||||
}
|
||||
|
||||
enum DML_PREVIEW_OPERATOR_TYPE
|
||||
{
|
||||
DML_PREVIEW_OPERATOR_FIRST = 0xC0000000,
|
||||
};
|
||||
|
||||
enum DML_GRAPH_NODE_TYPE_PREVIEW
|
||||
{
|
||||
DML_GRAPH_NODE_TYPE_CONSTANT_DATA = 0xCC000000,
|
||||
};
|
||||
|
||||
struct DML_CONSTANT_DATA_GRAPH_NODE_DESC_PREVIEW
|
||||
{
|
||||
const BYTE* data;
|
||||
UINT64 dataSize;
|
||||
_Field_z_ _Maybenull_ const char* Name;
|
||||
};
|
||||
|
||||
void ConvertGraphDesc(
|
||||
const Dml::GraphDescBuilder::GraphDesc& graphDesc,
|
||||
_Out_ DML_GRAPH_DESC& dmlGraphDesc,
|
||||
const uint32_t inputCount,
|
||||
const uint32_t outputCount,
|
||||
_Inout_ std::vector<DML_OPERATOR_GRAPH_NODE_DESC>& dmlOperatorGraphNodes,
|
||||
_Inout_ std::vector<DML_CONSTANT_DATA_GRAPH_NODE_DESC_PREVIEW>& dmlConstantGraphNodes,
|
||||
_Inout_ std::vector<DML_GRAPH_NODE_DESC>& dmlGraphNodes,
|
||||
_Inout_ std::vector<DML_GRAPH_EDGE_DESC>& dmlInputEdges,
|
||||
_Inout_ std::vector<DML_GRAPH_EDGE_DESC>& dmlOutputEdges,
|
||||
|
|
@ -293,8 +311,22 @@ namespace DmlGraphFusionHelper
|
|||
for (size_t i = 0; i < graphDesc.nodes.size(); ++i)
|
||||
{
|
||||
auto& nodeInfo = graphDesc.nodes[i];
|
||||
dmlOperatorGraphNodes[i] = DML_OPERATOR_GRAPH_NODE_DESC{nodeInfo.op.Get(), nodeInfo.name.data()};
|
||||
dmlGraphNodes[i] = DML_GRAPH_NODE_DESC{DML_GRAPH_NODE_TYPE_OPERATOR, &dmlOperatorGraphNodes[i]};
|
||||
|
||||
if (std::holds_alternative<Microsoft::WRL::ComPtr<IDMLOperator>>(nodeInfo.nodeDef))
|
||||
{
|
||||
dmlOperatorGraphNodes[i] = DML_OPERATOR_GRAPH_NODE_DESC{std::get<Microsoft::WRL::ComPtr<IDMLOperator>>(nodeInfo.nodeDef).Get(), nodeInfo.name.data()};
|
||||
dmlGraphNodes[i] = DML_GRAPH_NODE_DESC{DML_GRAPH_NODE_TYPE_OPERATOR, &dmlOperatorGraphNodes[i]};
|
||||
}
|
||||
else
|
||||
{
|
||||
dmlConstantGraphNodes[i] = DML_CONSTANT_DATA_GRAPH_NODE_DESC_PREVIEW{
|
||||
std::get<std::vector<uint8_t>>(nodeInfo.nodeDef).data(),
|
||||
std::get<std::vector<uint8_t>>(nodeInfo.nodeDef).size(),
|
||||
nodeInfo.name.data()
|
||||
};
|
||||
|
||||
dmlGraphNodes[i] = DML_GRAPH_NODE_DESC{(DML_GRAPH_NODE_TYPE) DML_GRAPH_NODE_TYPE_CONSTANT_DATA, &dmlConstantGraphNodes[i]};
|
||||
}
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < graphDesc.inputEdges.size(); ++i)
|
||||
|
|
@ -360,6 +392,8 @@ namespace DmlGraphFusionHelper
|
|||
// convert DML EP GraphDesc into DML_GRAPH_DESC and create IDMLCompiledOperator
|
||||
DML_GRAPH_DESC dmlGraphDesc = {};
|
||||
std::vector<DML_OPERATOR_GRAPH_NODE_DESC> dmlOperatorGraphNodes(graphDesc.nodes.size());
|
||||
std::vector<DML_CONSTANT_DATA_GRAPH_NODE_DESC_PREVIEW> dmlConstantGraphNodes(graphDesc.nodes.size());
|
||||
|
||||
std::vector<DML_GRAPH_NODE_DESC> dmlGraphNodes(graphDesc.nodes.size());
|
||||
std::vector<DML_GRAPH_EDGE_DESC> dmlInputEdges(graphDesc.inputEdges.size());
|
||||
std::vector<DML_GRAPH_EDGE_DESC> dmlOutputEdges(graphDesc.outputEdges.size());
|
||||
|
|
@ -370,6 +404,7 @@ namespace DmlGraphFusionHelper
|
|||
fusedNodeInputCount,
|
||||
fusedNodeOutputCount,
|
||||
dmlOperatorGraphNodes,
|
||||
dmlConstantGraphNodes,
|
||||
dmlGraphNodes,
|
||||
dmlInputEdges,
|
||||
dmlOutputEdges,
|
||||
|
|
|
|||
|
|
@ -264,6 +264,7 @@ namespace Dml::GraphDescBuilder
|
|||
std::unordered_map<uint32_t, uint32_t> operatorGraphNodeIndexToMainGraphNodeIndexMap;
|
||||
uint32_t graphNodeCount = gsl::narrow_cast<uint32_t>(graphNodes.size());
|
||||
const bool isNodeAsOpDesc = graphNodeCreateInfo.nodesAsOperatorDesc.size() > 0;
|
||||
size_t firstOpDescGraphNode = graphNodes.size();
|
||||
|
||||
if (isNodeAsOpDesc)
|
||||
{
|
||||
|
|
@ -273,6 +274,8 @@ namespace Dml::GraphDescBuilder
|
|||
ORT_THROW_HR_IF(E_UNEXPECTED, !graphNodeCreateInfo.nodesAsOperatorDesc[nodeIndex]);
|
||||
operatorGraphNodeIndexToMainGraphNodeIndexMap.emplace(nodeIndex, graphNodeCount++);
|
||||
}
|
||||
|
||||
graphNodes.resize(graphNodes.size() + graphNodeCreateInfo.nodeCount);
|
||||
}
|
||||
else
|
||||
{
|
||||
|
|
@ -281,7 +284,7 @@ namespace Dml::GraphDescBuilder
|
|||
ORT_THROW_HR_IF(E_UNEXPECTED, !graphNodeCreateInfo.nodesAsIDMLOperator[nodeIndex].Get());
|
||||
operatorGraphNodeIndexToMainGraphNodeIndexMap.emplace(nodeIndex, graphNodeCount++);
|
||||
NodeInfo nodeInfo = {};
|
||||
nodeInfo.op = std::move(graphNodeCreateInfo.nodesAsIDMLOperator[nodeIndex]);
|
||||
nodeInfo.nodeDef = std::move(graphNodeCreateInfo.nodesAsIDMLOperator[nodeIndex]);
|
||||
graphNodes.push_back(std::move(nodeInfo));
|
||||
}
|
||||
}
|
||||
|
|
@ -303,21 +306,58 @@ namespace Dml::GraphDescBuilder
|
|||
|
||||
const uint32_t dmlFusedNodeInputIndex = iter->second;
|
||||
|
||||
DML_INPUT_GRAPH_EDGE_DESC edge = {};
|
||||
edge.GraphInputIndex = dmlFusedNodeInputIndex;
|
||||
edge.ToNodeIndex = mainGraphNodeIndex;
|
||||
edge.ToNodeInputIndex = operatorGraphInputEdge.ToNodeInputIndex; // ?? might need to point inputIndex
|
||||
graphInputEdges.push_back(edge);
|
||||
|
||||
// If this is a constant input, set the appropriate flags on the desc
|
||||
if (isNodeAsOpDesc &&
|
||||
dmlFusedNodeInputIndex < isConstGpuGraphInputCount &&
|
||||
isConstGpuGraphInput[dmlFusedNodeInputIndex])
|
||||
{
|
||||
auto& graphInputNode = graphNodeCreateInfo.nodesAsOperatorDesc[operatorGraphInputEdge.ToNodeIndex];
|
||||
std::vector<DmlBufferTensorDesc*> toNodeInputTensorDescs = graphInputNode->GetInputTensors();
|
||||
DmlBufferTensorDesc* tensorDesc = toNodeInputTensorDescs[operatorGraphInputEdge.ToNodeInputIndex];
|
||||
tensorDesc->flags |= DML_TENSOR_FLAG_OWNED_BY_DML;
|
||||
// This is a highly inefficient approach to generating constant nodes. It duplicates constant data
|
||||
// across the graph input as well as every consumer's unique constant node. However it is currently
|
||||
// only used for small inputs.
|
||||
uint32_t c_maxConstNodeDataSize = 64;
|
||||
|
||||
ComPtr<OnnxTensorWrapper> constantInput = constantCpuGraphInputGetter(arg->Name());
|
||||
|
||||
if (constantInput && constantInput->GetTensorByteSize() < c_maxConstNodeDataSize)
|
||||
{
|
||||
std::vector<uint8_t> tensorData;
|
||||
tensorData.insert(
|
||||
tensorData.begin(),
|
||||
static_cast<const uint8_t*>(constantInput->GetData()),
|
||||
static_cast<const uint8_t*>(constantInput->GetData()) + constantInput->GetTensorByteSize());
|
||||
|
||||
NodeInfo nodeInfo = {};
|
||||
nodeInfo.nodeDef = std::move(tensorData);
|
||||
graphNodes.push_back(std::move(nodeInfo));
|
||||
|
||||
DML_INTERMEDIATE_GRAPH_EDGE_DESC edge = {};
|
||||
edge.FromNodeIndex = static_cast<UINT>(graphNodes.size() - 1);
|
||||
edge.FromNodeOutputIndex = 0;
|
||||
edge.ToNodeIndex = mainGraphNodeIndex;
|
||||
edge.ToNodeInputIndex = operatorGraphInputEdge.ToNodeInputIndex;
|
||||
graphIntermediateEdges.push_back(edge);
|
||||
}
|
||||
else
|
||||
{
|
||||
DML_INPUT_GRAPH_EDGE_DESC edge = {};
|
||||
edge.GraphInputIndex = dmlFusedNodeInputIndex;
|
||||
edge.ToNodeIndex = mainGraphNodeIndex;
|
||||
edge.ToNodeInputIndex = operatorGraphInputEdge.ToNodeInputIndex;
|
||||
graphInputEdges.push_back(edge);
|
||||
|
||||
auto& graphInputNode = graphNodeCreateInfo.nodesAsOperatorDesc[operatorGraphInputEdge.ToNodeIndex];
|
||||
std::vector<DmlBufferTensorDesc*> toNodeInputTensorDescs = graphInputNode->GetInputTensors();
|
||||
DmlBufferTensorDesc* tensorDesc = toNodeInputTensorDescs[operatorGraphInputEdge.ToNodeInputIndex];
|
||||
tensorDesc->flags |= DML_TENSOR_FLAG_OWNED_BY_DML;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
DML_INPUT_GRAPH_EDGE_DESC edge = {};
|
||||
edge.GraphInputIndex = dmlFusedNodeInputIndex;
|
||||
edge.ToNodeIndex = mainGraphNodeIndex;
|
||||
edge.ToNodeInputIndex = operatorGraphInputEdge.ToNodeInputIndex;
|
||||
graphInputEdges.push_back(edge);
|
||||
}
|
||||
}
|
||||
else
|
||||
|
|
@ -360,17 +400,19 @@ namespace Dml::GraphDescBuilder
|
|||
|
||||
if (isNodeAsOpDesc)
|
||||
{
|
||||
for (auto& opDesc : graphNodeCreateInfo.nodesAsOperatorDesc)
|
||||
for (uint32_t i = 0; i < graphNodeCreateInfo.nodesAsOperatorDesc.size(); ++i)
|
||||
{
|
||||
auto& opDesc = graphNodeCreateInfo.nodesAsOperatorDesc[i];
|
||||
|
||||
DML_OPERATOR_DESC dmlDesc = SchemaHelpers::ConvertOperatorDesc(*opDesc, &allocator);
|
||||
ComPtr<IDMLOperator> op;
|
||||
ORT_THROW_IF_FAILED(device->CreateOperator(&dmlDesc, IID_PPV_ARGS(&op)));
|
||||
allocator.Reset();
|
||||
|
||||
NodeInfo nodeInfo = {};
|
||||
nodeInfo.op = std::move(op);
|
||||
nodeInfo.nodeDef = std::move(op);
|
||||
nodeInfo.name = node.Name();
|
||||
graphNodes.push_back(std::move(nodeInfo));
|
||||
graphNodes[firstOpDescGraphNode + i] = std::move(nodeInfo);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -27,7 +27,7 @@ namespace Dml
|
|||
|
||||
struct NodeInfo
|
||||
{
|
||||
Microsoft::WRL::ComPtr<IDMLOperator> op;
|
||||
std::variant<Microsoft::WRL::ComPtr<IDMLOperator>, std::vector<uint8_t>> nodeDef;
|
||||
std::string name;
|
||||
};
|
||||
|
||||
|
|
|
|||
|
|
@ -339,6 +339,8 @@ class OnnxTensorWrapper : public WRL::Base<IMLOperatorTensor>, public Closable
|
|||
const onnxruntime::Tensor* GetInterface() const { return nullptr; }
|
||||
onnxruntime::Tensor* GetInterface() { return nullptr; }
|
||||
|
||||
size_t GetTensorByteSize() const { return m_tensorByteSize; }
|
||||
|
||||
private:
|
||||
size_t m_tensorByteSize = 0;
|
||||
std::unique_ptr<std::byte[]> m_unpackedTensor;
|
||||
|
|
|
|||
Loading…
Reference in a new issue