From 732ffd12d2103c7f9bbd8dc77b9a3e0dadafd7ad Mon Sep 17 00:00:00 2001 From: Dwayne Robinson Date: Mon, 16 Nov 2020 15:29:08 -0800 Subject: [PATCH] DirectML Execution Provider integration 2020-11-13 (#5809) * Merged PR 5253310: Fix 0-sized dimension broadcasting Tensors that contain 0-sized dimensions were being broadcasted to higher dimensions, which would remove the possibility to remove them from the graph. 0-sized dimensions represent empty tensors, so whatever operator needs to broadcast it shouldn't try to call into DML. * Merged PR 5334334: Fix asserts and failure in GraphKernelHelper.cpp This extends a workaround needed to match node inputs with Tensors to the EP code handling constant input upload. This was causing issues in a couple of models, including EfficientDet, although that model still fails due to this bug: https://microsoft.visualstudio.com/OS/_workitems/edit/29970551 Related work items: #29706035 * Merged PR 5344477: Disable GPU timeouts in DML EP command queue creation GPU timeouts have already been disabled in command queues created by Winml, but not the ones created by the DML EP within the ORT API * Merged PR 5380534: BatchNormalization failure in autopilot - fix output size New validation [here](https://microsoft.visualstudio.com/DefaultCollection/WindowsAI/_git/WindowsAI/pullrequest/5354070?_a=files&path=%2Fdml%2FSharedValidation%2FDmlBatchNormalizationOperatorValidator.h) causes some BatchNorm cases to fail (e.g. OnnxConformanceTestsTaef::BatchNormalization (BatchNormalization_2x2x2)). I'm unsure how long this bug existed, but based on Nick's investigation, it apparently still worked anyway. Related work items: #27678610 * Merged PR 5386132: Update 8D BatchNorm Update 8D BatchNorm Related work items: #27678610 * Merged PR 5390213: Tile allow 0 in repeats 0 is valid in Tile in "repeats" parameter. The CPU kernel handles it fine. So should the DML EP. Related work items: #29970551 Co-authored-by: Justin Stoecker Co-authored-by: Jeff Bloomfield Co-authored-by: Patrice Vignola Co-authored-by: Nick Feeney --- .../src/GraphDescBuilder.cpp | 41 ++--------------- .../src/GraphKernelHelper.cpp | 46 +++++++++++++++++-- .../src/GraphKernelHelper.h | 2 + .../DmlOperatorBatchNormalization.cpp | 6 ++- .../OperatorAuthorHelper/OperatorHelper.cpp | 7 +++ .../dml/OperatorAuthorHelper/OperatorHelper.h | 2 +- .../providers/dml/dml_provider_factory.cc | 2 +- 7 files changed, 62 insertions(+), 44 deletions(-) diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphDescBuilder.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphDescBuilder.cpp index 9cd8f92ae2..2c0fb0f374 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphDescBuilder.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphDescBuilder.cpp @@ -3,43 +3,12 @@ #include "precomp.h" #include "GraphDescBuilder.h" +#include "GraphKernelHelper.h" using namespace Windows::AI::MachineLearning::Adapter; namespace Dml::GraphDescBuilder { - // TODO: This is a hack which strips the suffix added within Lotus transforms that insert mem copies. - // This shouldn't be necessary if Lotus exposes the inputs/ouputs in the same order between the kernel - // for a function, and the graph for that function exposed as a kernel property. When the ordering - // mismatch is fixed (WindowsAI: 21114358, Lotus: 1953), this workaround should be removed. - static std::string GetFusedNodeArgNameMatchingGraph(const std::string& fusedNodeArgeName) - { - const char* suffix = nullptr; - - // The suffix used when inserting mem copies is equal to the below, probably followed by an incrementing number. - if (!suffix) { - suffix = strstr(fusedNodeArgeName.c_str(), "_DmlExecutionProvider_"); - } - - // The suffix used when inserting mem copies is equal to the below, not followed by an incrementing number. - if (!suffix) { - suffix = strstr(fusedNodeArgeName.c_str(), "_DmlExecutionProvider"); - } - - if (!suffix) { - suffix = strstr(fusedNodeArgeName.c_str(), "_token_"); - } - - if (suffix) - { - return std::string( - fusedNodeArgeName.begin(), - fusedNodeArgeName.begin() + (suffix - fusedNodeArgeName.c_str()) - ); - } else { - return fusedNodeArgeName; - } - } const std::string& GetUniqueNodeName(const onnxruntime::Node& node) { @@ -85,7 +54,7 @@ namespace Dml::GraphDescBuilder for (size_t inputIndex = 0; inputIndex < fusedNodeInputDefs.size(); ++inputIndex) { const onnxruntime::NodeArg* graphInput = graph.GetNodeArg( - GetFusedNodeArgNameMatchingGraph(fusedNodeInputDefs[inputIndex]->Name())); + GraphKernelHelper::GetFusedNodeArgNameMatchingGraph(fusedNodeInputDefs[inputIndex]->Name())); if (!graphInput) { @@ -208,10 +177,10 @@ namespace Dml::GraphDescBuilder auto iter = nameToFusedNodeInputIndex.find(arg->Name()); // The graph input could be missing the suffix, so try to match without it. - // This is part of a temporary workaround; see comments in GetFusedNodeArgNameMatchingGraph. + // This is part of a temporary workaround; see comments in GraphKernelHelper::GetFusedNodeArgNameMatchingGraph. if (iter == nameToFusedNodeInputIndex.end()) { - iter = nameToFusedNodeInputIndex.find(GetFusedNodeArgNameMatchingGraph(arg->Name())); + iter = nameToFusedNodeInputIndex.find(GraphKernelHelper::GetFusedNodeArgNameMatchingGraph(arg->Name())); } if (iter != nameToFusedNodeInputIndex.end()) @@ -277,7 +246,7 @@ namespace Dml::GraphDescBuilder for (size_t outputIndex = 0; outputIndex < fusedNodeOutputDefs.size(); ++outputIndex) { const onnxruntime::NodeArg* graphOutput = graph.GetNodeArg( - GetFusedNodeArgNameMatchingGraph(fusedNodeOutputDefs[outputIndex]->Name())); + GraphKernelHelper::GetFusedNodeArgNameMatchingGraph(fusedNodeOutputDefs[outputIndex]->Name())); const auto& outputNodeAndIndex = nameToNodeAndIndexMap.at(graphOutput->Name()); diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphKernelHelper.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphKernelHelper.cpp index f6aa742ce0..92895d0814 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphKernelHelper.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphKernelHelper.cpp @@ -109,7 +109,7 @@ namespace GraphKernelHelper const std::unordered_map& transferredInitializerMap) { // Transferred initializers are uploaded to GPU memory - auto iter = transferredInitializerMap.find(fusedNodeInputDefs[index]->Name()); + auto iter = transferredInitializerMap.find(GetFusedNodeArgNameMatchingGraph(fusedNodeInputDefs[index]->Name())); if (iter != transferredInitializerMap.end()) { return true; @@ -154,7 +154,7 @@ namespace GraphKernelHelper std::map initializerToLastInputIndexMap; for (uint32_t i = 0; i < graphInputCount; i++) { - auto iter = transferredInitializerMap.find(fusedNodeInputDefs[i]->Name()); + auto iter = transferredInitializerMap.find(GetFusedNodeArgNameMatchingGraph(fusedNodeInputDefs[i]->Name())); if (iter != transferredInitializerMap.end()) { initializerToLastInputIndexMap[&iter->second] = i; } @@ -172,7 +172,7 @@ namespace GraphKernelHelper // initialization or execution). So just throw away the transferred initializer and skip this input. if (!inputsUsed[i]) { - transferredInitializerMap.erase(fusedNodeInputDefs[i]->Name()); + transferredInitializerMap.erase(GetFusedNodeArgNameMatchingGraph(fusedNodeInputDefs[i]->Name())); if (inputRawData) { @@ -183,7 +183,7 @@ namespace GraphKernelHelper } // Look for the initializer among those transferred from the graph during partitioning - auto iter = transferredInitializerMap.find(fusedNodeInputDefs[i]->Name()); + auto iter = transferredInitializerMap.find(GetFusedNodeArgNameMatchingGraph(fusedNodeInputDefs[i]->Name())); if (iter != transferredInitializerMap.end()) { std::byte* tensorPtr = nullptr; @@ -320,5 +320,43 @@ namespace GraphKernelHelper dmlGraphDesc.IntermediateEdgeCount = gsl::narrow_cast(dmlIntermediateEdges.size()); dmlGraphDesc.IntermediateEdges = dmlIntermediateEdges.data(); } + + // TODO: This is a hack which strips the suffix added within Lotus transforms that insert mem copies. + // This shouldn't be necessary if Lotus exposes the inputs/ouputs in the same order between the kernel + // for a function, and the graph for that function exposed as a kernel property. When the ordering + // mismatch is fixed (WindowsAI: 21114358, Lotus: 1953), this workaround should be removed. + std::string GetFusedNodeArgNameMatchingGraph(const std::string& fusedNodeArgeName) + { + const char* suffix = nullptr; + + // The suffix used when inserting mem copies is equal to the below, probably followed by an incrementing number. + if (!suffix) + { + suffix = strstr(fusedNodeArgeName.c_str(), "_DmlExecutionProvider_"); + } + + // The suffix used when inserting mem copies is equal to the below, not followed by an incrementing number. + if (!suffix) + { + suffix = strstr(fusedNodeArgeName.c_str(), "_DmlExecutionProvider"); + } + + if (!suffix) + { + suffix = strstr(fusedNodeArgeName.c_str(), "_token_"); + } + + if (suffix) + { + return std::string( + fusedNodeArgeName.begin(), + fusedNodeArgeName.begin() + (suffix - fusedNodeArgeName.c_str()) + ); + } + else + { + return fusedNodeArgeName; + } + } } // namespace GraphKernelHelper } // namespace Dml \ No newline at end of file diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphKernelHelper.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphKernelHelper.h index 6f271612c6..63efb58300 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphKernelHelper.h +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphKernelHelper.h @@ -63,6 +63,8 @@ namespace GraphKernelHelper _Out_ std::vector& dmlInputEdges, _Out_ std::vector& dmlOutputEdges, _Out_ std::vector& dmlIntermediateEdges); + + std::string GetFusedNodeArgNameMatchingGraph(const std::string& fusedNodeArgeName); } // namespace GraphKernelHelper } // namespace Dml \ No newline at end of file diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorBatchNormalization.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorBatchNormalization.cpp index 1d7f3d08ea..5f1a116802 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorBatchNormalization.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorBatchNormalization.cpp @@ -36,12 +36,14 @@ public: m_inputTensorDescs[0] = CreateTensorDescFromInput(kernelCreationContext, 0, TensorAxis::DoNotCoerce, TensorAxis::N, TensorAxis::LeftAligned); - // Massage each of these 1D tensors (of length C) into 4D tensors of the form [1,C,1,1]. + // Massage each of these 1D tensors (of length C) into ND tensors of the form [1,C,1,1,...]. for (uint32_t i = Scale; i < OnnxInputIndex::Count; ++i) { - m_inputTensorDescs[i] = CreateTensorDescFromInput(kernelCreationContext, i, TensorAxis::DoNotCoerce, TensorAxis::C, TensorAxis::LeftAligned); + m_inputTensorDescs[i] = CreateTensorDescFromInput(kernelCreationContext, i, TensorAxis::DoNotCoerce, TensorAxis::C, TensorAxis::LeftAligned, std::nullopt, m_inputTensorDescs[0].GetDimensionCount()); } + m_outputTensorDescs[0] = CreateTensorDescFromOutput(kernelCreationContext, 0, TensorAxis::DoNotCoerce, TensorAxis::N, TensorAxis::LeftAligned, std::nullopt, m_inputTensorDescs[0].GetDimensionCount()); + std::vector inputDescs = GetDmlInputDescs(); std::vector outputDescs = GetDmlOutputDescs(); diff --git a/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.cpp b/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.cpp index d21daa36c3..c79a984b2b 100644 --- a/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.cpp +++ b/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.cpp @@ -550,6 +550,13 @@ namespace OperatorHelper ++inDim1Iter; } + // 0-sized dimensions indicate an empty tensor and shouldn't be broadcasted to higher dimensions + if (inDimension0 == 0 || inDimension1 == 0) + { + inDimension0 = 0; + inDimension1 = 0; + } + ML_CHECK_VALID_ARGUMENT((inDimension0 == inDimension1) || (inDimension0 == 1) || (inDimension1 == 1)); *outDimIter = std::max(inDimension0, inDimension1); } diff --git a/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.h b/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.h index 3bd8064dc6..06250fddce 100644 --- a/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.h +++ b/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.h @@ -1273,7 +1273,7 @@ class TileHelper { ML_CHECK_VALID_ARGUMENT(repeatsTensor.IsCpuData(), "Tile's repeats tensor must be CPU Tensor."); for (size_t i = 0; i < dimCount; ++i) { - ML_CHECK_VALID_ARGUMENT(repeatsData[i] > 0, "Repeat values should be > 0."); + ML_CHECK_VALID_ARGUMENT(repeatsData[i] >= 0, "Repeat values should be >= 0."); m_repeatsData.push_back(gsl::narrow_cast(repeatsData[i])); } diff --git a/onnxruntime/core/providers/dml/dml_provider_factory.cc b/onnxruntime/core/providers/dml/dml_provider_factory.cc index 64d90ed0cc..ca85bc3383 100644 --- a/onnxruntime/core/providers/dml/dml_provider_factory.cc +++ b/onnxruntime/core/providers/dml/dml_provider_factory.cc @@ -111,7 +111,7 @@ std::shared_ptr CreateExecutionProviderFactory_DML(in D3D12_COMMAND_QUEUE_DESC cmd_queue_desc = {}; cmd_queue_desc.Type = D3D12_COMMAND_LIST_TYPE_DIRECT; - cmd_queue_desc.Flags = D3D12_COMMAND_QUEUE_FLAG_NONE; + cmd_queue_desc.Flags = D3D12_COMMAND_QUEUE_FLAG_DISABLE_GPU_TIMEOUT; ComPtr cmd_queue; THROW_IF_FAILED(d3d12_device->CreateCommandQueue(&cmd_queue_desc, IID_PPV_ARGS(&cmd_queue)));