DirectML Execution Provider integration 2020-11-13 (#5809)

* Merged PR 5253310: Fix 0-sized dimension broadcasting

Tensors that contain 0-sized dimensions were being broadcasted to higher dimensions, which would remove the possibility to remove them from the graph. 0-sized dimensions represent empty tensors, so whatever operator needs to broadcast it shouldn't try to call into DML.

* Merged PR 5334334: Fix asserts and failure in GraphKernelHelper.cpp

This extends a workaround needed to match node inputs with Tensors to the EP code handling constant input upload.

This was causing issues in a couple of models, including EfficientDet, although that model still fails due to this bug:
https://microsoft.visualstudio.com/OS/_workitems/edit/29970551

Related work items: #29706035

* Merged PR 5344477: Disable GPU timeouts in DML EP command queue creation

GPU timeouts have already been disabled in command queues created by Winml, but not the ones created by the DML EP within the ORT API

* Merged PR 5380534: BatchNormalization failure in autopilot - fix output size

New validation [here](https://microsoft.visualstudio.com/DefaultCollection/WindowsAI/_git/WindowsAI/pullrequest/5354070?_a=files&path=%2Fdml%2FSharedValidation%2FDmlBatchNormalizationOperatorValidator.h) causes some BatchNorm cases to fail (e.g. OnnxConformanceTestsTaef::BatchNormalization (BatchNormalization_2x2x2)). I'm unsure how long this bug existed, but based on Nick's investigation, it apparently still worked anyway.

Related work items: #27678610

* Merged PR 5386132: Update 8D BatchNorm

Update 8D BatchNorm

Related work items: #27678610

* Merged PR 5390213: Tile allow 0 in repeats

0 is valid in Tile in "repeats" parameter. The CPU kernel handles it fine. So should the DML EP.

Related work items: #29970551

Co-authored-by: Justin Stoecker
Co-authored-by: Jeff Bloomfield
Co-authored-by: Patrice Vignola
Co-authored-by: Nick Feeney
This commit is contained in:
Dwayne Robinson 2020-11-16 15:29:08 -08:00 committed by GitHub
parent 339348bc46
commit 732ffd12d2
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
7 changed files with 62 additions and 44 deletions

View file

@ -3,43 +3,12 @@
#include "precomp.h"
#include "GraphDescBuilder.h"
#include "GraphKernelHelper.h"
using namespace Windows::AI::MachineLearning::Adapter;
namespace Dml::GraphDescBuilder
{
// TODO: This is a hack which strips the suffix added within Lotus transforms that insert mem copies.
// This shouldn't be necessary if Lotus exposes the inputs/ouputs in the same order between the kernel
// for a function, and the graph for that function exposed as a kernel property. When the ordering
// mismatch is fixed (WindowsAI: 21114358, Lotus: 1953), this workaround should be removed.
static std::string GetFusedNodeArgNameMatchingGraph(const std::string& fusedNodeArgeName)
{
const char* suffix = nullptr;
// The suffix used when inserting mem copies is equal to the below, probably followed by an incrementing number.
if (!suffix) {
suffix = strstr(fusedNodeArgeName.c_str(), "_DmlExecutionProvider_");
}
// The suffix used when inserting mem copies is equal to the below, not followed by an incrementing number.
if (!suffix) {
suffix = strstr(fusedNodeArgeName.c_str(), "_DmlExecutionProvider");
}
if (!suffix) {
suffix = strstr(fusedNodeArgeName.c_str(), "_token_");
}
if (suffix)
{
return std::string(
fusedNodeArgeName.begin(),
fusedNodeArgeName.begin() + (suffix - fusedNodeArgeName.c_str())
);
} else {
return fusedNodeArgeName;
}
}
const std::string& GetUniqueNodeName(const onnxruntime::Node& node)
{
@ -85,7 +54,7 @@ namespace Dml::GraphDescBuilder
for (size_t inputIndex = 0; inputIndex < fusedNodeInputDefs.size(); ++inputIndex)
{
const onnxruntime::NodeArg* graphInput = graph.GetNodeArg(
GetFusedNodeArgNameMatchingGraph(fusedNodeInputDefs[inputIndex]->Name()));
GraphKernelHelper::GetFusedNodeArgNameMatchingGraph(fusedNodeInputDefs[inputIndex]->Name()));
if (!graphInput)
{
@ -208,10 +177,10 @@ namespace Dml::GraphDescBuilder
auto iter = nameToFusedNodeInputIndex.find(arg->Name());
// The graph input could be missing the suffix, so try to match without it.
// This is part of a temporary workaround; see comments in GetFusedNodeArgNameMatchingGraph.
// This is part of a temporary workaround; see comments in GraphKernelHelper::GetFusedNodeArgNameMatchingGraph.
if (iter == nameToFusedNodeInputIndex.end())
{
iter = nameToFusedNodeInputIndex.find(GetFusedNodeArgNameMatchingGraph(arg->Name()));
iter = nameToFusedNodeInputIndex.find(GraphKernelHelper::GetFusedNodeArgNameMatchingGraph(arg->Name()));
}
if (iter != nameToFusedNodeInputIndex.end())
@ -277,7 +246,7 @@ namespace Dml::GraphDescBuilder
for (size_t outputIndex = 0; outputIndex < fusedNodeOutputDefs.size(); ++outputIndex)
{
const onnxruntime::NodeArg* graphOutput = graph.GetNodeArg(
GetFusedNodeArgNameMatchingGraph(fusedNodeOutputDefs[outputIndex]->Name()));
GraphKernelHelper::GetFusedNodeArgNameMatchingGraph(fusedNodeOutputDefs[outputIndex]->Name()));
const auto& outputNodeAndIndex = nameToNodeAndIndexMap.at(graphOutput->Name());

View file

@ -109,7 +109,7 @@ namespace GraphKernelHelper
const std::unordered_map<std::string, onnx::TensorProto>& transferredInitializerMap)
{
// Transferred initializers are uploaded to GPU memory
auto iter = transferredInitializerMap.find(fusedNodeInputDefs[index]->Name());
auto iter = transferredInitializerMap.find(GetFusedNodeArgNameMatchingGraph(fusedNodeInputDefs[index]->Name()));
if (iter != transferredInitializerMap.end())
{
return true;
@ -154,7 +154,7 @@ namespace GraphKernelHelper
std::map<const onnx::TensorProto*, uint32_t> initializerToLastInputIndexMap;
for (uint32_t i = 0; i < graphInputCount; i++)
{
auto iter = transferredInitializerMap.find(fusedNodeInputDefs[i]->Name());
auto iter = transferredInitializerMap.find(GetFusedNodeArgNameMatchingGraph(fusedNodeInputDefs[i]->Name()));
if (iter != transferredInitializerMap.end()) {
initializerToLastInputIndexMap[&iter->second] = i;
}
@ -172,7 +172,7 @@ namespace GraphKernelHelper
// initialization or execution). So just throw away the transferred initializer and skip this input.
if (!inputsUsed[i])
{
transferredInitializerMap.erase(fusedNodeInputDefs[i]->Name());
transferredInitializerMap.erase(GetFusedNodeArgNameMatchingGraph(fusedNodeInputDefs[i]->Name()));
if (inputRawData)
{
@ -183,7 +183,7 @@ namespace GraphKernelHelper
}
// Look for the initializer among those transferred from the graph during partitioning
auto iter = transferredInitializerMap.find(fusedNodeInputDefs[i]->Name());
auto iter = transferredInitializerMap.find(GetFusedNodeArgNameMatchingGraph(fusedNodeInputDefs[i]->Name()));
if (iter != transferredInitializerMap.end())
{
std::byte* tensorPtr = nullptr;
@ -320,5 +320,43 @@ namespace GraphKernelHelper
dmlGraphDesc.IntermediateEdgeCount = gsl::narrow_cast<uint32_t>(dmlIntermediateEdges.size());
dmlGraphDesc.IntermediateEdges = dmlIntermediateEdges.data();
}
// TODO: This is a hack which strips the suffix added within Lotus transforms that insert mem copies.
// This shouldn't be necessary if Lotus exposes the inputs/ouputs in the same order between the kernel
// for a function, and the graph for that function exposed as a kernel property. When the ordering
// mismatch is fixed (WindowsAI: 21114358, Lotus: 1953), this workaround should be removed.
std::string GetFusedNodeArgNameMatchingGraph(const std::string& fusedNodeArgeName)
{
const char* suffix = nullptr;
// The suffix used when inserting mem copies is equal to the below, probably followed by an incrementing number.
if (!suffix)
{
suffix = strstr(fusedNodeArgeName.c_str(), "_DmlExecutionProvider_");
}
// The suffix used when inserting mem copies is equal to the below, not followed by an incrementing number.
if (!suffix)
{
suffix = strstr(fusedNodeArgeName.c_str(), "_DmlExecutionProvider");
}
if (!suffix)
{
suffix = strstr(fusedNodeArgeName.c_str(), "_token_");
}
if (suffix)
{
return std::string(
fusedNodeArgeName.begin(),
fusedNodeArgeName.begin() + (suffix - fusedNodeArgeName.c_str())
);
}
else
{
return fusedNodeArgeName;
}
}
} // namespace GraphKernelHelper
} // namespace Dml

View file

@ -63,6 +63,8 @@ namespace GraphKernelHelper
_Out_ std::vector<DML_GRAPH_EDGE_DESC>& dmlInputEdges,
_Out_ std::vector<DML_GRAPH_EDGE_DESC>& dmlOutputEdges,
_Out_ std::vector<DML_GRAPH_EDGE_DESC>& dmlIntermediateEdges);
std::string GetFusedNodeArgNameMatchingGraph(const std::string& fusedNodeArgeName);
} // namespace GraphKernelHelper
} // namespace Dml

View file

@ -36,12 +36,14 @@ public:
m_inputTensorDescs[0] = CreateTensorDescFromInput(kernelCreationContext, 0, TensorAxis::DoNotCoerce, TensorAxis::N, TensorAxis::LeftAligned);
// Massage each of these 1D tensors (of length C) into 4D tensors of the form [1,C,1,1].
// Massage each of these 1D tensors (of length C) into ND tensors of the form [1,C,1,1,...].
for (uint32_t i = Scale; i < OnnxInputIndex::Count; ++i)
{
m_inputTensorDescs[i] = CreateTensorDescFromInput(kernelCreationContext, i, TensorAxis::DoNotCoerce, TensorAxis::C, TensorAxis::LeftAligned);
m_inputTensorDescs[i] = CreateTensorDescFromInput(kernelCreationContext, i, TensorAxis::DoNotCoerce, TensorAxis::C, TensorAxis::LeftAligned, std::nullopt, m_inputTensorDescs[0].GetDimensionCount());
}
m_outputTensorDescs[0] = CreateTensorDescFromOutput(kernelCreationContext, 0, TensorAxis::DoNotCoerce, TensorAxis::N, TensorAxis::LeftAligned, std::nullopt, m_inputTensorDescs[0].GetDimensionCount());
std::vector<DML_TENSOR_DESC> inputDescs = GetDmlInputDescs();
std::vector<DML_TENSOR_DESC> outputDescs = GetDmlOutputDescs();

View file

@ -550,6 +550,13 @@ namespace OperatorHelper
++inDim1Iter;
}
// 0-sized dimensions indicate an empty tensor and shouldn't be broadcasted to higher dimensions
if (inDimension0 == 0 || inDimension1 == 0)
{
inDimension0 = 0;
inDimension1 = 0;
}
ML_CHECK_VALID_ARGUMENT((inDimension0 == inDimension1) || (inDimension0 == 1) || (inDimension1 == 1));
*outDimIter = std::max(inDimension0, inDimension1);
}

View file

@ -1273,7 +1273,7 @@ class TileHelper {
ML_CHECK_VALID_ARGUMENT(repeatsTensor.IsCpuData(), "Tile's repeats tensor must be CPU Tensor.");
for (size_t i = 0; i < dimCount; ++i) {
ML_CHECK_VALID_ARGUMENT(repeatsData[i] > 0, "Repeat values should be > 0.");
ML_CHECK_VALID_ARGUMENT(repeatsData[i] >= 0, "Repeat values should be >= 0.");
m_repeatsData.push_back(gsl::narrow_cast<uint32_t>(repeatsData[i]));
}

View file

@ -111,7 +111,7 @@ std::shared_ptr<IExecutionProviderFactory> CreateExecutionProviderFactory_DML(in
D3D12_COMMAND_QUEUE_DESC cmd_queue_desc = {};
cmd_queue_desc.Type = D3D12_COMMAND_LIST_TYPE_DIRECT;
cmd_queue_desc.Flags = D3D12_COMMAND_QUEUE_FLAG_NONE;
cmd_queue_desc.Flags = D3D12_COMMAND_QUEUE_FLAG_DISABLE_GPU_TIMEOUT;
ComPtr<ID3D12CommandQueue> cmd_queue;
THROW_IF_FAILED(d3d12_device->CreateCommandQueue(&cmd_queue_desc, IID_PPV_ARGS(&cmd_queue)));