From f664f912980ac7adce2cf3ecaade6a70cfb97401 Mon Sep 17 00:00:00 2001 From: Sumit Agarwal Date: Thu, 18 Apr 2024 11:23:13 -0700 Subject: [PATCH] [DML EP] Expose NPU macro via build command (#20306) ### Description This fixes following things: - Expose `ENABLE_NPU_ADAPTER_ENUMERATION` macro via build command, so that a user can enable NPU support for DML EP seamlessly. - Add keyword `_dmlEp_` as part of the node name, which would be useful for debugging purpose. ### Motivation and Context --- cmake/CMakeLists.txt | 3 +++ .../DmlExecutionProvider/src/GraphDescBuilder.cpp | 12 ++++++------ 2 files changed, 9 insertions(+), 6 deletions(-) diff --git a/cmake/CMakeLists.txt b/cmake/CMakeLists.txt index 74edaad762..a5cadc937e 100644 --- a/cmake/CMakeLists.txt +++ b/cmake/CMakeLists.txt @@ -848,6 +848,9 @@ if (onnxruntime_USE_DML) list(APPEND ORT_PROVIDER_FLAGS -DUSE_DML=1) list(APPEND ORT_PROVIDER_CMAKE_FLAGS -Donnxruntime_USE_DML=1) list(APPEND ONNXRUNTIME_PROVIDER_NAMES dml) + if(onnxruntime_ENABLE_NPU_ADAPTER_ENUMERATION) + list(APPEND ORT_PROVIDER_FLAGS -DENABLE_NPU_ADAPTER_ENUMERATION=1) + endif() endif() if (onnxruntime_USE_MIGRAPHX) list(APPEND ORT_PROVIDER_FLAGS -DUSE_MIGRAPHX=1) diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphDescBuilder.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphDescBuilder.cpp index 3b0dbd5425..1c82c40a88 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphDescBuilder.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphDescBuilder.cpp @@ -194,7 +194,7 @@ namespace Dml::GraphDescBuilder uint32_t SetAndGetDmlGraphNodeIndex( const uint32_t operatorDmlGraphNodeIndex, - const std::string& nodeNamePrefix, + const onnxruntime::Node& node, AbstractOperatorDesc& operatorDesc, /*in_out*/std::unordered_map& operatorDmlGraphToDmlGraphNodeIndexMap, /*in_out*/std::vector& dmlGraphNodes) @@ -205,7 +205,7 @@ namespace Dml::GraphDescBuilder return iter->second; } operatorDmlGraphToDmlGraphNodeIndexMap[operatorDmlGraphNodeIndex] = static_cast(dmlGraphNodes.size()); - dmlGraphNodes.push_back({operatorDesc, nodeNamePrefix + std::to_string(operatorDmlGraphNodeIndex)}); + dmlGraphNodes.push_back({operatorDesc, GetUniqueNodeName(node) + "_dmlEp_" + std::to_string(operatorDmlGraphNodeIndex)}); return operatorDmlGraphToDmlGraphNodeIndexMap[operatorDmlGraphNodeIndex]; } @@ -432,7 +432,7 @@ namespace Dml::GraphDescBuilder { uint32_t dmlGraphNodeIndex = SetAndGetDmlGraphNodeIndex( operatorDmlGraphInputEdge.ToNodeIndex, - node.Name(), + node, *operatorDmlGraphCreateInfo.nodes[operatorDmlGraphInputEdge.ToNodeIndex], operatorDmlGraphToDmlGraphNodeIndexMap, dmlGraphNodes); @@ -508,13 +508,13 @@ namespace Dml::GraphDescBuilder DmlIntermediateSerializedGraphEdge edge = {}; uint32_t shiftedFromNodeIndex = SetAndGetDmlGraphNodeIndex( operatorGraphIntermediateEdge.FromNodeIndex, - node.Name(), + node, *operatorDmlGraphCreateInfo.nodes[operatorGraphIntermediateEdge.FromNodeIndex], operatorDmlGraphToDmlGraphNodeIndexMap, dmlGraphNodes); uint32_t shiftedToNodeIndex = SetAndGetDmlGraphNodeIndex( operatorGraphIntermediateEdge.ToNodeIndex, - node.Name(), + node, *operatorDmlGraphCreateInfo.nodes[operatorGraphIntermediateEdge.ToNodeIndex], operatorDmlGraphToDmlGraphNodeIndexMap, dmlGraphNodes); @@ -535,7 +535,7 @@ namespace Dml::GraphDescBuilder { uint32_t shiftedNodeIndex = SetAndGetDmlGraphNodeIndex( operatorGraphOutputEdge.FromNodeIndex, - node.Name(), + node, *operatorDmlGraphCreateInfo.nodes[operatorGraphOutputEdge.FromNodeIndex], operatorDmlGraphToDmlGraphNodeIndexMap, dmlGraphNodes);