diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphDescBuilder.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphDescBuilder.cpp index 2c0fb0f374..eb8067397c 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphDescBuilder.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphDescBuilder.cpp @@ -134,29 +134,6 @@ namespace Dml::GraphDescBuilder &graphNodeInfo ); - // Determine the number of valid inputs and outputs of this node. The graph currently supports opererators - // with unused inputs and outputs only at the end of each list. - uint32_t validOpInputCount = 0; - uint32_t validOpOutputCount = 0; - - for (uint32_t i = 0; i < graphNodeInfo.kernelInputIndices.size(); ++i) - { - if (graphNodeInfo.kernelInputIndices[i] != std::numeric_limits::max()) - { - assert(i - validOpInputCount == 0); - ++validOpInputCount; - } - } - - for (uint32_t i = 0; i < graphNodeInfo.kernelOutputIndices.size(); ++i) - { - if (graphNodeInfo.kernelOutputIndices[i] != std::numeric_limits::max()) - { - assert(i - validOpOutputCount == 0); - ++validOpOutputCount; - } - } - uint32_t nodeIndex = gsl::narrow_cast(graphNodes.size()); AbstractOperatorDesc opDesc = *graphNodeInfo.desc; // Make a copy @@ -166,8 +143,13 @@ namespace Dml::GraphDescBuilder std::vector outputTensorDescs = opDesc.GetOutputTensors(); // Set connections of the new node - for (uint32_t inputIndex = 0; inputIndex < validOpInputCount; ++inputIndex) + for (uint32_t inputIndex = 0; inputIndex < graphNodeInfo.kernelInputIndices.size(); ++inputIndex) { + if (graphNodeInfo.kernelInputIndices[inputIndex] == std::numeric_limits::max()) + { + continue; + } + uint32_t kernelInputIndex = graphNodeInfo.kernelInputIndices[inputIndex]; const onnxruntime::NodeArg* arg = node.InputDefs()[kernelInputIndex]; @@ -219,8 +201,13 @@ namespace Dml::GraphDescBuilder // Store the new node for lookup when downstream nodes consume it. - for (uint32_t outputIndex = 0; outputIndex < validOpOutputCount; ++outputIndex) + for (uint32_t outputIndex = 0; outputIndex < graphNodeInfo.kernelOutputIndices.size(); ++outputIndex) { + if (graphNodeInfo.kernelOutputIndices[outputIndex] == std::numeric_limits::max()) + { + continue; + } + uint32_t kernelOutputIndex = graphNodeInfo.kernelOutputIndices[outputIndex]; const onnxruntime::NodeArg* arg = node.OutputDefs()[kernelOutputIndex]; if (arg->Exists()) diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.cpp index 3292e9d6fb..413971166c 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.cpp @@ -576,10 +576,10 @@ constexpr static OperatorRegistrationInformation operatorRegistrationInformation {REG_INFO( 9, MaxUnpool, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStridesFromAnyEp, requiredConstantCpuInputs(2))}, {REG_INFO( 11, MaxUnpool, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStridesFromAnyEp, requiredConstantCpuInputs(2))}, // 11 is identical to 9. - {REG_INFO( 10, QLinearConv, typeNameListFour, supportedTypeListQLinearConv, DmlGraphSupport::NotSupported)}, - {REG_INFO( 10, QLinearMatMul, typeNameListThree, supportedTypeListQLinearMatMul, DmlGraphSupport::NotSupported)}, - {REG_INFO( 10, MatMulInteger, typeNameListThree, supportedTypeListInteger, DmlGraphSupport::NotSupported)}, - {REG_INFO( 10, ConvInteger, typeNameListThree, supportedTypeListInteger, DmlGraphSupport::NotSupported)}, + {REG_INFO( 10, QLinearConv, typeNameListFour, supportedTypeListQLinearConv, DmlGraphSupport::Supported)}, + {REG_INFO( 10, QLinearMatMul, typeNameListThree, supportedTypeListQLinearMatMul, DmlGraphSupport::Supported)}, + {REG_INFO( 10, MatMulInteger, typeNameListThree, supportedTypeListInteger, DmlGraphSupport::Supported)}, + {REG_INFO( 10, ConvInteger, typeNameListThree, supportedTypeListInteger, DmlGraphSupport::Supported)}, }; template diff --git a/onnxruntime/python/tools/quantization/E2E_example_model/image_classification/cpu/test_images/daisy.jpg b/onnxruntime/python/tools/quantization/E2E_example_model/image_classification/cpu/test_images/daisy.jpg old mode 100644 new mode 100755 diff --git a/onnxruntime/python/tools/quantization/E2E_example_model/image_classification/cpu/test_images/rose.jpg b/onnxruntime/python/tools/quantization/E2E_example_model/image_classification/cpu/test_images/rose.jpg old mode 100644 new mode 100755 diff --git a/onnxruntime/python/tools/quantization/E2E_example_model/image_classification/cpu/test_images/tulip.jpg b/onnxruntime/python/tools/quantization/E2E_example_model/image_classification/cpu/test_images/tulip.jpg old mode 100644 new mode 100755