diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphPartitioner.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphPartitioner.cpp index 1340ff7c8a..32ba2abe50 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphPartitioner.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphPartitioner.cpp @@ -477,9 +477,9 @@ namespace Dml std::optional requiredInputCount = internalRegInfo->graphNodeFactoryRegistration->requiredInputCount; if (requiredCpuInputsConstant && TryGetStaticInputShapes( node, graphNodeProperty.first->second.inputShapes) && - !ContainsEmptyDimensions(graphNodeProperty.first->second.inputShapes) && + !ContainsEmptyDimensions(graphNodeProperty.first->second.inputShapes, internalRegInfo->requiredConstantCpuInputs) && TryGetStaticOutputShapes(node, graphNodeProperty.first->second.outputShapes) && - !ContainsEmptyDimensions(graphNodeProperty.first->second.outputShapes) && + !ContainsEmptyDimensions(graphNodeProperty.first->second.outputShapes, internalRegInfo->requiredConstantCpuInputs) && (requiredInputCount == std::nullopt || *requiredInputCount == node.InputDefs().size())) { *isDmlGraphNode = true; diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/MLOperatorAuthorImpl.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/MLOperatorAuthorImpl.cpp index 47cbcae00f..d49bcfb370 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/MLOperatorAuthorImpl.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/MLOperatorAuthorImpl.cpp @@ -1876,12 +1876,13 @@ bool TryGetStaticOutputShapes(const onnxruntime::Node& node, EdgeShapes& outputS return true; } -bool ContainsEmptyDimensions(const EdgeShapes& shapes) { +bool ContainsEmptyDimensions(const EdgeShapes& shapes, gsl::span ignoredShapeIndices) { for (size_t i = 0; i < shapes.EdgeCount(); i++) { const std::vector& shape = shapes.GetShape(i); - if (std::find(shape.begin(), shape.end(), 0) != shape.end()) { - return true; + if (std::find(shape.begin(), shape.end(), 0) != shape.end() && + std::find(ignoredShapeIndices.begin(), ignoredShapeIndices.end(), i) == ignoredShapeIndices.end()) { + return true; } } diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/MLOperatorAuthorImpl.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/MLOperatorAuthorImpl.h index 0168da24ef..216faead4c 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/MLOperatorAuthorImpl.h +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/MLOperatorAuthorImpl.h @@ -637,7 +637,7 @@ onnx::AttributeProto_AttributeType ToProto(MLOperatorAttributeType type); bool TryGetStaticInputShapes(const onnxruntime::Node& node, EdgeShapes& inputShapes); bool TryGetStaticOutputShapes(const onnxruntime::Node& node, EdgeShapes& outputShapes); -bool ContainsEmptyDimensions(const EdgeShapes& shapes); +bool ContainsEmptyDimensions(const EdgeShapes& shapes, gsl::span ignoredShapeIndices = gsl::span()); std::tuple, size_t> UnpackTensor(const onnx::TensorProto& initializer); } // namespace Windows::AI::MachineLearning::Adapter