diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphDescBuilder.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphDescBuilder.cpp index 9cd8f92ae2..2c0fb0f374 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphDescBuilder.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphDescBuilder.cpp @@ -3,43 +3,12 @@ #include "precomp.h" #include "GraphDescBuilder.h" +#include "GraphKernelHelper.h" using namespace Windows::AI::MachineLearning::Adapter; namespace Dml::GraphDescBuilder { - // TODO: This is a hack which strips the suffix added within Lotus transforms that insert mem copies. - // This shouldn't be necessary if Lotus exposes the inputs/ouputs in the same order between the kernel - // for a function, and the graph for that function exposed as a kernel property. When the ordering - // mismatch is fixed (WindowsAI: 21114358, Lotus: 1953), this workaround should be removed. - static std::string GetFusedNodeArgNameMatchingGraph(const std::string& fusedNodeArgeName) - { - const char* suffix = nullptr; - - // The suffix used when inserting mem copies is equal to the below, probably followed by an incrementing number. - if (!suffix) { - suffix = strstr(fusedNodeArgeName.c_str(), "_DmlExecutionProvider_"); - } - - // The suffix used when inserting mem copies is equal to the below, not followed by an incrementing number. - if (!suffix) { - suffix = strstr(fusedNodeArgeName.c_str(), "_DmlExecutionProvider"); - } - - if (!suffix) { - suffix = strstr(fusedNodeArgeName.c_str(), "_token_"); - } - - if (suffix) - { - return std::string( - fusedNodeArgeName.begin(), - fusedNodeArgeName.begin() + (suffix - fusedNodeArgeName.c_str()) - ); - } else { - return fusedNodeArgeName; - } - } const std::string& GetUniqueNodeName(const onnxruntime::Node& node) { @@ -85,7 +54,7 @@ namespace Dml::GraphDescBuilder for (size_t inputIndex = 0; inputIndex < fusedNodeInputDefs.size(); ++inputIndex) { const onnxruntime::NodeArg* graphInput = graph.GetNodeArg( - GetFusedNodeArgNameMatchingGraph(fusedNodeInputDefs[inputIndex]->Name())); + GraphKernelHelper::GetFusedNodeArgNameMatchingGraph(fusedNodeInputDefs[inputIndex]->Name())); if (!graphInput) { @@ -208,10 +177,10 @@ namespace Dml::GraphDescBuilder auto iter = nameToFusedNodeInputIndex.find(arg->Name()); // The graph input could be missing the suffix, so try to match without it. - // This is part of a temporary workaround; see comments in GetFusedNodeArgNameMatchingGraph. + // This is part of a temporary workaround; see comments in GraphKernelHelper::GetFusedNodeArgNameMatchingGraph. if (iter == nameToFusedNodeInputIndex.end()) { - iter = nameToFusedNodeInputIndex.find(GetFusedNodeArgNameMatchingGraph(arg->Name())); + iter = nameToFusedNodeInputIndex.find(GraphKernelHelper::GetFusedNodeArgNameMatchingGraph(arg->Name())); } if (iter != nameToFusedNodeInputIndex.end()) @@ -277,7 +246,7 @@ namespace Dml::GraphDescBuilder for (size_t outputIndex = 0; outputIndex < fusedNodeOutputDefs.size(); ++outputIndex) { const onnxruntime::NodeArg* graphOutput = graph.GetNodeArg( - GetFusedNodeArgNameMatchingGraph(fusedNodeOutputDefs[outputIndex]->Name())); + GraphKernelHelper::GetFusedNodeArgNameMatchingGraph(fusedNodeOutputDefs[outputIndex]->Name())); const auto& outputNodeAndIndex = nameToNodeAndIndexMap.at(graphOutput->Name()); diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphKernelHelper.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphKernelHelper.cpp index f6aa742ce0..92895d0814 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphKernelHelper.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphKernelHelper.cpp @@ -109,7 +109,7 @@ namespace GraphKernelHelper const std::unordered_map& transferredInitializerMap) { // Transferred initializers are uploaded to GPU memory - auto iter = transferredInitializerMap.find(fusedNodeInputDefs[index]->Name()); + auto iter = transferredInitializerMap.find(GetFusedNodeArgNameMatchingGraph(fusedNodeInputDefs[index]->Name())); if (iter != transferredInitializerMap.end()) { return true; @@ -154,7 +154,7 @@ namespace GraphKernelHelper std::map initializerToLastInputIndexMap; for (uint32_t i = 0; i < graphInputCount; i++) { - auto iter = transferredInitializerMap.find(fusedNodeInputDefs[i]->Name()); + auto iter = transferredInitializerMap.find(GetFusedNodeArgNameMatchingGraph(fusedNodeInputDefs[i]->Name())); if (iter != transferredInitializerMap.end()) { initializerToLastInputIndexMap[&iter->second] = i; } @@ -172,7 +172,7 @@ namespace GraphKernelHelper // initialization or execution). So just throw away the transferred initializer and skip this input. if (!inputsUsed[i]) { - transferredInitializerMap.erase(fusedNodeInputDefs[i]->Name()); + transferredInitializerMap.erase(GetFusedNodeArgNameMatchingGraph(fusedNodeInputDefs[i]->Name())); if (inputRawData) { @@ -183,7 +183,7 @@ namespace GraphKernelHelper } // Look for the initializer among those transferred from the graph during partitioning - auto iter = transferredInitializerMap.find(fusedNodeInputDefs[i]->Name()); + auto iter = transferredInitializerMap.find(GetFusedNodeArgNameMatchingGraph(fusedNodeInputDefs[i]->Name())); if (iter != transferredInitializerMap.end()) { std::byte* tensorPtr = nullptr; @@ -320,5 +320,43 @@ namespace GraphKernelHelper dmlGraphDesc.IntermediateEdgeCount = gsl::narrow_cast(dmlIntermediateEdges.size()); dmlGraphDesc.IntermediateEdges = dmlIntermediateEdges.data(); } + + // TODO: This is a hack which strips the suffix added within Lotus transforms that insert mem copies. + // This shouldn't be necessary if Lotus exposes the inputs/ouputs in the same order between the kernel + // for a function, and the graph for that function exposed as a kernel property. When the ordering + // mismatch is fixed (WindowsAI: 21114358, Lotus: 1953), this workaround should be removed. + std::string GetFusedNodeArgNameMatchingGraph(const std::string& fusedNodeArgeName) + { + const char* suffix = nullptr; + + // The suffix used when inserting mem copies is equal to the below, probably followed by an incrementing number. + if (!suffix) + { + suffix = strstr(fusedNodeArgeName.c_str(), "_DmlExecutionProvider_"); + } + + // The suffix used when inserting mem copies is equal to the below, not followed by an incrementing number. + if (!suffix) + { + suffix = strstr(fusedNodeArgeName.c_str(), "_DmlExecutionProvider"); + } + + if (!suffix) + { + suffix = strstr(fusedNodeArgeName.c_str(), "_token_"); + } + + if (suffix) + { + return std::string( + fusedNodeArgeName.begin(), + fusedNodeArgeName.begin() + (suffix - fusedNodeArgeName.c_str()) + ); + } + else + { + return fusedNodeArgeName; + } + } } // namespace GraphKernelHelper } // namespace Dml \ No newline at end of file diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphKernelHelper.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphKernelHelper.h index 6f271612c6..63efb58300 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphKernelHelper.h +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphKernelHelper.h @@ -63,6 +63,8 @@ namespace GraphKernelHelper _Out_ std::vector& dmlInputEdges, _Out_ std::vector& dmlOutputEdges, _Out_ std::vector& dmlIntermediateEdges); + + std::string GetFusedNodeArgNameMatchingGraph(const std::string& fusedNodeArgeName); } // namespace GraphKernelHelper } // namespace Dml \ No newline at end of file diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorBatchNormalization.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorBatchNormalization.cpp index 1d7f3d08ea..5f1a116802 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorBatchNormalization.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorBatchNormalization.cpp @@ -36,12 +36,14 @@ public: m_inputTensorDescs[0] = CreateTensorDescFromInput(kernelCreationContext, 0, TensorAxis::DoNotCoerce, TensorAxis::N, TensorAxis::LeftAligned); - // Massage each of these 1D tensors (of length C) into 4D tensors of the form [1,C,1,1]. + // Massage each of these 1D tensors (of length C) into ND tensors of the form [1,C,1,1,...]. for (uint32_t i = Scale; i < OnnxInputIndex::Count; ++i) { - m_inputTensorDescs[i] = CreateTensorDescFromInput(kernelCreationContext, i, TensorAxis::DoNotCoerce, TensorAxis::C, TensorAxis::LeftAligned); + m_inputTensorDescs[i] = CreateTensorDescFromInput(kernelCreationContext, i, TensorAxis::DoNotCoerce, TensorAxis::C, TensorAxis::LeftAligned, std::nullopt, m_inputTensorDescs[0].GetDimensionCount()); } + m_outputTensorDescs[0] = CreateTensorDescFromOutput(kernelCreationContext, 0, TensorAxis::DoNotCoerce, TensorAxis::N, TensorAxis::LeftAligned, std::nullopt, m_inputTensorDescs[0].GetDimensionCount()); + std::vector inputDescs = GetDmlInputDescs(); std::vector outputDescs = GetDmlOutputDescs(); diff --git a/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.cpp b/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.cpp index d21daa36c3..c79a984b2b 100644 --- a/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.cpp +++ b/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.cpp @@ -550,6 +550,13 @@ namespace OperatorHelper ++inDim1Iter; } + // 0-sized dimensions indicate an empty tensor and shouldn't be broadcasted to higher dimensions + if (inDimension0 == 0 || inDimension1 == 0) + { + inDimension0 = 0; + inDimension1 = 0; + } + ML_CHECK_VALID_ARGUMENT((inDimension0 == inDimension1) || (inDimension0 == 1) || (inDimension1 == 1)); *outDimIter = std::max(inDimension0, inDimension1); } diff --git a/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.h b/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.h index 3bd8064dc6..06250fddce 100644 --- a/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.h +++ b/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.h @@ -1273,7 +1273,7 @@ class TileHelper { ML_CHECK_VALID_ARGUMENT(repeatsTensor.IsCpuData(), "Tile's repeats tensor must be CPU Tensor."); for (size_t i = 0; i < dimCount; ++i) { - ML_CHECK_VALID_ARGUMENT(repeatsData[i] > 0, "Repeat values should be > 0."); + ML_CHECK_VALID_ARGUMENT(repeatsData[i] >= 0, "Repeat values should be >= 0."); m_repeatsData.push_back(gsl::narrow_cast(repeatsData[i])); } diff --git a/onnxruntime/core/providers/dml/dml_provider_factory.cc b/onnxruntime/core/providers/dml/dml_provider_factory.cc index 64d90ed0cc..ca85bc3383 100644 --- a/onnxruntime/core/providers/dml/dml_provider_factory.cc +++ b/onnxruntime/core/providers/dml/dml_provider_factory.cc @@ -111,7 +111,7 @@ std::shared_ptr CreateExecutionProviderFactory_DML(in D3D12_COMMAND_QUEUE_DESC cmd_queue_desc = {}; cmd_queue_desc.Type = D3D12_COMMAND_LIST_TYPE_DIRECT; - cmd_queue_desc.Flags = D3D12_COMMAND_QUEUE_FLAG_NONE; + cmd_queue_desc.Flags = D3D12_COMMAND_QUEUE_FLAG_DISABLE_GPU_TIMEOUT; ComPtr cmd_queue; THROW_IF_FAILED(d3d12_device->CreateCommandQueue(&cmd_queue_desc, IID_PPV_ARGS(&cmd_queue)));