From 67f6b7ce293abb7ef57e9d5dda1d86e839d10ff4 Mon Sep 17 00:00:00 2001 From: Sheil Kumar Date: Thu, 11 Aug 2022 16:01:27 -0700 Subject: [PATCH] DirectML GEMM broken in opset 11 and 13 when optional tensor C not provided (#12568) Set kernel input indices to be fixed to 0,1,2. C input is now optional, so last tensor must be specified. --- .../DmlExecutionProvider/src/Operators/DmlOperatorGemm.cpp | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorGemm.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorGemm.cpp index a559828d63..8c6d8d062f 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorGemm.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorGemm.cpp @@ -10,12 +10,13 @@ class DmlOperatorGemm : public DmlOperator, public GemmHelper { public: DmlOperatorGemm(const MLOperatorKernelCreationContext& kernelInfo) - : DmlOperator(kernelInfo), + : DmlOperator(kernelInfo), GemmHelper(kernelInfo, kernelInfo.GetTensorShapeDescription()) { ML_CHECK_VALID_ARGUMENT(kernelInfo.GetInputCount() >= 2); ML_CHECK_VALID_ARGUMENT(kernelInfo.GetOutputCount() == 1); - DmlOperator::Initialize(kernelInfo); + auto kernelInputIndices = std::vector> { 0, 1, 2 }; + DmlOperator::Initialize(kernelInfo, kernelInputIndices); bool containsBiasTensor = kernelInfo.IsInputValid(2);