mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-01 23:30:35 +00:00
DirectML Execution Provider integration 2020-11-13 (#5809)
* Merged PR 5253310: Fix 0-sized dimension broadcasting Tensors that contain 0-sized dimensions were being broadcasted to higher dimensions, which would remove the possibility to remove them from the graph. 0-sized dimensions represent empty tensors, so whatever operator needs to broadcast it shouldn't try to call into DML. * Merged PR 5334334: Fix asserts and failure in GraphKernelHelper.cpp This extends a workaround needed to match node inputs with Tensors to the EP code handling constant input upload. This was causing issues in a couple of models, including EfficientDet, although that model still fails due to this bug: https://microsoft.visualstudio.com/OS/_workitems/edit/29970551 Related work items: #29706035 * Merged PR 5344477: Disable GPU timeouts in DML EP command queue creation GPU timeouts have already been disabled in command queues created by Winml, but not the ones created by the DML EP within the ORT API * Merged PR 5380534: BatchNormalization failure in autopilot - fix output size New validation [here](https://microsoft.visualstudio.com/DefaultCollection/WindowsAI/_git/WindowsAI/pullrequest/5354070?_a=files&path=%2Fdml%2FSharedValidation%2FDmlBatchNormalizationOperatorValidator.h) causes some BatchNorm cases to fail (e.g. OnnxConformanceTestsTaef::BatchNormalization (BatchNormalization_2x2x2)). I'm unsure how long this bug existed, but based on Nick's investigation, it apparently still worked anyway. Related work items: #27678610 * Merged PR 5386132: Update 8D BatchNorm Update 8D BatchNorm Related work items: #27678610 * Merged PR 5390213: Tile allow 0 in repeats 0 is valid in Tile in "repeats" parameter. The CPU kernel handles it fine. So should the DML EP. Related work items: #29970551 Co-authored-by: Justin Stoecker Co-authored-by: Jeff Bloomfield Co-authored-by: Patrice Vignola Co-authored-by: Nick Feeney
This commit is contained in:
parent
339348bc46
commit
732ffd12d2
7 changed files with 62 additions and 44 deletions
|
|
@ -3,43 +3,12 @@
|
|||
|
||||
#include "precomp.h"
|
||||
#include "GraphDescBuilder.h"
|
||||
#include "GraphKernelHelper.h"
|
||||
|
||||
using namespace Windows::AI::MachineLearning::Adapter;
|
||||
|
||||
namespace Dml::GraphDescBuilder
|
||||
{
|
||||
// TODO: This is a hack which strips the suffix added within Lotus transforms that insert mem copies.
|
||||
// This shouldn't be necessary if Lotus exposes the inputs/ouputs in the same order between the kernel
|
||||
// for a function, and the graph for that function exposed as a kernel property. When the ordering
|
||||
// mismatch is fixed (WindowsAI: 21114358, Lotus: 1953), this workaround should be removed.
|
||||
static std::string GetFusedNodeArgNameMatchingGraph(const std::string& fusedNodeArgeName)
|
||||
{
|
||||
const char* suffix = nullptr;
|
||||
|
||||
// The suffix used when inserting mem copies is equal to the below, probably followed by an incrementing number.
|
||||
if (!suffix) {
|
||||
suffix = strstr(fusedNodeArgeName.c_str(), "_DmlExecutionProvider_");
|
||||
}
|
||||
|
||||
// The suffix used when inserting mem copies is equal to the below, not followed by an incrementing number.
|
||||
if (!suffix) {
|
||||
suffix = strstr(fusedNodeArgeName.c_str(), "_DmlExecutionProvider");
|
||||
}
|
||||
|
||||
if (!suffix) {
|
||||
suffix = strstr(fusedNodeArgeName.c_str(), "_token_");
|
||||
}
|
||||
|
||||
if (suffix)
|
||||
{
|
||||
return std::string(
|
||||
fusedNodeArgeName.begin(),
|
||||
fusedNodeArgeName.begin() + (suffix - fusedNodeArgeName.c_str())
|
||||
);
|
||||
} else {
|
||||
return fusedNodeArgeName;
|
||||
}
|
||||
}
|
||||
|
||||
const std::string& GetUniqueNodeName(const onnxruntime::Node& node)
|
||||
{
|
||||
|
|
@ -85,7 +54,7 @@ namespace Dml::GraphDescBuilder
|
|||
for (size_t inputIndex = 0; inputIndex < fusedNodeInputDefs.size(); ++inputIndex)
|
||||
{
|
||||
const onnxruntime::NodeArg* graphInput = graph.GetNodeArg(
|
||||
GetFusedNodeArgNameMatchingGraph(fusedNodeInputDefs[inputIndex]->Name()));
|
||||
GraphKernelHelper::GetFusedNodeArgNameMatchingGraph(fusedNodeInputDefs[inputIndex]->Name()));
|
||||
|
||||
if (!graphInput)
|
||||
{
|
||||
|
|
@ -208,10 +177,10 @@ namespace Dml::GraphDescBuilder
|
|||
auto iter = nameToFusedNodeInputIndex.find(arg->Name());
|
||||
|
||||
// The graph input could be missing the suffix, so try to match without it.
|
||||
// This is part of a temporary workaround; see comments in GetFusedNodeArgNameMatchingGraph.
|
||||
// This is part of a temporary workaround; see comments in GraphKernelHelper::GetFusedNodeArgNameMatchingGraph.
|
||||
if (iter == nameToFusedNodeInputIndex.end())
|
||||
{
|
||||
iter = nameToFusedNodeInputIndex.find(GetFusedNodeArgNameMatchingGraph(arg->Name()));
|
||||
iter = nameToFusedNodeInputIndex.find(GraphKernelHelper::GetFusedNodeArgNameMatchingGraph(arg->Name()));
|
||||
}
|
||||
|
||||
if (iter != nameToFusedNodeInputIndex.end())
|
||||
|
|
@ -277,7 +246,7 @@ namespace Dml::GraphDescBuilder
|
|||
for (size_t outputIndex = 0; outputIndex < fusedNodeOutputDefs.size(); ++outputIndex)
|
||||
{
|
||||
const onnxruntime::NodeArg* graphOutput = graph.GetNodeArg(
|
||||
GetFusedNodeArgNameMatchingGraph(fusedNodeOutputDefs[outputIndex]->Name()));
|
||||
GraphKernelHelper::GetFusedNodeArgNameMatchingGraph(fusedNodeOutputDefs[outputIndex]->Name()));
|
||||
|
||||
const auto& outputNodeAndIndex = nameToNodeAndIndexMap.at(graphOutput->Name());
|
||||
|
||||
|
|
|
|||
|
|
@ -109,7 +109,7 @@ namespace GraphKernelHelper
|
|||
const std::unordered_map<std::string, onnx::TensorProto>& transferredInitializerMap)
|
||||
{
|
||||
// Transferred initializers are uploaded to GPU memory
|
||||
auto iter = transferredInitializerMap.find(fusedNodeInputDefs[index]->Name());
|
||||
auto iter = transferredInitializerMap.find(GetFusedNodeArgNameMatchingGraph(fusedNodeInputDefs[index]->Name()));
|
||||
if (iter != transferredInitializerMap.end())
|
||||
{
|
||||
return true;
|
||||
|
|
@ -154,7 +154,7 @@ namespace GraphKernelHelper
|
|||
std::map<const onnx::TensorProto*, uint32_t> initializerToLastInputIndexMap;
|
||||
for (uint32_t i = 0; i < graphInputCount; i++)
|
||||
{
|
||||
auto iter = transferredInitializerMap.find(fusedNodeInputDefs[i]->Name());
|
||||
auto iter = transferredInitializerMap.find(GetFusedNodeArgNameMatchingGraph(fusedNodeInputDefs[i]->Name()));
|
||||
if (iter != transferredInitializerMap.end()) {
|
||||
initializerToLastInputIndexMap[&iter->second] = i;
|
||||
}
|
||||
|
|
@ -172,7 +172,7 @@ namespace GraphKernelHelper
|
|||
// initialization or execution). So just throw away the transferred initializer and skip this input.
|
||||
if (!inputsUsed[i])
|
||||
{
|
||||
transferredInitializerMap.erase(fusedNodeInputDefs[i]->Name());
|
||||
transferredInitializerMap.erase(GetFusedNodeArgNameMatchingGraph(fusedNodeInputDefs[i]->Name()));
|
||||
|
||||
if (inputRawData)
|
||||
{
|
||||
|
|
@ -183,7 +183,7 @@ namespace GraphKernelHelper
|
|||
}
|
||||
|
||||
// Look for the initializer among those transferred from the graph during partitioning
|
||||
auto iter = transferredInitializerMap.find(fusedNodeInputDefs[i]->Name());
|
||||
auto iter = transferredInitializerMap.find(GetFusedNodeArgNameMatchingGraph(fusedNodeInputDefs[i]->Name()));
|
||||
if (iter != transferredInitializerMap.end())
|
||||
{
|
||||
std::byte* tensorPtr = nullptr;
|
||||
|
|
@ -320,5 +320,43 @@ namespace GraphKernelHelper
|
|||
dmlGraphDesc.IntermediateEdgeCount = gsl::narrow_cast<uint32_t>(dmlIntermediateEdges.size());
|
||||
dmlGraphDesc.IntermediateEdges = dmlIntermediateEdges.data();
|
||||
}
|
||||
|
||||
// TODO: This is a hack which strips the suffix added within Lotus transforms that insert mem copies.
|
||||
// This shouldn't be necessary if Lotus exposes the inputs/ouputs in the same order between the kernel
|
||||
// for a function, and the graph for that function exposed as a kernel property. When the ordering
|
||||
// mismatch is fixed (WindowsAI: 21114358, Lotus: 1953), this workaround should be removed.
|
||||
std::string GetFusedNodeArgNameMatchingGraph(const std::string& fusedNodeArgeName)
|
||||
{
|
||||
const char* suffix = nullptr;
|
||||
|
||||
// The suffix used when inserting mem copies is equal to the below, probably followed by an incrementing number.
|
||||
if (!suffix)
|
||||
{
|
||||
suffix = strstr(fusedNodeArgeName.c_str(), "_DmlExecutionProvider_");
|
||||
}
|
||||
|
||||
// The suffix used when inserting mem copies is equal to the below, not followed by an incrementing number.
|
||||
if (!suffix)
|
||||
{
|
||||
suffix = strstr(fusedNodeArgeName.c_str(), "_DmlExecutionProvider");
|
||||
}
|
||||
|
||||
if (!suffix)
|
||||
{
|
||||
suffix = strstr(fusedNodeArgeName.c_str(), "_token_");
|
||||
}
|
||||
|
||||
if (suffix)
|
||||
{
|
||||
return std::string(
|
||||
fusedNodeArgeName.begin(),
|
||||
fusedNodeArgeName.begin() + (suffix - fusedNodeArgeName.c_str())
|
||||
);
|
||||
}
|
||||
else
|
||||
{
|
||||
return fusedNodeArgeName;
|
||||
}
|
||||
}
|
||||
} // namespace GraphKernelHelper
|
||||
} // namespace Dml
|
||||
|
|
@ -63,6 +63,8 @@ namespace GraphKernelHelper
|
|||
_Out_ std::vector<DML_GRAPH_EDGE_DESC>& dmlInputEdges,
|
||||
_Out_ std::vector<DML_GRAPH_EDGE_DESC>& dmlOutputEdges,
|
||||
_Out_ std::vector<DML_GRAPH_EDGE_DESC>& dmlIntermediateEdges);
|
||||
|
||||
std::string GetFusedNodeArgNameMatchingGraph(const std::string& fusedNodeArgeName);
|
||||
|
||||
} // namespace GraphKernelHelper
|
||||
} // namespace Dml
|
||||
|
|
@ -36,12 +36,14 @@ public:
|
|||
|
||||
m_inputTensorDescs[0] = CreateTensorDescFromInput(kernelCreationContext, 0, TensorAxis::DoNotCoerce, TensorAxis::N, TensorAxis::LeftAligned);
|
||||
|
||||
// Massage each of these 1D tensors (of length C) into 4D tensors of the form [1,C,1,1].
|
||||
// Massage each of these 1D tensors (of length C) into ND tensors of the form [1,C,1,1,...].
|
||||
for (uint32_t i = Scale; i < OnnxInputIndex::Count; ++i)
|
||||
{
|
||||
m_inputTensorDescs[i] = CreateTensorDescFromInput(kernelCreationContext, i, TensorAxis::DoNotCoerce, TensorAxis::C, TensorAxis::LeftAligned);
|
||||
m_inputTensorDescs[i] = CreateTensorDescFromInput(kernelCreationContext, i, TensorAxis::DoNotCoerce, TensorAxis::C, TensorAxis::LeftAligned, std::nullopt, m_inputTensorDescs[0].GetDimensionCount());
|
||||
}
|
||||
|
||||
m_outputTensorDescs[0] = CreateTensorDescFromOutput(kernelCreationContext, 0, TensorAxis::DoNotCoerce, TensorAxis::N, TensorAxis::LeftAligned, std::nullopt, m_inputTensorDescs[0].GetDimensionCount());
|
||||
|
||||
std::vector<DML_TENSOR_DESC> inputDescs = GetDmlInputDescs();
|
||||
std::vector<DML_TENSOR_DESC> outputDescs = GetDmlOutputDescs();
|
||||
|
||||
|
|
|
|||
|
|
@ -550,6 +550,13 @@ namespace OperatorHelper
|
|||
++inDim1Iter;
|
||||
}
|
||||
|
||||
// 0-sized dimensions indicate an empty tensor and shouldn't be broadcasted to higher dimensions
|
||||
if (inDimension0 == 0 || inDimension1 == 0)
|
||||
{
|
||||
inDimension0 = 0;
|
||||
inDimension1 = 0;
|
||||
}
|
||||
|
||||
ML_CHECK_VALID_ARGUMENT((inDimension0 == inDimension1) || (inDimension0 == 1) || (inDimension1 == 1));
|
||||
*outDimIter = std::max(inDimension0, inDimension1);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1273,7 +1273,7 @@ class TileHelper {
|
|||
ML_CHECK_VALID_ARGUMENT(repeatsTensor.IsCpuData(), "Tile's repeats tensor must be CPU Tensor.");
|
||||
|
||||
for (size_t i = 0; i < dimCount; ++i) {
|
||||
ML_CHECK_VALID_ARGUMENT(repeatsData[i] > 0, "Repeat values should be > 0.");
|
||||
ML_CHECK_VALID_ARGUMENT(repeatsData[i] >= 0, "Repeat values should be >= 0.");
|
||||
m_repeatsData.push_back(gsl::narrow_cast<uint32_t>(repeatsData[i]));
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -111,7 +111,7 @@ std::shared_ptr<IExecutionProviderFactory> CreateExecutionProviderFactory_DML(in
|
|||
|
||||
D3D12_COMMAND_QUEUE_DESC cmd_queue_desc = {};
|
||||
cmd_queue_desc.Type = D3D12_COMMAND_LIST_TYPE_DIRECT;
|
||||
cmd_queue_desc.Flags = D3D12_COMMAND_QUEUE_FLAG_NONE;
|
||||
cmd_queue_desc.Flags = D3D12_COMMAND_QUEUE_FLAG_DISABLE_GPU_TIMEOUT;
|
||||
|
||||
ComPtr<ID3D12CommandQueue> cmd_queue;
|
||||
THROW_IF_FAILED(d3d12_device->CreateCommandQueue(&cmd_queue_desc, IID_PPV_ARGS(&cmd_queue)));
|
||||
|
|
|
|||
Loading…
Reference in a new issue