From 4c8ef4080d5d6c7c8def4df473e5a9ba0949e597 Mon Sep 17 00:00:00 2001 From: Jeff Bloomfield Date: Fri, 11 Aug 2023 15:51:35 -0700 Subject: [PATCH] Utilize DML constant input graph node --- .../src/DmlGraphFusionHelper.cpp | 39 ++++++++++- .../src/GraphDescBuilder.cpp | 70 +++++++++++++++---- .../src/GraphDescBuilder.h | 2 +- .../src/MLOperatorAuthorImpl.h | 2 + 4 files changed, 96 insertions(+), 17 deletions(-) diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlGraphFusionHelper.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlGraphFusionHelper.cpp index 17aa197396..a7c2fcf178 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlGraphFusionHelper.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlGraphFusionHelper.cpp @@ -279,12 +279,30 @@ namespace DmlGraphFusionHelper return initializerPartitionMap; } + enum DML_PREVIEW_OPERATOR_TYPE + { + DML_PREVIEW_OPERATOR_FIRST = 0xC0000000, + }; + + enum DML_GRAPH_NODE_TYPE_PREVIEW + { + DML_GRAPH_NODE_TYPE_CONSTANT_DATA = 0xCC000000, + }; + + struct DML_CONSTANT_DATA_GRAPH_NODE_DESC_PREVIEW + { + const BYTE* data; + UINT64 dataSize; + _Field_z_ _Maybenull_ const char* Name; + }; + void ConvertGraphDesc( const Dml::GraphDescBuilder::GraphDesc& graphDesc, _Out_ DML_GRAPH_DESC& dmlGraphDesc, const uint32_t inputCount, const uint32_t outputCount, _Inout_ std::vector& dmlOperatorGraphNodes, + _Inout_ std::vector& dmlConstantGraphNodes, _Inout_ std::vector& dmlGraphNodes, _Inout_ std::vector& dmlInputEdges, _Inout_ std::vector& dmlOutputEdges, @@ -293,8 +311,22 @@ namespace DmlGraphFusionHelper for (size_t i = 0; i < graphDesc.nodes.size(); ++i) { auto& nodeInfo = graphDesc.nodes[i]; - dmlOperatorGraphNodes[i] = DML_OPERATOR_GRAPH_NODE_DESC{nodeInfo.op.Get(), nodeInfo.name.data()}; - dmlGraphNodes[i] = DML_GRAPH_NODE_DESC{DML_GRAPH_NODE_TYPE_OPERATOR, &dmlOperatorGraphNodes[i]}; + + if (std::holds_alternative>(nodeInfo.nodeDef)) + { + dmlOperatorGraphNodes[i] = DML_OPERATOR_GRAPH_NODE_DESC{std::get>(nodeInfo.nodeDef).Get(), nodeInfo.name.data()}; + dmlGraphNodes[i] = DML_GRAPH_NODE_DESC{DML_GRAPH_NODE_TYPE_OPERATOR, &dmlOperatorGraphNodes[i]}; + } + else + { + dmlConstantGraphNodes[i] = DML_CONSTANT_DATA_GRAPH_NODE_DESC_PREVIEW{ + std::get>(nodeInfo.nodeDef).data(), + std::get>(nodeInfo.nodeDef).size(), + nodeInfo.name.data() + }; + + dmlGraphNodes[i] = DML_GRAPH_NODE_DESC{(DML_GRAPH_NODE_TYPE) DML_GRAPH_NODE_TYPE_CONSTANT_DATA, &dmlConstantGraphNodes[i]}; + } } for (size_t i = 0; i < graphDesc.inputEdges.size(); ++i) @@ -360,6 +392,8 @@ namespace DmlGraphFusionHelper // convert DML EP GraphDesc into DML_GRAPH_DESC and create IDMLCompiledOperator DML_GRAPH_DESC dmlGraphDesc = {}; std::vector dmlOperatorGraphNodes(graphDesc.nodes.size()); + std::vector dmlConstantGraphNodes(graphDesc.nodes.size()); + std::vector dmlGraphNodes(graphDesc.nodes.size()); std::vector dmlInputEdges(graphDesc.inputEdges.size()); std::vector dmlOutputEdges(graphDesc.outputEdges.size()); @@ -370,6 +404,7 @@ namespace DmlGraphFusionHelper fusedNodeInputCount, fusedNodeOutputCount, dmlOperatorGraphNodes, + dmlConstantGraphNodes, dmlGraphNodes, dmlInputEdges, dmlOutputEdges, diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphDescBuilder.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphDescBuilder.cpp index 7b95ae9620..dd0721df72 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphDescBuilder.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphDescBuilder.cpp @@ -264,6 +264,7 @@ namespace Dml::GraphDescBuilder std::unordered_map operatorGraphNodeIndexToMainGraphNodeIndexMap; uint32_t graphNodeCount = gsl::narrow_cast(graphNodes.size()); const bool isNodeAsOpDesc = graphNodeCreateInfo.nodesAsOperatorDesc.size() > 0; + size_t firstOpDescGraphNode = graphNodes.size(); if (isNodeAsOpDesc) { @@ -273,6 +274,8 @@ namespace Dml::GraphDescBuilder ORT_THROW_HR_IF(E_UNEXPECTED, !graphNodeCreateInfo.nodesAsOperatorDesc[nodeIndex]); operatorGraphNodeIndexToMainGraphNodeIndexMap.emplace(nodeIndex, graphNodeCount++); } + + graphNodes.resize(graphNodes.size() + graphNodeCreateInfo.nodeCount); } else { @@ -281,7 +284,7 @@ namespace Dml::GraphDescBuilder ORT_THROW_HR_IF(E_UNEXPECTED, !graphNodeCreateInfo.nodesAsIDMLOperator[nodeIndex].Get()); operatorGraphNodeIndexToMainGraphNodeIndexMap.emplace(nodeIndex, graphNodeCount++); NodeInfo nodeInfo = {}; - nodeInfo.op = std::move(graphNodeCreateInfo.nodesAsIDMLOperator[nodeIndex]); + nodeInfo.nodeDef = std::move(graphNodeCreateInfo.nodesAsIDMLOperator[nodeIndex]); graphNodes.push_back(std::move(nodeInfo)); } } @@ -303,21 +306,58 @@ namespace Dml::GraphDescBuilder const uint32_t dmlFusedNodeInputIndex = iter->second; - DML_INPUT_GRAPH_EDGE_DESC edge = {}; - edge.GraphInputIndex = dmlFusedNodeInputIndex; - edge.ToNodeIndex = mainGraphNodeIndex; - edge.ToNodeInputIndex = operatorGraphInputEdge.ToNodeInputIndex; // ?? might need to point inputIndex - graphInputEdges.push_back(edge); - // If this is a constant input, set the appropriate flags on the desc if (isNodeAsOpDesc && dmlFusedNodeInputIndex < isConstGpuGraphInputCount && isConstGpuGraphInput[dmlFusedNodeInputIndex]) { - auto& graphInputNode = graphNodeCreateInfo.nodesAsOperatorDesc[operatorGraphInputEdge.ToNodeIndex]; - std::vector toNodeInputTensorDescs = graphInputNode->GetInputTensors(); - DmlBufferTensorDesc* tensorDesc = toNodeInputTensorDescs[operatorGraphInputEdge.ToNodeInputIndex]; - tensorDesc->flags |= DML_TENSOR_FLAG_OWNED_BY_DML; + // This is a highly inefficient approach to generating constant nodes. It duplicates constant data + // across the graph input as well as every consumer's unique constant node. However it is currently + // only used for small inputs. + uint32_t c_maxConstNodeDataSize = 64; + + ComPtr constantInput = constantCpuGraphInputGetter(arg->Name()); + + if (constantInput && constantInput->GetTensorByteSize() < c_maxConstNodeDataSize) + { + std::vector tensorData; + tensorData.insert( + tensorData.begin(), + static_cast(constantInput->GetData()), + static_cast(constantInput->GetData()) + constantInput->GetTensorByteSize()); + + NodeInfo nodeInfo = {}; + nodeInfo.nodeDef = std::move(tensorData); + graphNodes.push_back(std::move(nodeInfo)); + + DML_INTERMEDIATE_GRAPH_EDGE_DESC edge = {}; + edge.FromNodeIndex = static_cast(graphNodes.size() - 1); + edge.FromNodeOutputIndex = 0; + edge.ToNodeIndex = mainGraphNodeIndex; + edge.ToNodeInputIndex = operatorGraphInputEdge.ToNodeInputIndex; + graphIntermediateEdges.push_back(edge); + } + else + { + DML_INPUT_GRAPH_EDGE_DESC edge = {}; + edge.GraphInputIndex = dmlFusedNodeInputIndex; + edge.ToNodeIndex = mainGraphNodeIndex; + edge.ToNodeInputIndex = operatorGraphInputEdge.ToNodeInputIndex; + graphInputEdges.push_back(edge); + + auto& graphInputNode = graphNodeCreateInfo.nodesAsOperatorDesc[operatorGraphInputEdge.ToNodeIndex]; + std::vector toNodeInputTensorDescs = graphInputNode->GetInputTensors(); + DmlBufferTensorDesc* tensorDesc = toNodeInputTensorDescs[operatorGraphInputEdge.ToNodeInputIndex]; + tensorDesc->flags |= DML_TENSOR_FLAG_OWNED_BY_DML; + } + } + else + { + DML_INPUT_GRAPH_EDGE_DESC edge = {}; + edge.GraphInputIndex = dmlFusedNodeInputIndex; + edge.ToNodeIndex = mainGraphNodeIndex; + edge.ToNodeInputIndex = operatorGraphInputEdge.ToNodeInputIndex; + graphInputEdges.push_back(edge); } } else @@ -360,17 +400,19 @@ namespace Dml::GraphDescBuilder if (isNodeAsOpDesc) { - for (auto& opDesc : graphNodeCreateInfo.nodesAsOperatorDesc) + for (uint32_t i = 0; i < graphNodeCreateInfo.nodesAsOperatorDesc.size(); ++i) { + auto& opDesc = graphNodeCreateInfo.nodesAsOperatorDesc[i]; + DML_OPERATOR_DESC dmlDesc = SchemaHelpers::ConvertOperatorDesc(*opDesc, &allocator); ComPtr op; ORT_THROW_IF_FAILED(device->CreateOperator(&dmlDesc, IID_PPV_ARGS(&op))); allocator.Reset(); NodeInfo nodeInfo = {}; - nodeInfo.op = std::move(op); + nodeInfo.nodeDef = std::move(op); nodeInfo.name = node.Name(); - graphNodes.push_back(std::move(nodeInfo)); + graphNodes[firstOpDescGraphNode + i] = std::move(nodeInfo); } } } diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphDescBuilder.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphDescBuilder.h index 5c04962e55..5473e590dc 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphDescBuilder.h +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphDescBuilder.h @@ -27,7 +27,7 @@ namespace Dml struct NodeInfo { - Microsoft::WRL::ComPtr op; + std::variant, std::vector> nodeDef; std::string name; }; diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/MLOperatorAuthorImpl.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/MLOperatorAuthorImpl.h index e7cc1e7356..a4caf89a33 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/MLOperatorAuthorImpl.h +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/MLOperatorAuthorImpl.h @@ -339,6 +339,8 @@ class OnnxTensorWrapper : public WRL::Base, public Closable const onnxruntime::Tensor* GetInterface() const { return nullptr; } onnxruntime::Tensor* GetInterface() { return nullptr; } + size_t GetTensorByteSize() const { return m_tensorByteSize; } + private: size_t m_tensorByteSize = 0; std::unique_ptr m_unpackedTensor;