From cbdd0bb7292b6344ebca8f484446ffd9af63aec4 Mon Sep 17 00:00:00 2001 From: Xiang Zhang Date: Wed, 26 Jul 2023 15:31:09 -0700 Subject: [PATCH] QAttention calls into MatMulIntToFloat instead of Dequantize+GEMM (#16851) ### Description Update QAttention calling into MatMulIntToFloat instead of Dequantize+GEMM to enable more metacommand path. --- .../DmlOperatorDynamicQuantizeMatMul.cpp | 2 +- .../src/Operators/DmlOperatorQAttention.cpp | 224 +++++++----------- 2 files changed, 81 insertions(+), 145 deletions(-) diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorDynamicQuantizeMatMul.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorDynamicQuantizeMatMul.cpp index 77b01f35f6..fa8d0076cb 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorDynamicQuantizeMatMul.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorDynamicQuantizeMatMul.cpp @@ -105,7 +105,7 @@ public: matrixMultiplyIntergerToFloatOperatorDesc.BiasTensor = hasBias? &inputDescs[OnnxInputIndex::Bias] : nullptr; matrixMultiplyIntergerToFloatOperatorDesc.OutputTensor = &outputDescs[0]; - const DML_OPERATOR_DESC opDesc2{ (DML_OPERATOR_TYPE)DML_OPERATOR_MATRIX_MULTIPLY_INTEGER_TO_FLOAT, &matrixMultiplyIntergerToFloatOperatorDesc}; + const DML_OPERATOR_DESC opDesc2{ static_cast(DML_OPERATOR_MATRIX_MULTIPLY_INTEGER_TO_FLOAT), &matrixMultiplyIntergerToFloatOperatorDesc}; MLOperatorGraphDesc operatorGraphDesc = {}; std::vector opDescs{&opDesc1, &opDesc2}; diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorQAttention.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorQAttention.cpp index 92839e5373..6e0785c91a 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorQAttention.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorQAttention.cpp @@ -10,9 +10,7 @@ Abbreviations: B is batch_size, S is sequence_length, W is hidden_size M A B C // M, A, B, and C are Inputs | | | / - | Dequantize / - | \ | / - | Gemm + | MatMulIntToFloat | / | \ | / | \ | / | \ @@ -133,23 +131,9 @@ public: kernelCreationContext.GetInputEdgeDescription(weightsIndex).tensorDataType, desiredWeightTensorShape, weightTensorShape); - m_inputTensorDescs[inputScaleIndex] = TensorDesc::ConstructBroadcastedTensorDesc( - kernelCreationContext.GetInputEdgeDescription(inputScaleIndex).tensorDataType, - inputTensorShape, - m_inputTensorDescs[inputScaleIndex].GetSizes()); - m_inputTensorDescs[inputZeroPointIndex] = TensorDesc::ConstructBroadcastedTensorDesc( - kernelCreationContext.GetInputEdgeDescription(inputZeroPointIndex).tensorDataType, - inputTensorShape, - m_inputTensorDescs[inputZeroPointIndex].GetSizes()); - m_inputTensorDescs[weightScaleIndex] = TensorDesc::ConstructBroadcastedTensorDesc( - kernelCreationContext.GetInputEdgeDescription(weightScaleIndex).tensorDataType, - desiredWeightTensorShape, - m_inputTensorDescs[weightScaleIndex].GetSizes()); - m_inputTensorDescs[weightZeroPointIndex] = TensorDesc::ConstructBroadcastedTensorDesc( - kernelCreationContext.GetInputEdgeDescription(weightZeroPointIndex).tensorDataType, - desiredWeightTensorShape, - m_inputTensorDescs[weightZeroPointIndex].GetSizes()); + uint32_t desiredBiasTensorShape[3] = {batchSize, sequenceLength, hiddenSize + hiddenSize + vHiddenSize}; + if (hasBias) { auto biasTensorShape = m_inputTensorDescs[biasIndex].GetSizes(); @@ -190,7 +174,7 @@ public: hasMaxSequenceMask = true; ML_CHECK_VALID_ARGUMENT(maskIndexTensorShape[2] == maskIndexTensorShape[3]); const uint32_t maxSequenceLength = maskIndexTensorShape[2]; - uint32_t desiredMaskIndexShape[4] {batchSize, numHeads, maxSequenceLength, maxSequenceLength}; + uint32_t desiredMaskIndexShape[4] = {batchSize, numHeads, maxSequenceLength, maxSequenceLength}; maskTensorDataType = kernelCreationContext.GetInputEdgeDescription(maskIndex).tensorDataType; m_inputTensorDescs[maskIndex] = TensorDesc::ConstructBroadcastedTensorDesc(maskTensorDataType, desiredMaskIndexShape, reshapedMaskIndexTensorShape); } @@ -198,65 +182,36 @@ public: { uint32_t maskIndexDimensionCount = gsl::narrow_cast(maskIndexTensorShape.size()); reshapedMaskIndexTensorShape.insert(reshapedMaskIndexTensorShape.begin() + 1, 4 - maskIndexDimensionCount, 1); - uint32_t desiredMaskIndexShape[4] {batchSize, numHeads, sequenceLength, sequenceLength}; + uint32_t desiredMaskIndexShape[4] = {batchSize, numHeads, sequenceLength, sequenceLength}; maskTensorDataType = kernelCreationContext.GetInputEdgeDescription(maskIndex).tensorDataType; m_inputTensorDescs[maskIndex] = TensorDesc::ConstructBroadcastedTensorDesc(maskTensorDataType, desiredMaskIndexShape, reshapedMaskIndexTensorShape); } } } - TensorDesc firstGemmOutputTensorDesc = TensorDesc::ConstructDefaultTensorDesc(dataType, desiredBiasTensorShape); - DML_TENSOR_DESC namedFirstGemmOutputTensorDesc = firstGemmOutputTensorDesc.GetDmlDesc(); + TensorDesc matMulIntToFloatOutputTensorDesc = TensorDesc::ConstructDefaultTensorDesc(dataType, desiredBiasTensorShape); + DML_TENSOR_DESC namedMatMulIntToFloatOutputTensorDesc = matMulIntToFloatOutputTensorDesc.GetDmlDesc(); std::vector inputDescs = GetDmlInputDescs(); std::vector outputDescs = GetDmlOutputDescs(); - // output edge between Dequantize and first GEMM node - TensorDesc intermediateInputTensorDesc = TensorDesc::ConstructDefaultTensorDesc(dataType, inputTensorShape); + DML_MATRIX_MULTIPLY_INTEGER_TO_FLOAT_OPERATOR_DESC matMulIntToFloatOperatorDesc = {}; + matMulIntToFloatOperatorDesc.ATensor = &inputDescs[InputIndex::inputIndex]; + matMulIntToFloatOperatorDesc.AScaleTensor = &inputDescs[InputIndex::inputScaleIndex]; + matMulIntToFloatOperatorDesc.AZeroPointTensor = &inputDescs[InputIndex::inputZeroPointIndex]; + matMulIntToFloatOperatorDesc.BTensor = &inputDescs[InputIndex::weightsIndex]; + matMulIntToFloatOperatorDesc.BScaleTensor = &inputDescs[InputIndex::weightScaleIndex]; + matMulIntToFloatOperatorDesc.BZeroPointTensor = &inputDescs[InputIndex::weightZeroPointIndex]; + matMulIntToFloatOperatorDesc.BiasTensor = hasBias ? &inputDescs[InputIndex::biasIndex] : nullptr; + matMulIntToFloatOperatorDesc.OutputTensor = &namedMatMulIntToFloatOutputTensorDesc; - TensorDesc intermediateWeightTensorDesc = TensorDesc::ConstructDefaultTensorDesc(dataType, desiredWeightTensorShape); + const DML_OPERATOR_DESC matMulIntToFloatDesc = { static_cast(DML_OPERATOR_MATRIX_MULTIPLY_INTEGER_TO_FLOAT), &matMulIntToFloatOperatorDesc}; - DML_TENSOR_DESC namedIntermediateInputTensorDesc = intermediateInputTensorDesc.GetDmlDesc(); - DML_TENSOR_DESC namedIntermediateWeightTensorDesc = intermediateWeightTensorDesc.GetDmlDesc(); - - DML_ELEMENT_WISE_DEQUANTIZE_LINEAR_OPERATOR_DESC inputDequantizeOperatorDesc = {}; - inputDequantizeOperatorDesc.InputTensor = &inputDescs[InputIndex::inputIndex]; - inputDequantizeOperatorDesc.ScaleTensor = &inputDescs[InputIndex::inputScaleIndex]; - inputDequantizeOperatorDesc.ZeroPointTensor = &inputDescs[InputIndex::inputZeroPointIndex]; - inputDequantizeOperatorDesc.OutputTensor = &namedIntermediateInputTensorDesc; - - const DML_OPERATOR_DESC inputDequantizeOpDesc{DML_OPERATOR_ELEMENT_WISE_DEQUANTIZE_LINEAR, &inputDequantizeOperatorDesc}; - - DML_ELEMENT_WISE_DEQUANTIZE_LINEAR_OPERATOR_DESC weightDequantizeOperatorDesc = {}; - weightDequantizeOperatorDesc.InputTensor = &inputDescs[InputIndex::weightsIndex]; - weightDequantizeOperatorDesc.ScaleTensor = &inputDescs[InputIndex::weightScaleIndex]; - weightDequantizeOperatorDesc.ZeroPointTensor = &inputDescs[InputIndex::weightZeroPointIndex]; - weightDequantizeOperatorDesc.OutputTensor = &namedIntermediateWeightTensorDesc; - - const DML_OPERATOR_DESC weightDequantizeOpDesc{DML_OPERATOR_ELEMENT_WISE_DEQUANTIZE_LINEAR, &weightDequantizeOperatorDesc}; - - DML_GEMM_OPERATOR_DESC gemmOperatorDesc = {}; - gemmOperatorDesc.ATensor = inputDequantizeOperatorDesc.OutputTensor; - gemmOperatorDesc.BTensor = weightDequantizeOperatorDesc.OutputTensor; - - if (hasBias) - { - gemmOperatorDesc.CTensor = &inputDescs[2]; - } - - gemmOperatorDesc.OutputTensor = &namedFirstGemmOutputTensorDesc; - gemmOperatorDesc.TransA = DML_MATRIX_TRANSFORM_NONE; - gemmOperatorDesc.TransB = DML_MATRIX_TRANSFORM_NONE; - gemmOperatorDesc.Alpha = 1.0f; - gemmOperatorDesc.Beta = 1.0f; - gemmOperatorDesc.FusedActivation = nullptr; - const DML_OPERATOR_DESC gemmDesc {DML_OPERATOR_GEMM, &gemmOperatorDesc}; - - std::array queryKeySlicedTensorShape {batchSize, sequenceLength, hiddenSize + hiddenSize}; + std::array queryKeySlicedTensorShape = {batchSize, sequenceLength, hiddenSize + hiddenSize}; TensorDesc queryKeySlicedInputTensorDesc = TensorDesc::ConstructDefaultTensorDesc(dataType, queryKeySlicedTensorShape); DML_TENSOR_DESC namedQueryKeySlicedInputTensorDesc = queryKeySlicedInputTensorDesc.GetDmlDesc(); - std::array valueSlicedTensorShape {batchSize, sequenceLength, vHiddenSize}; + std::array valueSlicedTensorShape = {batchSize, sequenceLength, vHiddenSize}; TensorDesc valueSlicedInputTensorDesc = TensorDesc::ConstructDefaultTensorDesc(dataType, valueSlicedTensorShape); DML_TENSOR_DESC namedValueSlicedInputTensorDesc = valueSlicedInputTensorDesc.GetDmlDesc(); @@ -282,7 +237,7 @@ public: DML_TENSOR_DESC namedQueryKeyTransposedOutputTensorDesc = queryKeyTransposedOutputTensorDesc.GetDmlDesc(); // Transpose QKV from [batchSize, sequenceLength, 3, numHeads, headSize] to [batchSize, sequenceLength, numHeads, 3, headSize] - std::array queryKeyValueTransposedTensorShape {batchSize, sequenceLength, numHeads, 3, headSize}; + std::array queryKeyValueTransposedTensorShape = {batchSize, sequenceLength, numHeads, 3, headSize}; std::array queryKeyValueTransposedStrides = { sequenceLength * numHeads * 3 * headSize, numHeads * 3 * headSize, @@ -317,14 +272,14 @@ public: DML_ELEMENT_WISE_IDENTITY_OPERATOR_DESC transposeOperatorDesc = {}; if (hasSlicedValue) { - queryKeySlicedOperatorDesc.InputTensor = &namedFirstGemmOutputTensorDesc; + queryKeySlicedOperatorDesc.InputTensor = &namedMatMulIntToFloatOutputTensorDesc; queryKeySlicedOperatorDesc.OutputTensor = &namedQueryKeySlicedInputTensorDesc; queryKeySlicedOperatorDesc.DimensionCount = gsl::narrow_cast(queryKeySlicedTensorShape.size()); queryKeySlicedOperatorDesc.InputWindowOffsets = queryKeySliceOffset.data(); queryKeySlicedOperatorDesc.InputWindowSizes = queryKeySliceSize.data(); queryKeySlicedOperatorDesc.InputWindowStrides = queryKeySliceStrides.data(); - valueSlicedOperatorDesc.InputTensor = &namedFirstGemmOutputTensorDesc; + valueSlicedOperatorDesc.InputTensor = &namedMatMulIntToFloatOutputTensorDesc; valueSlicedOperatorDesc.OutputTensor = &namedValueSlicedInputTensorDesc; valueSlicedOperatorDesc.DimensionCount = gsl::narrow_cast(valueSlicedTensorShape.size()); valueSlicedOperatorDesc.InputWindowOffsets = valueSliceOffset.data(); @@ -378,7 +333,6 @@ public: mhaOperatorDesc.MaskTensor = hasMask ? &inputDescs[maskIndex] : nullptr; } - // mhaOperatorDesc.RelativePositionBiasTensor = hasRelativePositionBias ? &inputDescs[dmlRelativePositionBiasIndex] : nullptr; mhaOperatorDesc.RelativePositionBiasTensor = nullptr; mhaOperatorDesc.OutputTensor = &outputDescs[outputIndex]; mhaOperatorDesc.Scale = kernelCreationContext.GetOptionalAttribute(AttrName::Scale, gsl::narrow_cast(1.0f / std::sqrt(headSize))); @@ -393,16 +347,12 @@ public: std::vector outputEdges; std::vector opDescs = { - &inputDequantizeOpDesc, - &weightDequantizeOpDesc, - &gemmDesc, + &matMulIntToFloatDesc, &mhaDesc, }; uint32_t currentNodeIndex = 0; - const uint32_t inputDequantizeNodeIndex = currentNodeIndex++; - const uint32_t weightDequantizeNodeIndex = currentNodeIndex++; - const uint32_t gemmNodeIndex = currentNodeIndex++; + const uint32_t matMulIntToFloatNodeIndex = currentNodeIndex++; const uint32_t mhaNodeIndex = currentNodeIndex++; uint32_t valueSliceNodeIndex = 0; @@ -433,63 +383,49 @@ public: maskSliceNodeIndex = currentNodeIndex++; } - DML_INPUT_GRAPH_EDGE_DESC inputToDequantizeEdge = {}; - inputToDequantizeEdge.GraphInputIndex = InputIndex::inputIndex; - inputToDequantizeEdge.ToNodeIndex = inputDequantizeNodeIndex; - inputToDequantizeEdge.ToNodeInputIndex = 0; - inputEdges.push_back(inputToDequantizeEdge); + DML_INPUT_GRAPH_EDGE_DESC inputToMatMulIntToFloatEdge = {}; + inputToMatMulIntToFloatEdge.GraphInputIndex = InputIndex::inputIndex; + inputToMatMulIntToFloatEdge.ToNodeIndex = matMulIntToFloatNodeIndex; + inputToMatMulIntToFloatEdge.ToNodeInputIndex = 0; + inputEdges.push_back(inputToMatMulIntToFloatEdge); - DML_INPUT_GRAPH_EDGE_DESC inputScaleToDequantizeEdge = {}; - inputScaleToDequantizeEdge.GraphInputIndex = InputIndex::inputScaleIndex; - inputScaleToDequantizeEdge.ToNodeIndex = inputDequantizeNodeIndex; - inputScaleToDequantizeEdge.ToNodeInputIndex = 1; - inputEdges.push_back(inputScaleToDequantizeEdge); + DML_INPUT_GRAPH_EDGE_DESC inputScaleToMatMulIntToFloatEdge = {}; + inputScaleToMatMulIntToFloatEdge.GraphInputIndex = InputIndex::inputScaleIndex; + inputScaleToMatMulIntToFloatEdge.ToNodeIndex = matMulIntToFloatNodeIndex; + inputScaleToMatMulIntToFloatEdge.ToNodeInputIndex = 1; + inputEdges.push_back(inputScaleToMatMulIntToFloatEdge); - DML_INPUT_GRAPH_EDGE_DESC inputZeroPointToDequantizeEdge = {}; - inputZeroPointToDequantizeEdge.GraphInputIndex = InputIndex::inputZeroPointIndex; - inputZeroPointToDequantizeEdge.ToNodeIndex = inputDequantizeNodeIndex; - inputZeroPointToDequantizeEdge.ToNodeInputIndex = 2; - inputEdges.push_back(inputZeroPointToDequantizeEdge); + DML_INPUT_GRAPH_EDGE_DESC inputZeroPointToMatMulIntToFloatEdge = {}; + inputZeroPointToMatMulIntToFloatEdge.GraphInputIndex = InputIndex::inputZeroPointIndex; + inputZeroPointToMatMulIntToFloatEdge.ToNodeIndex = matMulIntToFloatNodeIndex; + inputZeroPointToMatMulIntToFloatEdge.ToNodeInputIndex = 2; + inputEdges.push_back(inputZeroPointToMatMulIntToFloatEdge); - DML_INPUT_GRAPH_EDGE_DESC weightToDequantizeEdge = {}; - weightToDequantizeEdge.GraphInputIndex = InputIndex::weightsIndex; - weightToDequantizeEdge.ToNodeIndex = weightDequantizeNodeIndex; - weightToDequantizeEdge.ToNodeInputIndex = 0; - inputEdges.push_back(weightToDequantizeEdge); + DML_INPUT_GRAPH_EDGE_DESC weightToMatMulIntToFloatEdge = {}; + weightToMatMulIntToFloatEdge.GraphInputIndex = InputIndex::weightsIndex; + weightToMatMulIntToFloatEdge.ToNodeIndex = matMulIntToFloatNodeIndex; + weightToMatMulIntToFloatEdge.ToNodeInputIndex = 3; + inputEdges.push_back(weightToMatMulIntToFloatEdge); - DML_INPUT_GRAPH_EDGE_DESC weightScaleToDequantizeEdge = {}; - weightScaleToDequantizeEdge.GraphInputIndex = InputIndex::weightScaleIndex; - weightScaleToDequantizeEdge.ToNodeIndex = weightDequantizeNodeIndex; - weightScaleToDequantizeEdge.ToNodeInputIndex = 1; - inputEdges.push_back(weightScaleToDequantizeEdge); + DML_INPUT_GRAPH_EDGE_DESC weightScaleToMatMulIntToFloatEdge = {}; + weightScaleToMatMulIntToFloatEdge.GraphInputIndex = InputIndex::weightScaleIndex; + weightScaleToMatMulIntToFloatEdge.ToNodeIndex = matMulIntToFloatNodeIndex; + weightScaleToMatMulIntToFloatEdge.ToNodeInputIndex = 4; + inputEdges.push_back(weightScaleToMatMulIntToFloatEdge); - DML_INPUT_GRAPH_EDGE_DESC weightZeroPointToDequantizeEdge = {}; - weightZeroPointToDequantizeEdge.GraphInputIndex = InputIndex::weightZeroPointIndex; - weightZeroPointToDequantizeEdge.ToNodeIndex = weightDequantizeNodeIndex; - weightZeroPointToDequantizeEdge.ToNodeInputIndex = 2; - inputEdges.push_back(weightZeroPointToDequantizeEdge); - - DML_INTERMEDIATE_GRAPH_EDGE_DESC inputQuantizeToGemmEdge = {}; - inputQuantizeToGemmEdge.FromNodeIndex = inputDequantizeNodeIndex; - inputQuantizeToGemmEdge.FromNodeOutputIndex = 0; - inputQuantizeToGemmEdge.ToNodeIndex = gemmNodeIndex; - inputQuantizeToGemmEdge.ToNodeInputIndex = 0; - intermediateEdges.push_back(inputQuantizeToGemmEdge); - - DML_INTERMEDIATE_GRAPH_EDGE_DESC weightQuantizeToGemmEdge = {}; - weightQuantizeToGemmEdge.FromNodeIndex = weightDequantizeNodeIndex; - weightQuantizeToGemmEdge.FromNodeOutputIndex = 0; - weightQuantizeToGemmEdge.ToNodeIndex = gemmNodeIndex; - weightQuantizeToGemmEdge.ToNodeInputIndex = 1; - intermediateEdges.push_back(weightQuantizeToGemmEdge); + DML_INPUT_GRAPH_EDGE_DESC weightZeroPointToMatMulIntToFloatEdge = {}; + weightZeroPointToMatMulIntToFloatEdge.GraphInputIndex = InputIndex::weightZeroPointIndex; + weightZeroPointToMatMulIntToFloatEdge.ToNodeIndex = matMulIntToFloatNodeIndex; + weightZeroPointToMatMulIntToFloatEdge.ToNodeInputIndex = 5; + inputEdges.push_back(weightZeroPointToMatMulIntToFloatEdge); if (hasBias) { - DML_INPUT_GRAPH_EDGE_DESC biasToGemmEdge = {}; - biasToGemmEdge.GraphInputIndex = biasIndex; - biasToGemmEdge.ToNodeIndex = gemmNodeIndex; - biasToGemmEdge.ToNodeInputIndex = 2; - inputEdges.push_back(biasToGemmEdge); + DML_INPUT_GRAPH_EDGE_DESC biasToMatMulIntToFloatEdge = {}; + biasToMatMulIntToFloatEdge.GraphInputIndex = InputIndex::biasIndex; + biasToMatMulIntToFloatEdge.ToNodeIndex = matMulIntToFloatNodeIndex; + biasToMatMulIntToFloatEdge.ToNodeInputIndex = 6; + inputEdges.push_back(biasToMatMulIntToFloatEdge); } if (hasMask) @@ -497,7 +433,7 @@ public: if (hasUnpaddedBounds) { DML_INPUT_GRAPH_EDGE_DESC maskToMhaEdge = {}; - maskToMhaEdge.GraphInputIndex = maskIndex; + maskToMhaEdge.GraphInputIndex = InputIndex::maskIndex; maskToMhaEdge.ToNodeIndex = mhaNodeIndex; maskToMhaEdge.ToNodeInputIndex = mhaMaskIndex; inputEdges.push_back(maskToMhaEdge); @@ -505,7 +441,7 @@ public: else if (hasMaxSequenceMask) { DML_INPUT_GRAPH_EDGE_DESC maskToMaskSliceEdge = {}; - maskToMaskSliceEdge.GraphInputIndex = maskIndex; + maskToMaskSliceEdge.GraphInputIndex = InputIndex::maskIndex; maskToMaskSliceEdge.ToNodeIndex = maskSliceNodeIndex; maskToMaskSliceEdge.ToNodeInputIndex = 0; inputEdges.push_back(maskToMaskSliceEdge); @@ -520,7 +456,7 @@ public: else { DML_INPUT_GRAPH_EDGE_DESC maskToMhaEdge = {}; - maskToMhaEdge.GraphInputIndex = maskIndex; + maskToMhaEdge.GraphInputIndex = InputIndex::maskIndex; maskToMhaEdge.ToNodeIndex = mhaNodeIndex; maskToMhaEdge.ToNodeInputIndex = mhaMaskIndex; inputEdges.push_back(maskToMhaEdge); @@ -530,12 +466,12 @@ public: if (hasSlicedValue) { // We need to slice QK and V, and transpose QK - DML_INTERMEDIATE_GRAPH_EDGE_DESC gemmToQueryKeySliceEdge = {}; - gemmToQueryKeySliceEdge.FromNodeIndex = gemmNodeIndex; - gemmToQueryKeySliceEdge.FromNodeOutputIndex = 0; - gemmToQueryKeySliceEdge.ToNodeIndex = queryKeySliceNodeIndex; - gemmToQueryKeySliceEdge.ToNodeInputIndex = 0; - intermediateEdges.push_back(gemmToQueryKeySliceEdge); + DML_INTERMEDIATE_GRAPH_EDGE_DESC matMulIntToFloatToQueryKeySliceEdge = {}; + matMulIntToFloatToQueryKeySliceEdge.FromNodeIndex = matMulIntToFloatNodeIndex; + matMulIntToFloatToQueryKeySliceEdge.FromNodeOutputIndex = 0; + matMulIntToFloatToQueryKeySliceEdge.ToNodeIndex = queryKeySliceNodeIndex; + matMulIntToFloatToQueryKeySliceEdge.ToNodeInputIndex = 0; + intermediateEdges.push_back(matMulIntToFloatToQueryKeySliceEdge); DML_INTERMEDIATE_GRAPH_EDGE_DESC queryKeySliceToTransposeEdge = {}; queryKeySliceToTransposeEdge.FromNodeIndex = queryKeySliceNodeIndex; @@ -551,12 +487,12 @@ public: queryKeyTransposedToMhaEdge.ToNodeInputIndex = mhaStackedQueryKeyIndex; intermediateEdges.push_back(queryKeyTransposedToMhaEdge); - DML_INTERMEDIATE_GRAPH_EDGE_DESC gemmToValueSliceEdge = {}; - gemmToValueSliceEdge.FromNodeIndex = gemmNodeIndex; - gemmToValueSliceEdge.FromNodeOutputIndex = 0; - gemmToValueSliceEdge.ToNodeIndex = valueSliceNodeIndex; - gemmToValueSliceEdge.ToNodeInputIndex = 0; - intermediateEdges.push_back(gemmToValueSliceEdge); + DML_INTERMEDIATE_GRAPH_EDGE_DESC matMulIntToFloatToValueSliceEdge = {}; + matMulIntToFloatToValueSliceEdge.FromNodeIndex = matMulIntToFloatNodeIndex; + matMulIntToFloatToValueSliceEdge.FromNodeOutputIndex = 0; + matMulIntToFloatToValueSliceEdge.ToNodeIndex = valueSliceNodeIndex; + matMulIntToFloatToValueSliceEdge.ToNodeInputIndex = 0; + intermediateEdges.push_back(matMulIntToFloatToValueSliceEdge); DML_INTERMEDIATE_GRAPH_EDGE_DESC valueSliceToMhaEdge = {}; valueSliceToMhaEdge.FromNodeIndex = valueSliceNodeIndex; @@ -567,12 +503,12 @@ public: } else { - DML_INTERMEDIATE_GRAPH_EDGE_DESC gemmToQueryKeyValueTransposeEdge = {}; - gemmToQueryKeyValueTransposeEdge.FromNodeIndex = gemmNodeIndex; - gemmToQueryKeyValueTransposeEdge.FromNodeOutputIndex = 0; - gemmToQueryKeyValueTransposeEdge.ToNodeIndex = queryKeyValueTransposedNodeIndex; - gemmToQueryKeyValueTransposeEdge.ToNodeInputIndex = 0; - intermediateEdges.push_back(gemmToQueryKeyValueTransposeEdge); + DML_INTERMEDIATE_GRAPH_EDGE_DESC matMulIntToFloatToQueryKeyValueTransposeEdge = {}; + matMulIntToFloatToQueryKeyValueTransposeEdge.FromNodeIndex = matMulIntToFloatNodeIndex; + matMulIntToFloatToQueryKeyValueTransposeEdge.FromNodeOutputIndex = 0; + matMulIntToFloatToQueryKeyValueTransposeEdge.ToNodeIndex = queryKeyValueTransposedNodeIndex; + matMulIntToFloatToQueryKeyValueTransposeEdge.ToNodeInputIndex = 0; + intermediateEdges.push_back(matMulIntToFloatToQueryKeyValueTransposeEdge); // All we need to do here is transpose the stacked QKV tensor into something DML supports DML_INTERMEDIATE_GRAPH_EDGE_DESC queryKeyValueTransposedToMhaEdge = {};