diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/FusedGraphKernel.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/FusedGraphKernel.cpp index 808a0b259a..3c6d1bcf4a 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/FusedGraphKernel.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/FusedGraphKernel.cpp @@ -19,7 +19,9 @@ namespace Dml FusedGraphKernel( const onnxruntime::OpKernelInfo& kernelInfo, const std::unordered_map &graphNodePropertyMap, - std::unordered_map& transferredInitializerMap) : OpKernel(kernelInfo) + std::unordered_map& transferredInitializerMap, + const gsl::span fusedNodeInputArgOriginalNames, + const gsl::span fusedNodeOutputArgOriginalNames) : OpKernel(kernelInfo) { // Get the graph for the function which was created according to the computational // capacity returned by the execution provider's graph partitioner @@ -45,14 +47,20 @@ namespace Dml ORT_THROW_IF_FAILED(providerExecutionObject.As(&m_winmlProvider)); } - TranslateAndCompileGraph(kernelInfo, graph, node.InputDefs(), node.OutputDefs(), graphNodePropertyMap, transferredInitializerMap); + TranslateAndCompileGraph( + kernelInfo, + graph, + fusedNodeInputArgOriginalNames, + fusedNodeOutputArgOriginalNames, + graphNodePropertyMap, + transferredInitializerMap); } void TranslateAndCompileGraph( const onnxruntime::OpKernelInfo& kernelInfo, const onnxruntime::Graph& graph, - const onnxruntime::ConstPointerContainer>& fusedNodeInputDefs, - const onnxruntime::ConstPointerContainer>& fusedNodeOutputDefs, + const gsl::span fusedNodeInputArgOriginalNames, + const gsl::span fusedNodeOutputArgOriginalNames, const std::unordered_map& graphNodePropertyMap, std::unordered_map& transferredInitializerMap ) @@ -68,7 +76,7 @@ namespace Dml m_inputsConstant.resize(graphInputCount); for (uint32_t i = 0; i < graphInputCount; ++i) { - m_inputsConstant[i] = GraphKernelHelper::GetGraphInputConstness(i, kernelInfo, fusedNodeInputDefs, transferredInitializerMap); + m_inputsConstant[i] = GraphKernelHelper::GetGraphInputConstness(i, kernelInfo, fusedNodeInputArgOriginalNames, transferredInitializerMap); } GraphDescBuilder::GraphDesc graphDesc = GraphDescBuilder::BuildGraphDesc( @@ -77,8 +85,8 @@ namespace Dml m_inputsConstant.size(), transferredInitializerMap, graph, - fusedNodeInputDefs, - fusedNodeOutputDefs, + fusedNodeInputArgOriginalNames, + fusedNodeOutputArgOriginalNames, graphNodePropertyMap, device.Get(), m_executionHandle); @@ -95,7 +103,7 @@ namespace Dml m_inputsConstant, kernelInfo, graphDesc, - fusedNodeInputDefs, + fusedNodeInputArgOriginalNames, m_inputsUsed, initInputBindings, initInputResources, @@ -511,9 +519,16 @@ namespace Dml onnxruntime::OpKernel* CreateFusedGraphKernel( const onnxruntime::OpKernelInfo& info, const std::unordered_map &graphNodePropertyMap, - std::unordered_map& transferredInitializerMap + std::unordered_map& transferredInitializerMap, + const gsl::span fusedNodeInputArgOriginalNames, + const gsl::span fusedNodeOutputArgOriginalNames ) { - return new FusedGraphKernel(info, graphNodePropertyMap, transferredInitializerMap); + return new FusedGraphKernel( + info, + graphNodePropertyMap, + transferredInitializerMap, + fusedNodeInputArgOriginalNames, + fusedNodeOutputArgOriginalNames); } } // namespace Dml diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/FusedGraphKernel.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/FusedGraphKernel.h index d025dc0077..45a08550b9 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/FusedGraphKernel.h +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/FusedGraphKernel.h @@ -8,7 +8,9 @@ namespace Dml { onnxruntime::OpKernel* CreateFusedGraphKernel( const onnxruntime::OpKernelInfo& info, - const std::unordered_map &graphNodePropertyMap, - std::unordered_map& transferredInitializerMap + const std::unordered_map& graphNodePropertyMap, + std::unordered_map& transferredInitializerMap, + const gsl::span fusedNodeInputArgOriginalNames, + const gsl::span fusedNodeOutputArgOriginalNames ); } // namespace Dml diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphDescBuilder.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphDescBuilder.cpp index ef0ade7a47..9db9b1f934 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphDescBuilder.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphDescBuilder.cpp @@ -38,8 +38,8 @@ namespace Dml::GraphDescBuilder const size_t isConstGpuGraphInputCount, std::unordered_map& transferredInitializerMap, const onnxruntime::Graph& graph, - const onnxruntime::ConstPointerContainer>& fusedNodeInputDefs, - const onnxruntime::ConstPointerContainer>& fusedNodeOutputDefs, + const gsl::span fusedNodeInputArgOriginalNames, + const gsl::span fusedNodeOutputArgOriginalNames, const std::unordered_map& graphNodePropertyMap, IDMLDevice* device, const void* executionHandle) @@ -56,10 +56,9 @@ namespace Dml::GraphDescBuilder // Map from Lotus node argument names to input indices of the fused kernel node. std::unordered_map nameToFusedNodeInputIndex; - for (size_t inputIndex = 0; inputIndex < fusedNodeInputDefs.size(); ++inputIndex) + for (size_t inputIndex = 0; inputIndex < fusedNodeInputArgOriginalNames.size(); ++inputIndex) { - const onnxruntime::NodeArg* graphInput = graph.GetNodeArg( - GraphKernelHelper::GetFusedNodeArgNameMatchingGraph(fusedNodeInputDefs[inputIndex]->Name())); + const onnxruntime::NodeArg* graphInput = graph.GetNodeArg(fusedNodeInputArgOriginalNames[inputIndex]); if (!graphInput) { @@ -168,13 +167,6 @@ namespace Dml::GraphDescBuilder { auto iter = nameToFusedNodeInputIndex.find(arg->Name()); - // The graph input could be missing the suffix, so try to match without it. - // This is part of a temporary workaround; see comments in GraphKernelHelper::GetFusedNodeArgNameMatchingGraph. - if (iter == nameToFusedNodeInputIndex.end()) - { - iter = nameToFusedNodeInputIndex.find(GraphKernelHelper::GetFusedNodeArgNameMatchingGraph(arg->Name())); - } - if (iter != nameToFusedNodeInputIndex.end()) { // This is a graph input @@ -240,10 +232,9 @@ namespace Dml::GraphDescBuilder assert(graphNodes.size() == orderedNodeIndices.size()); // Add graph output nodes, which might be in a different order from the encapsulating node - for (size_t outputIndex = 0; outputIndex < fusedNodeOutputDefs.size(); ++outputIndex) + for (size_t outputIndex = 0; outputIndex < fusedNodeOutputArgOriginalNames.size(); ++outputIndex) { - const onnxruntime::NodeArg* graphOutput = graph.GetNodeArg( - GraphKernelHelper::GetFusedNodeArgNameMatchingGraph(fusedNodeOutputDefs[outputIndex]->Name())); + const onnxruntime::NodeArg* graphOutput = graph.GetNodeArg(fusedNodeOutputArgOriginalNames[outputIndex]); ORT_THROW_HR_IF_NULL_MSG(E_POINTER, graphOutput, "FusedNode's nodeArgList does not contain one of the nodeArg"); const auto& outputNodeAndIndex = nameToNodeAndIndexMap.at(graphOutput->Name()); diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphDescBuilder.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphDescBuilder.h index 0d013e8fe4..e787332884 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphDescBuilder.h +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphDescBuilder.h @@ -45,8 +45,8 @@ namespace Dml const size_t isConstGpuGraphInputCount, std::unordered_map& transferredInitializerMap, const onnxruntime::Graph& graph, - const onnxruntime::ConstPointerContainer>& fusedNodeInputDefs, - const onnxruntime::ConstPointerContainer>& fusedNodeOutputDefs, + const gsl::span fusedNodeInputArgOriginalNames, + const gsl::span fusedNodeOutputArgOriginalNames, const std::unordered_map& graphNodePropertyMap, IDMLDevice* device, const void* executionHandle); diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphKernelHelper.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphKernelHelper.cpp index 1ad3080505..a55a55c952 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphKernelHelper.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphKernelHelper.cpp @@ -105,11 +105,11 @@ namespace GraphKernelHelper bool GetGraphInputConstness( uint32_t index, const onnxruntime::OpKernelInfo& kernelInfo, - const onnxruntime::ConstPointerContainer>& fusedNodeInputDefs, + const gsl::span fusedNodeInputArgOriginalNames, const std::unordered_map& transferredInitializerMap) { // Transferred initializers are uploaded to GPU memory - auto iter = transferredInitializerMap.find(GetFusedNodeArgNameMatchingGraph(fusedNodeInputDefs[index]->Name())); + auto iter = transferredInitializerMap.find(fusedNodeInputArgOriginalNames[index]); if (iter != transferredInitializerMap.end()) { return true; @@ -139,7 +139,7 @@ namespace GraphKernelHelper const std::vector& inputsConstant, const onnxruntime::OpKernelInfo& kernelInfo, const Dml::GraphDescBuilder::GraphDesc& graphDesc, - const onnxruntime::ConstPointerContainer>& fusedNodeInputDefs, + const gsl::span fusedNodeInputArgOriginalNames, _Out_ std::vector& inputsUsed, _Inout_ std::vector& initInputBindings, _Inout_ std::vector>& initInputResources, @@ -154,7 +154,7 @@ namespace GraphKernelHelper std::map initializerToLastInputIndexMap; for (uint32_t i = 0; i < graphInputCount; i++) { - auto iter = transferredInitializerMap.find(GetFusedNodeArgNameMatchingGraph(fusedNodeInputDefs[i]->Name())); + auto iter = transferredInitializerMap.find(fusedNodeInputArgOriginalNames[i]); if (iter != transferredInitializerMap.end()) { initializerToLastInputIndexMap[&iter->second] = i; } @@ -172,7 +172,7 @@ namespace GraphKernelHelper // initialization or execution). So just throw away the transferred initializer and skip this input. if (!inputsUsed[i]) { - transferredInitializerMap.erase(GetFusedNodeArgNameMatchingGraph(fusedNodeInputDefs[i]->Name())); + transferredInitializerMap.erase(fusedNodeInputArgOriginalNames[i]); if (inputRawData) { @@ -183,7 +183,7 @@ namespace GraphKernelHelper } // Look for the initializer among those transferred from the graph during partitioning - auto iter = transferredInitializerMap.find(GetFusedNodeArgNameMatchingGraph(fusedNodeInputDefs[i]->Name())); + auto iter = transferredInitializerMap.find(fusedNodeInputArgOriginalNames[i]); if (iter != transferredInitializerMap.end()) { std::byte* tensorPtr = nullptr; @@ -320,43 +320,5 @@ namespace GraphKernelHelper dmlGraphDesc.IntermediateEdgeCount = gsl::narrow_cast(dmlIntermediateEdges.size()); dmlGraphDesc.IntermediateEdges = dmlIntermediateEdges.data(); } - - // TODO: This is a hack which strips the suffix added within Lotus transforms that insert mem copies. - // This shouldn't be necessary if Lotus exposes the inputs/ouputs in the same order between the kernel - // for a function, and the graph for that function exposed as a kernel property. When the ordering - // mismatch is fixed (WindowsAI: 21114358, Lotus: 1953), this workaround should be removed. - std::string GetFusedNodeArgNameMatchingGraph(const std::string& fusedNodeArgeName) - { - const char* suffix = nullptr; - - // The suffix used when inserting mem copies is equal to the below, probably followed by an incrementing number. - if (!suffix) - { - suffix = strstr(fusedNodeArgeName.c_str(), "_DmlExecutionProvider_"); - } - - // The suffix used when inserting mem copies is equal to the below, not followed by an incrementing number. - if (!suffix) - { - suffix = strstr(fusedNodeArgeName.c_str(), "_DmlExecutionProvider"); - } - - if (!suffix) - { - suffix = strstr(fusedNodeArgeName.c_str(), "_token_"); - } - - if (suffix) - { - return std::string( - fusedNodeArgeName.begin(), - fusedNodeArgeName.begin() + (suffix - fusedNodeArgeName.c_str()) - ); - } - else - { - return fusedNodeArgeName; - } - } } // namespace GraphKernelHelper } // namespace Dml \ No newline at end of file diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphKernelHelper.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphKernelHelper.h index 0f11f03217..d6b6db8e87 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphKernelHelper.h +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphKernelHelper.h @@ -41,7 +41,7 @@ namespace GraphKernelHelper bool GetGraphInputConstness( uint32_t index, const onnxruntime::OpKernelInfo& kernelInfo, - const onnxruntime::ConstPointerContainer>& fusedNodeInputDefs, + const gsl::span fusedNodeInputArgOriginalNames, const std::unordered_map& transferredInitializerMap); void ProcessInputData( @@ -50,7 +50,7 @@ namespace GraphKernelHelper const std::vector& inputsConstant, const onnxruntime::OpKernelInfo& kernelInfo, const Dml::GraphDescBuilder::GraphDesc& graphDesc, - const onnxruntime::ConstPointerContainer>& fusedNodeInputDefs, + const gsl::span fusedNodeInputArgOriginalNames, _Out_ std::vector& inputsUsed, _Inout_ std::vector& initInputBindings, _Inout_ std::vector>& initInputResources, @@ -68,8 +68,6 @@ namespace GraphKernelHelper _Inout_ std::vector& dmlInputEdges, _Inout_ std::vector& dmlOutputEdges, _Inout_ std::vector& dmlIntermediateEdges); - - std::string GetFusedNodeArgNameMatchingGraph(const std::string& fusedNodeArgeName); - + } // namespace GraphKernelHelper } // namespace Dml \ No newline at end of file diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphPartitioner.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphPartitioner.cpp index ed30d9540e..1698733bf9 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphPartitioner.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphPartitioner.cpp @@ -537,10 +537,18 @@ namespace Dml printf("\n"); #endif - auto fused_kernel_func = [partitionNodePropsMap, transferredInitializerMap](onnxruntime::FuncManager& func_mgr, const onnxruntime::OpKernelInfo& info, std::unique_ptr& out) mutable ->onnxruntime::Status + // These nodeArgNames will be used while creating DML Graph inside FusedGraphKernel.cpp + // Ordering of input/output nodeArgs in below vector will be same as Node::Definitions::input_defs because + // ORT is populating these args as it is while creating the FusedNode at Graph::CreateFusedSubGraphNode() + // Why we need these names? + // After Partitioning and before reaching to FusedGraphKernel, ORT may modify the input/output nodeArg names + // present in FusedNode (Node::Definitions::input_defs) as part of some transformers like memcopy, or L1/L2/L3 transformers. + std::vector fusedNodeInputArgOriginalNames = def->inputs; + std::vector fusedNodeOutputArgOriginalNames = def->outputs; + auto fused_kernel_func = [partitionNodePropsMap, transferredInitializerMap, fusedNodeInputArgOriginalNames, fusedNodeOutputArgOriginalNames](onnxruntime::FuncManager& func_mgr, const onnxruntime::OpKernelInfo& info, std::unique_ptr& out) mutable ->onnxruntime::Status { - out.reset(CreateFusedGraphKernel(info, partitionNodePropsMap, *transferredInitializerMap)); - return Status::OK(); + out.reset(CreateFusedGraphKernel(info, partitionNodePropsMap, *transferredInitializerMap, fusedNodeInputArgOriginalNames, fusedNodeOutputArgOriginalNames)); + return Status::OK(); }; // build the kernel definition on the fly, and register it to the fused_kernel_regisitry.