From f87527c0dfaec649cd4f4c4ee8de9a57e986770a Mon Sep 17 00:00:00 2001 From: Jeff Bloomfield Date: Wed, 31 Mar 2021 19:06:08 +0000 Subject: [PATCH] Merged PR 5861108: Allow nodes in DML graph partitions with empty shapes on constant CPU inputs Resize is spec'd to ignore the "roi" tensor in certain modes. For some reason, converters are specifying an arbitrary value for this tensor, even though it's optional. This makes the graph partitioner skip a check for empty shape dimensions for tensors such as this, which the DML kernel registers as consuming as CPU inputs. Otherwise, the node is not included in DML graph partitions, because the DML graph doesn't handle empty dimensions. Related work items: #32221164 --- .../dml/DmlExecutionProvider/src/GraphPartitioner.cpp | 4 ++-- .../dml/DmlExecutionProvider/src/MLOperatorAuthorImpl.cpp | 7 ++++--- .../dml/DmlExecutionProvider/src/MLOperatorAuthorImpl.h | 2 +- 3 files changed, 7 insertions(+), 6 deletions(-) diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphPartitioner.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphPartitioner.cpp index 1340ff7c8a..32ba2abe50 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphPartitioner.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphPartitioner.cpp @@ -477,9 +477,9 @@ namespace Dml std::optional requiredInputCount = internalRegInfo->graphNodeFactoryRegistration->requiredInputCount; if (requiredCpuInputsConstant && TryGetStaticInputShapes( node, graphNodeProperty.first->second.inputShapes) && - !ContainsEmptyDimensions(graphNodeProperty.first->second.inputShapes) && + !ContainsEmptyDimensions(graphNodeProperty.first->second.inputShapes, internalRegInfo->requiredConstantCpuInputs) && TryGetStaticOutputShapes(node, graphNodeProperty.first->second.outputShapes) && - !ContainsEmptyDimensions(graphNodeProperty.first->second.outputShapes) && + !ContainsEmptyDimensions(graphNodeProperty.first->second.outputShapes, internalRegInfo->requiredConstantCpuInputs) && (requiredInputCount == std::nullopt || *requiredInputCount == node.InputDefs().size())) { *isDmlGraphNode = true; diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/MLOperatorAuthorImpl.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/MLOperatorAuthorImpl.cpp index 47cbcae00f..d49bcfb370 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/MLOperatorAuthorImpl.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/MLOperatorAuthorImpl.cpp @@ -1876,12 +1876,13 @@ bool TryGetStaticOutputShapes(const onnxruntime::Node& node, EdgeShapes& outputS return true; } -bool ContainsEmptyDimensions(const EdgeShapes& shapes) { +bool ContainsEmptyDimensions(const EdgeShapes& shapes, gsl::span ignoredShapeIndices) { for (size_t i = 0; i < shapes.EdgeCount(); i++) { const std::vector& shape = shapes.GetShape(i); - if (std::find(shape.begin(), shape.end(), 0) != shape.end()) { - return true; + if (std::find(shape.begin(), shape.end(), 0) != shape.end() && + std::find(ignoredShapeIndices.begin(), ignoredShapeIndices.end(), i) == ignoredShapeIndices.end()) { + return true; } } diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/MLOperatorAuthorImpl.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/MLOperatorAuthorImpl.h index 0168da24ef..216faead4c 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/MLOperatorAuthorImpl.h +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/MLOperatorAuthorImpl.h @@ -637,7 +637,7 @@ onnx::AttributeProto_AttributeType ToProto(MLOperatorAttributeType type); bool TryGetStaticInputShapes(const onnxruntime::Node& node, EdgeShapes& inputShapes); bool TryGetStaticOutputShapes(const onnxruntime::Node& node, EdgeShapes& outputShapes); -bool ContainsEmptyDimensions(const EdgeShapes& shapes); +bool ContainsEmptyDimensions(const EdgeShapes& shapes, gsl::span ignoredShapeIndices = gsl::span()); std::tuple, size_t> UnpackTensor(const onnx::TensorProto& initializer); } // namespace Windows::AI::MachineLearning::Adapter