diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/DirectMLSchema.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/DirectMLSchema.h index 2e9217cf3f..3870a58450 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/DirectMLSchema.h +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/DirectMLSchema.h @@ -848,7 +848,7 @@ constexpr DML_SCHEMA_FIELD DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_SCHEMA_ constexpr DML_OPERATOR_SCHEMA DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_SCHEMA { "DML_OPERATOR_QUANTIZED_LINEAR_AVERAGE_POOLING", - static_cast(DML_OPERATOR_QUANTIZED_LINEAR_AVERAGE_POOLING), + DML_OPERATOR_QUANTIZED_LINEAR_AVERAGE_POOLING, DML_SCHEMA_OPERATOR_SUPPORT_FLAG_NONE, 13, DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_SCHEMA_FIELDS, @@ -1883,7 +1883,7 @@ constexpr DML_SCHEMA_FIELD DML_MATRIX_MULTIPLY_INTEGER_TO_FLOAT_OPERATOR_SCHEMA_ constexpr DML_OPERATOR_SCHEMA DML_MATRIX_MULTIPLY_INTEGER_TO_FLOAT_OPERATOR_SCHEMA { "DML_OPERATOR_MATRIX_MULTIPLY_INTEGER_TO_FLOAT", - static_cast(DML_OPERATOR_MATRIX_MULTIPLY_INTEGER_TO_FLOAT), + DML_OPERATOR_MATRIX_MULTIPLY_INTEGER_TO_FLOAT, DML_SCHEMA_OPERATOR_SUPPORT_FLAG_NONE, 8, DML_MATRIX_MULTIPLY_INTEGER_TO_FLOAT_OPERATOR_SCHEMA_FIELDS, diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorAttention.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorAttention.cpp index bbebb4a333..c8ca6806e7 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorAttention.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorAttention.cpp @@ -571,6 +571,12 @@ void CALLBACK QueryAttention(IMLOperatorSupportQueryContextPrivate* context, /*o return; } + // `past_present_share_buffer == 1` is not supported yet + if (attributes.GetOptionalAttribute(AttrName::PastPresentShareBuffer, 0) != 0) + { + return; + } + *isSupported = true; } diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorDynamicQuantizeMatMul.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorDynamicQuantizeMatMul.cpp index fa8d0076cb..22444d1f19 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorDynamicQuantizeMatMul.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorDynamicQuantizeMatMul.cpp @@ -105,7 +105,7 @@ public: matrixMultiplyIntergerToFloatOperatorDesc.BiasTensor = hasBias? &inputDescs[OnnxInputIndex::Bias] : nullptr; matrixMultiplyIntergerToFloatOperatorDesc.OutputTensor = &outputDescs[0]; - const DML_OPERATOR_DESC opDesc2{ static_cast(DML_OPERATOR_MATRIX_MULTIPLY_INTEGER_TO_FLOAT), &matrixMultiplyIntergerToFloatOperatorDesc}; + const DML_OPERATOR_DESC opDesc2{ DML_OPERATOR_MATRIX_MULTIPLY_INTEGER_TO_FLOAT, &matrixMultiplyIntergerToFloatOperatorDesc}; MLOperatorGraphDesc operatorGraphDesc = {}; std::vector opDescs{&opDesc1, &opDesc2}; diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorQAttention.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorQAttention.cpp index 6e0785c91a..f19c0116fc 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorQAttention.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorQAttention.cpp @@ -8,7 +8,9 @@ Abbreviations: B is batch_size, S is sequence_length, W is hidden_size N is number of attention heads, H is head size, and W=N*H M is mask_index tensor - M A B C // M, A, B, and C are Inputs +M, A, B, C and P are Inputs + + M A B C | | | / | MatMulIntToFloat | / | \ @@ -20,10 +22,27 @@ Abbreviations: B is batch_size, S is sequence_length, W is hidden_size | Identity Identity Identity // The identities are used to transpose NCHW -> NHCW while | | | | // keeping the GEMM strides as NCHW to better target metacommands | | | | - ----------------- MHA ----- - | - | - Output // Final output + | | | | P + | | | | / \ + | | | | / \ + | | | | Slice Slice + | | | | | | + | | | | | | + | | | | | | + --------------------------MHA ----------- + / | \ + / | \ + / | \ + / | \ + / | \ + / | \ + / presentKey presentValue + / \ / + / \ / + / \ / + / Concat + / | + Output1 Output2 (present) This kernel creates a DML_GRAPH, as mentioned above. For reference, refer to this Doc: @@ -39,22 +58,6 @@ public: : DmlOperator(kernelCreationContext) { - enum DmlInputIndex : uint32_t - { - mhaQueryIndex, - mhaKeyIndex, - mhaValueIndex, - mhaStackedQueryKeyIndex, - mhaStackedKeyValueIndex, - mhaStackedQueryKeyValueIndex, - mhaBiasIndex, - mhaMaskIndex, - mhaRelativePositionBiasIndex, - mhaPastKeyIndex, - mhaPastValueIndex, - mhaInputCount, - }; - enum InputIndex : uint32_t { inputIndex, @@ -72,18 +75,45 @@ public: enum OutputIndex : uint32_t { outputIndex, + presentIndex, outputCount, }; - ML_CHECK_VALID_ARGUMENT(kernelCreationContext.GetInputCount() >= 2); + enum MhaInputIndex : uint32_t + { + mhaQueryIndex, + mhaKeyIndex, + mhaValueIndex, + mhaStackedQueryKeyIndex, + mhaStackedKeyValueIndex, + mhaStackedQueryKeyValueIndex, + mhaBiasIndex, + mhaMaskIndex, + mhaRelativePositionBiasIndex, + mhaPastKeyIndex, + mhaPastValueIndex, + mhaInputCount, + }; + + enum MhaOutputIndex : uint32_t + { + mhaOutputIndex, + mhaPresentKeyIndex, + mhaPresentValueIndex, + mhaOutputCount, + }; + + ML_CHECK_VALID_ARGUMENT(kernelCreationContext.GetInputCount() >= 5); ML_CHECK_VALID_ARGUMENT(kernelCreationContext.GetOutputCount() >= 1); const bool hasBias = kernelCreationContext.IsInputValid(biasIndex); const bool hasMask = kernelCreationContext.IsInputValid(maskIndex); const bool hasUnpaddedBounds = hasMask && kernelCreationContext.GetInputTensorDimensionCount(maskIndex) == 1; + const bool hasPast = kernelCreationContext.IsInputValid(pastIndex); DmlOperator::Initialize(kernelCreationContext, std::nullopt, std::nullopt, std::nullopt, std::nullopt, 1); + const bool unidirectional = gsl::narrow_cast(kernelCreationContext.GetAttribute(AttrName::Unidirectional)); const uint32_t numHeads = gsl::narrow_cast(kernelCreationContext.GetAttribute(AttrName::NumHeads)); ML_CHECK_VALID_ARGUMENT(numHeads > 0); // to avoid process crash because of division by zero. @@ -93,38 +123,28 @@ public: auto weightTensorShape = m_inputTensorDescs[weightsIndex].GetSizes(); ML_CHECK_VALID_ARGUMENT(weightTensorShape.size() == 2); ML_CHECK_VALID_ARGUMENT(weightTensorShape[0] == inputTensorShape[2]); + ML_CHECK_VALID_ARGUMENT(weightTensorShape[1] % 3 == 0); - const auto qkvHiddenSizes = kernelCreationContext.GetOptionalAttributeVectorInt32(AttrName::QkvHiddenSizes); if (hasBias) { auto biasTensorShape = m_inputTensorDescs[biasIndex].GetSizes(); ML_CHECK_VALID_ARGUMENT(biasTensorShape.size() == 1); + ML_CHECK_VALID_ARGUMENT(biasTensorShape[0] % 3 == 0); ML_CHECK_VALID_ARGUMENT(weightTensorShape[1] == biasTensorShape[0]); - - if (qkvHiddenSizes.empty()) - { - ML_CHECK_VALID_ARGUMENT(biasTensorShape[0] % 3 == 0); - } } - if (!qkvHiddenSizes.empty()) + if (hasPast) { - ML_CHECK_VALID_ARGUMENT(qkvHiddenSizes.size() == 3); - ML_CHECK_VALID_ARGUMENT(qkvHiddenSizes[0] == qkvHiddenSizes[1]); - } - else - { - ML_CHECK_VALID_ARGUMENT(weightTensorShape[1] % 3 == 0); + ML_CHECK_VALID_ARGUMENT(kernelCreationContext.IsOutputValid(presentIndex)); } - const uint32_t hiddenSize = qkvHiddenSizes.empty() ? weightTensorShape[1] / 3 : qkvHiddenSizes[0]; - const uint32_t vHiddenSize = qkvHiddenSizes.empty() ? weightTensorShape[1] / 3 : qkvHiddenSizes[2]; + const uint32_t hiddenSize = weightTensorShape[1] / 3; const uint32_t headSize = hiddenSize / numHeads; - const uint32_t vHeadSize = vHiddenSize / numHeads; const uint32_t batchSize = inputTensorShape[0]; const uint32_t sequenceLength = inputTensorShape[1]; + const uint32_t pastSequenceLength = hasPast ? m_inputTensorDescs[pastIndex].GetSizes()[3] : 0; - uint32_t desiredWeightTensorShape[3] = {batchSize, weightTensorShape[0], hiddenSize + hiddenSize + vHiddenSize}; + uint32_t desiredWeightTensorShape[3] = {batchSize, weightTensorShape[0], 3 * hiddenSize}; MLOperatorTensorDataType dataType = kernelCreationContext.GetOutputEdgeDescription(outputIndex).tensorDataType; m_inputTensorDescs[weightsIndex] = TensorDesc::ConstructBroadcastedTensorDesc( @@ -132,7 +152,7 @@ public: desiredWeightTensorShape, weightTensorShape); - uint32_t desiredBiasTensorShape[3] = {batchSize, sequenceLength, hiddenSize + hiddenSize + vHiddenSize}; + uint32_t desiredBiasTensorShape[3] = {batchSize, sequenceLength, 3 * hiddenSize}; if (hasBias) { @@ -189,6 +209,14 @@ public: } } + MLOperatorTensorDataType pastTensorDataType = MLOperatorTensorDataType::Undefined; + MLOperatorTensorDataType presentTensorDataType = MLOperatorTensorDataType::Undefined; + if (hasPast) + { + pastTensorDataType = kernelCreationContext.GetInputEdgeDescription(pastIndex).tensorDataType; + presentTensorDataType = kernelCreationContext.GetOutputEdgeDescription(presentIndex).tensorDataType; + } + TensorDesc matMulIntToFloatOutputTensorDesc = TensorDesc::ConstructDefaultTensorDesc(dataType, desiredBiasTensorShape); DML_TENSOR_DESC namedMatMulIntToFloatOutputTensorDesc = matMulIntToFloatOutputTensorDesc.GetDmlDesc(); @@ -205,13 +233,13 @@ public: matMulIntToFloatOperatorDesc.BiasTensor = hasBias ? &inputDescs[InputIndex::biasIndex] : nullptr; matMulIntToFloatOperatorDesc.OutputTensor = &namedMatMulIntToFloatOutputTensorDesc; - const DML_OPERATOR_DESC matMulIntToFloatDesc = { static_cast(DML_OPERATOR_MATRIX_MULTIPLY_INTEGER_TO_FLOAT), &matMulIntToFloatOperatorDesc}; + const DML_OPERATOR_DESC matMulIntToFloatDesc = { DML_OPERATOR_MATRIX_MULTIPLY_INTEGER_TO_FLOAT, &matMulIntToFloatOperatorDesc}; std::array queryKeySlicedTensorShape = {batchSize, sequenceLength, hiddenSize + hiddenSize}; TensorDesc queryKeySlicedInputTensorDesc = TensorDesc::ConstructDefaultTensorDesc(dataType, queryKeySlicedTensorShape); DML_TENSOR_DESC namedQueryKeySlicedInputTensorDesc = queryKeySlicedInputTensorDesc.GetDmlDesc(); - std::array valueSlicedTensorShape = {batchSize, sequenceLength, vHiddenSize}; + std::array valueSlicedTensorShape = {batchSize, sequenceLength, hiddenSize}; TensorDesc valueSlicedInputTensorDesc = TensorDesc::ConstructDefaultTensorDesc(dataType, valueSlicedTensorShape); DML_TENSOR_DESC namedValueSlicedInputTensorDesc = valueSlicedInputTensorDesc.GetDmlDesc(); @@ -262,44 +290,18 @@ public: std::array queryKeySliceStrides = {1, 1, 1}; std::array valueSliceOffset = {0, 0, 2 * hiddenSize}; - std::array valueSliceSize = {batchSize, sequenceLength, vHiddenSize}; + std::array valueSliceSize = {batchSize, sequenceLength, hiddenSize}; std::array valueSliceStrides = {1, 1, 1}; - const bool hasSlicedValue = hiddenSize != vHiddenSize; - // We need to slice the value tensor when its hidden size is different from the query and key - DML_SLICE1_OPERATOR_DESC queryKeySlicedOperatorDesc = {}; - DML_SLICE1_OPERATOR_DESC valueSlicedOperatorDesc = {}; + // When Q/K/V all have the same hidden size, we just have to transpose it before sending it to MHA DML_ELEMENT_WISE_IDENTITY_OPERATOR_DESC transposeOperatorDesc = {}; - if (hasSlicedValue) - { - queryKeySlicedOperatorDesc.InputTensor = &namedMatMulIntToFloatOutputTensorDesc; - queryKeySlicedOperatorDesc.OutputTensor = &namedQueryKeySlicedInputTensorDesc; - queryKeySlicedOperatorDesc.DimensionCount = gsl::narrow_cast(queryKeySlicedTensorShape.size()); - queryKeySlicedOperatorDesc.InputWindowOffsets = queryKeySliceOffset.data(); - queryKeySlicedOperatorDesc.InputWindowSizes = queryKeySliceSize.data(); - queryKeySlicedOperatorDesc.InputWindowStrides = queryKeySliceStrides.data(); - valueSlicedOperatorDesc.InputTensor = &namedMatMulIntToFloatOutputTensorDesc; - valueSlicedOperatorDesc.OutputTensor = &namedValueSlicedInputTensorDesc; - valueSlicedOperatorDesc.DimensionCount = gsl::narrow_cast(valueSlicedTensorShape.size()); - valueSlicedOperatorDesc.InputWindowOffsets = valueSliceOffset.data(); - valueSlicedOperatorDesc.InputWindowSizes = valueSliceSize.data(); - valueSlicedOperatorDesc.InputWindowStrides = valueSliceStrides.data(); + transposeOperatorDesc.InputTensor = &namedQueryKeyValueTransposedInputTensorDesc; + transposeOperatorDesc.OutputTensor = &namedQueryKeyValueTransposedOutputTensorDesc; - transposeOperatorDesc.InputTensor = &namedQueryKeyTransposedInputTensorDesc; - transposeOperatorDesc.OutputTensor = &namedQueryKeyTransposedOutputTensorDesc; - } - else - { - // When Q/K/V all have the same hidden size, we just have to transpose it before sending it to MHA - transposeOperatorDesc.InputTensor = &namedQueryKeyValueTransposedInputTensorDesc; - transposeOperatorDesc.OutputTensor = &namedQueryKeyValueTransposedOutputTensorDesc; - } - const DML_OPERATOR_DESC queryKeySlicedDesc = { DML_OPERATOR_SLICE1, &queryKeySlicedOperatorDesc}; - const DML_OPERATOR_DESC valueSlicedDesc = { DML_OPERATOR_SLICE1, &valueSlicedOperatorDesc}; const DML_OPERATOR_DESC transposedDesc = { DML_OPERATOR_ELEMENT_WISE_IDENTITY, &transposeOperatorDesc}; - std::array maskSliceOutputShape {batchSize, numHeads, sequenceLength, sequenceLength}; + std::array maskSliceOutputShape = {batchSize, numHeads, sequenceLength, sequenceLength}; std::array maskSliceStrides = {1, 1, 1, 1}; std::array maskSliceOffsets = {0, 0, 0, 0}; TensorDesc maskSliceOutputTensorDesc; @@ -319,12 +321,81 @@ public: } const DML_OPERATOR_DESC maskSlicedDesc = { DML_OPERATOR_SLICE1, &maskSlicedOperatorDesc}; - DML_MULTIHEAD_ATTENTION_OPERATOR_DESC mhaOperatorDesc = {}; - mhaOperatorDesc.ValueTensor = hasSlicedValue ? &namedValueSlicedInputTensorDesc : nullptr; - mhaOperatorDesc.StackedQueryKeyTensor = hasSlicedValue ? &namedQueryKeyTransposedOutputTensorDesc : nullptr; - mhaOperatorDesc.StackedQueryKeyValueTensor = hasSlicedValue ? nullptr : &namedQueryKeyValueTransposedOutputTensorDesc; + // We need to slice Past to get PastValue and PastKey tensors for MHA + std::array pastKeyOutputShape = {1, batchSize, numHeads, pastSequenceLength, headSize}; + std::array pastKeyStrides = {1, 1, 1, 1, 1}; + std::array pastKeyOffsets = {0, 0, 0, 0, 0}; + TensorDesc pastKeyOutputTensorDesc; + DML_TENSOR_DESC namedPastKeyOutputTensorDesc; + + std::array pastValueOutputShape = {1, batchSize, numHeads, pastSequenceLength, headSize}; + std::array pastValueStrides = {1, 1, 1, 1, 1}; + std::array pastValueOffsets = {1, 0, 0, 0, 0}; + TensorDesc pastValueOutputTensorDesc; + DML_TENSOR_DESC namedPastValueOutputTensorDesc; - if (hasMaxSequenceMask) + DML_SLICE1_OPERATOR_DESC pastKeySlicedOperatorDesc = {}; + DML_SLICE1_OPERATOR_DESC pastValueSlicedOperatorDesc = {}; + + if (hasPast) + { + pastKeyOutputTensorDesc = TensorDesc::ConstructDefaultTensorDesc(pastTensorDataType, pastKeyOutputShape); + namedPastKeyOutputTensorDesc = pastKeyOutputTensorDesc.GetDmlDesc(); + pastKeySlicedOperatorDesc.InputTensor = &inputDescs[pastIndex]; + pastKeySlicedOperatorDesc.OutputTensor = &namedPastKeyOutputTensorDesc; + pastKeySlicedOperatorDesc.DimensionCount = gsl::narrow_cast(pastKeyOutputShape.size()); + pastKeySlicedOperatorDesc.InputWindowOffsets = pastKeyOffsets.data(); + pastKeySlicedOperatorDesc.InputWindowSizes = pastKeyOutputShape.data(); + pastKeySlicedOperatorDesc.InputWindowStrides = pastKeyStrides.data(); + + pastValueOutputTensorDesc = TensorDesc::ConstructDefaultTensorDesc(pastTensorDataType, pastValueOutputShape); + namedPastValueOutputTensorDesc = pastValueOutputTensorDesc.GetDmlDesc(); + pastValueSlicedOperatorDesc.InputTensor = &inputDescs[pastIndex]; + pastValueSlicedOperatorDesc.OutputTensor = &namedPastValueOutputTensorDesc; + pastValueSlicedOperatorDesc.DimensionCount = gsl::narrow_cast(pastValueOutputShape.size()); + pastValueSlicedOperatorDesc.InputWindowOffsets = pastValueOffsets.data(); + pastValueSlicedOperatorDesc.InputWindowSizes = pastValueOutputShape.data(); + pastValueSlicedOperatorDesc.InputWindowStrides = pastValueStrides.data(); + } + + const DML_OPERATOR_DESC pastKeySlicedDesc = { DML_OPERATOR_SLICE1, &pastKeySlicedOperatorDesc}; + const DML_OPERATOR_DESC pastValueSlicedDesc = { DML_OPERATOR_SLICE1, &pastValueSlicedOperatorDesc}; + + // Causal Mask: [pastSequenceLength, pastSequenceLength + 1 ... pastSequenceLength + batchSize -1] + // passed to MHA as maskIndex Tensor when unidirectional == 1 + std::array causalMaskOutputShape = {1, batchSize}; + TensorDesc causalMaskTensorDesc; + DML_FILL_VALUE_SEQUENCE_OPERATOR_DESC causalMaskOperatorDesc = {}; + DML_TENSOR_DESC namedcausalMaskTensorDesc; + + if (unidirectional && !hasMask) + { + causalMaskTensorDesc = TensorDesc::ConstructDefaultTensorDesc(MLOperatorTensorDataType::Int32, causalMaskOutputShape); + namedcausalMaskTensorDesc = causalMaskTensorDesc.GetDmlDesc(); + causalMaskOperatorDesc.ValueDataType = DML_TENSOR_DATA_TYPE_INT32; + causalMaskOperatorDesc.ValueStart.Int32 = pastSequenceLength; + causalMaskOperatorDesc.ValueDelta.Int32 = 1; + causalMaskOperatorDesc.OutputTensor = &namedcausalMaskTensorDesc; + + maskType = DML_MULTIHEAD_ATTENTION_MASK_TYPE_KEY_SEQUENCE_LENGTH; + } + DML_OPERATOR_DESC causalMaskDesc = { DML_OPERATOR_FILL_VALUE_SEQUENCE, &causalMaskOperatorDesc }; + + DML_MULTIHEAD_ATTENTION_OPERATOR_DESC mhaOperatorDesc = {}; + std::array presentKeyOutputShape = {1, batchSize, numHeads, pastSequenceLength + sequenceLength, headSize}; + std::array presentValueOutputShape = {1, batchSize, numHeads, pastSequenceLength + sequenceLength, headSize}; + TensorDesc presentKeyTensorDesc; + TensorDesc presentValueTensorDesc; + DML_TENSOR_DESC namedPresentKeyOutputTensorDesc; + DML_TENSOR_DESC namedPresentValueOutputTensorDesc; + + mhaOperatorDesc.StackedQueryKeyValueTensor = &namedQueryKeyValueTransposedOutputTensorDesc; + + if (unidirectional && !hasMask) + { + mhaOperatorDesc.MaskTensor = &namedcausalMaskTensorDesc; + } + else if (hasMaxSequenceMask) { mhaOperatorDesc.MaskTensor = &namedMaskSliceOutputTensorDesc; } @@ -339,8 +410,35 @@ public: mhaOperatorDesc.MaskFilterValue = kernelCreationContext.GetOptionalAttribute(AttrName::MaskFilterValue, -10'000.0f); mhaOperatorDesc.HeadCount = numHeads; mhaOperatorDesc.MaskType = maskType; + if (hasPast) + { + presentKeyTensorDesc = TensorDesc::ConstructDefaultTensorDesc(presentTensorDataType, presentKeyOutputShape); + namedPresentKeyOutputTensorDesc = presentKeyTensorDesc.GetDmlDesc(); + presentValueTensorDesc = TensorDesc::ConstructDefaultTensorDesc(presentTensorDataType, presentValueOutputShape); + namedPresentValueOutputTensorDesc = presentValueTensorDesc.GetDmlDesc(); + mhaOperatorDesc.PastKeyTensor = &namedPastKeyOutputTensorDesc; + mhaOperatorDesc.PastValueTensor = &namedPastValueOutputTensorDesc; + mhaOperatorDesc.OutputPresentKeyTensor = &namedPresentKeyOutputTensorDesc; + mhaOperatorDesc.OutputPresentValueTensor = &namedPresentValueOutputTensorDesc; + } + const DML_OPERATOR_DESC mhaDesc = { DML_OPERATOR_MULTIHEAD_ATTENTION, &mhaOperatorDesc }; + DML_JOIN_OPERATOR_DESC presentKeyValueJoinOperatorDesc = {}; + std::vector joinInputDesc; + + if (hasPast) + { + joinInputDesc.push_back(namedPresentKeyOutputTensorDesc); + joinInputDesc.push_back(namedPresentValueOutputTensorDesc); + presentKeyValueJoinOperatorDesc.InputCount = gsl::narrow_cast(joinInputDesc.size()); + presentKeyValueJoinOperatorDesc.InputTensors = joinInputDesc.data(); + presentKeyValueJoinOperatorDesc.OutputTensor = &outputDescs[presentIndex]; + presentKeyValueJoinOperatorDesc.Axis = gsl::narrow_cast(0); + } + + DML_OPERATOR_DESC presentKeyValueJoinDesc = { DML_OPERATOR_JOIN, &presentKeyValueJoinOperatorDesc }; + // Construct the graph std::vector inputEdges; std::vector intermediateEdges; @@ -355,26 +453,10 @@ public: const uint32_t matMulIntToFloatNodeIndex = currentNodeIndex++; const uint32_t mhaNodeIndex = currentNodeIndex++; - uint32_t valueSliceNodeIndex = 0; - uint32_t queryKeySliceNodeIndex = 0; - uint32_t queryKeyTransposedNodeIndex = 0; uint32_t queryKeyValueTransposedNodeIndex = 0; - if (hasSlicedValue) - { - opDescs.push_back(&queryKeySlicedDesc); - queryKeySliceNodeIndex = currentNodeIndex++; - opDescs.push_back(&valueSlicedDesc); - valueSliceNodeIndex = currentNodeIndex++; - - opDescs.push_back(&transposedDesc); - queryKeyTransposedNodeIndex = currentNodeIndex++; - } - else - { - opDescs.push_back(&transposedDesc); - queryKeyValueTransposedNodeIndex = currentNodeIndex++; - } + opDescs.push_back(&transposedDesc); + queryKeyValueTransposedNodeIndex = currentNodeIndex++; uint32_t maskSliceNodeIndex = 0; if (hasMaxSequenceMask) @@ -383,6 +465,26 @@ public: maskSliceNodeIndex = currentNodeIndex++; } + uint32_t pastKeySliceNodeIndex = 0; + uint32_t pastValueSliceNodeIndex = 0; + uint32_t concatNodeIndex = 0; + if (hasPast) + { + opDescs.push_back(&pastKeySlicedDesc); + pastKeySliceNodeIndex = currentNodeIndex++; + opDescs.push_back(&pastValueSlicedDesc); + pastValueSliceNodeIndex = currentNodeIndex++; + opDescs.push_back(&presentKeyValueJoinDesc); + concatNodeIndex = currentNodeIndex++; + } + + uint32_t causalMaskNodeIndex = 0; + if (unidirectional && !hasMask) + { + opDescs.push_back(&causalMaskDesc); + causalMaskNodeIndex = currentNodeIndex++; + } + DML_INPUT_GRAPH_EDGE_DESC inputToMatMulIntToFloatEdge = {}; inputToMatMulIntToFloatEdge.GraphInputIndex = InputIndex::inputIndex; inputToMatMulIntToFloatEdge.ToNodeIndex = matMulIntToFloatNodeIndex; @@ -462,69 +564,89 @@ public: inputEdges.push_back(maskToMhaEdge); } } - - if (hasSlicedValue) + else if (unidirectional) { - // We need to slice QK and V, and transpose QK - DML_INTERMEDIATE_GRAPH_EDGE_DESC matMulIntToFloatToQueryKeySliceEdge = {}; - matMulIntToFloatToQueryKeySliceEdge.FromNodeIndex = matMulIntToFloatNodeIndex; - matMulIntToFloatToQueryKeySliceEdge.FromNodeOutputIndex = 0; - matMulIntToFloatToQueryKeySliceEdge.ToNodeIndex = queryKeySliceNodeIndex; - matMulIntToFloatToQueryKeySliceEdge.ToNodeInputIndex = 0; - intermediateEdges.push_back(matMulIntToFloatToQueryKeySliceEdge); - - DML_INTERMEDIATE_GRAPH_EDGE_DESC queryKeySliceToTransposeEdge = {}; - queryKeySliceToTransposeEdge.FromNodeIndex = queryKeySliceNodeIndex; - queryKeySliceToTransposeEdge.FromNodeOutputIndex = 0; - queryKeySliceToTransposeEdge.ToNodeIndex = queryKeyTransposedNodeIndex; - queryKeySliceToTransposeEdge.ToNodeInputIndex = 0; - intermediateEdges.push_back(queryKeySliceToTransposeEdge); - - DML_INTERMEDIATE_GRAPH_EDGE_DESC queryKeyTransposedToMhaEdge = {}; - queryKeyTransposedToMhaEdge.FromNodeIndex = queryKeyTransposedNodeIndex; - queryKeyTransposedToMhaEdge.FromNodeOutputIndex = 0; - queryKeyTransposedToMhaEdge.ToNodeIndex = mhaNodeIndex; - queryKeyTransposedToMhaEdge.ToNodeInputIndex = mhaStackedQueryKeyIndex; - intermediateEdges.push_back(queryKeyTransposedToMhaEdge); - - DML_INTERMEDIATE_GRAPH_EDGE_DESC matMulIntToFloatToValueSliceEdge = {}; - matMulIntToFloatToValueSliceEdge.FromNodeIndex = matMulIntToFloatNodeIndex; - matMulIntToFloatToValueSliceEdge.FromNodeOutputIndex = 0; - matMulIntToFloatToValueSliceEdge.ToNodeIndex = valueSliceNodeIndex; - matMulIntToFloatToValueSliceEdge.ToNodeInputIndex = 0; - intermediateEdges.push_back(matMulIntToFloatToValueSliceEdge); - - DML_INTERMEDIATE_GRAPH_EDGE_DESC valueSliceToMhaEdge = {}; - valueSliceToMhaEdge.FromNodeIndex = valueSliceNodeIndex; - valueSliceToMhaEdge.FromNodeOutputIndex = 0; - valueSliceToMhaEdge.ToNodeIndex = mhaNodeIndex; - valueSliceToMhaEdge.ToNodeInputIndex = mhaValueIndex; - intermediateEdges.push_back(valueSliceToMhaEdge); + DML_INTERMEDIATE_GRAPH_EDGE_DESC causalMaskToMhaEdge = {}; + causalMaskToMhaEdge.FromNodeIndex = causalMaskNodeIndex; + causalMaskToMhaEdge.FromNodeOutputIndex = 0; + causalMaskToMhaEdge.ToNodeIndex = mhaNodeIndex; + causalMaskToMhaEdge.ToNodeInputIndex = mhaMaskIndex ; + intermediateEdges.push_back(causalMaskToMhaEdge); } - else + + if (hasPast) { - DML_INTERMEDIATE_GRAPH_EDGE_DESC matMulIntToFloatToQueryKeyValueTransposeEdge = {}; - matMulIntToFloatToQueryKeyValueTransposeEdge.FromNodeIndex = matMulIntToFloatNodeIndex; - matMulIntToFloatToQueryKeyValueTransposeEdge.FromNodeOutputIndex = 0; - matMulIntToFloatToQueryKeyValueTransposeEdge.ToNodeIndex = queryKeyValueTransposedNodeIndex; - matMulIntToFloatToQueryKeyValueTransposeEdge.ToNodeInputIndex = 0; - intermediateEdges.push_back(matMulIntToFloatToQueryKeyValueTransposeEdge); + DML_INPUT_GRAPH_EDGE_DESC pastToPastKeySliceEdge = {}; + pastToPastKeySliceEdge.GraphInputIndex = InputIndex::pastIndex; + pastToPastKeySliceEdge.ToNodeIndex = pastKeySliceNodeIndex; + pastToPastKeySliceEdge.ToNodeInputIndex = 0; + inputEdges.push_back(pastToPastKeySliceEdge); - // All we need to do here is transpose the stacked QKV tensor into something DML supports - DML_INTERMEDIATE_GRAPH_EDGE_DESC queryKeyValueTransposedToMhaEdge = {}; - queryKeyValueTransposedToMhaEdge.FromNodeIndex = queryKeyValueTransposedNodeIndex; - queryKeyValueTransposedToMhaEdge.FromNodeOutputIndex = 0; - queryKeyValueTransposedToMhaEdge.ToNodeIndex = mhaNodeIndex; - queryKeyValueTransposedToMhaEdge.ToNodeInputIndex = mhaStackedQueryKeyValueIndex; - intermediateEdges.push_back(queryKeyValueTransposedToMhaEdge); + DML_INPUT_GRAPH_EDGE_DESC pastToPastValueSliceEdge = {}; + pastToPastValueSliceEdge.GraphInputIndex = InputIndex::pastIndex; + pastToPastValueSliceEdge.ToNodeIndex = pastValueSliceNodeIndex; + pastToPastValueSliceEdge.ToNodeInputIndex = 0; + inputEdges.push_back(pastToPastValueSliceEdge); + + DML_INTERMEDIATE_GRAPH_EDGE_DESC pastKeyToMhaEdge = {}; + pastKeyToMhaEdge.FromNodeIndex = pastKeySliceNodeIndex; + pastKeyToMhaEdge.FromNodeOutputIndex = 0; + pastKeyToMhaEdge.ToNodeIndex = mhaNodeIndex; + pastKeyToMhaEdge.ToNodeInputIndex = mhaPastKeyIndex; + intermediateEdges.push_back(pastKeyToMhaEdge); + + DML_INTERMEDIATE_GRAPH_EDGE_DESC pastValueToMhaEdge = {}; + pastValueToMhaEdge.FromNodeIndex = pastValueSliceNodeIndex; + pastValueToMhaEdge.FromNodeOutputIndex = 0; + pastValueToMhaEdge.ToNodeIndex = mhaNodeIndex; + pastValueToMhaEdge.ToNodeInputIndex = mhaPastValueIndex; + intermediateEdges.push_back(pastValueToMhaEdge); + + DML_INTERMEDIATE_GRAPH_EDGE_DESC presentKeyToConcatEdge = {}; + presentKeyToConcatEdge.FromNodeIndex = mhaNodeIndex; + presentKeyToConcatEdge.FromNodeOutputIndex = mhaPresentKeyIndex; + presentKeyToConcatEdge.ToNodeIndex = concatNodeIndex; + presentKeyToConcatEdge.ToNodeInputIndex = 0; + intermediateEdges.push_back(presentKeyToConcatEdge); + + DML_INTERMEDIATE_GRAPH_EDGE_DESC presentValueToConcatEdge = {}; + presentValueToConcatEdge.FromNodeIndex = mhaNodeIndex; + presentValueToConcatEdge.FromNodeOutputIndex = mhaPresentValueIndex; + presentValueToConcatEdge.ToNodeIndex = concatNodeIndex; + presentValueToConcatEdge.ToNodeInputIndex = 1; + intermediateEdges.push_back(presentValueToConcatEdge); } + DML_INTERMEDIATE_GRAPH_EDGE_DESC matMulIntToFloatToQueryKeyValueTransposeEdge = {}; + matMulIntToFloatToQueryKeyValueTransposeEdge.FromNodeIndex = matMulIntToFloatNodeIndex; + matMulIntToFloatToQueryKeyValueTransposeEdge.FromNodeOutputIndex = 0; + matMulIntToFloatToQueryKeyValueTransposeEdge.ToNodeIndex = queryKeyValueTransposedNodeIndex; + matMulIntToFloatToQueryKeyValueTransposeEdge.ToNodeInputIndex = 0; + intermediateEdges.push_back(matMulIntToFloatToQueryKeyValueTransposeEdge); + + // All we need to do here is transpose the stacked QKV tensor into something DML supports + DML_INTERMEDIATE_GRAPH_EDGE_DESC queryKeyValueTransposedToMhaEdge = {}; + queryKeyValueTransposedToMhaEdge.FromNodeIndex = queryKeyValueTransposedNodeIndex; + queryKeyValueTransposedToMhaEdge.FromNodeOutputIndex = 0; + queryKeyValueTransposedToMhaEdge.ToNodeIndex = mhaNodeIndex; + queryKeyValueTransposedToMhaEdge.ToNodeInputIndex = mhaStackedQueryKeyValueIndex; + intermediateEdges.push_back(queryKeyValueTransposedToMhaEdge); + DML_OUTPUT_GRAPH_EDGE_DESC mhaToOutputEdge = {}; mhaToOutputEdge.FromNodeIndex = mhaNodeIndex; - mhaToOutputEdge.FromNodeOutputIndex = 0; - mhaToOutputEdge.GraphOutputIndex = 0; + mhaToOutputEdge.FromNodeOutputIndex = mhaOutputIndex; + mhaToOutputEdge.GraphOutputIndex = OutputIndex::outputIndex; outputEdges.push_back(mhaToOutputEdge); + if (hasPast) + { + DML_OUTPUT_GRAPH_EDGE_DESC concatToOutputEdge = {}; + concatToOutputEdge.FromNodeIndex = concatNodeIndex; + concatToOutputEdge.FromNodeOutputIndex = 0; + concatToOutputEdge.GraphOutputIndex = OutputIndex::presentIndex; + outputEdges.push_back(concatToOutputEdge); + } + MLOperatorGraphDesc operatorGraphDesc = {}; operatorGraphDesc.inputEdgeCount = gsl::narrow_cast(inputEdges.size()); operatorGraphDesc.inputEdges = inputEdges.data(); @@ -542,21 +664,10 @@ public: void CALLBACK QueryQAttention(IMLOperatorSupportQueryContextPrivate* context, /*out*/ bool* isSupported) { *isSupported = false; - // `past` input tensor is not supported yet - if (context->IsInputValid(8)) - { - return; - } - // `present` output tensor is not supported yet - if (context->IsOutputValid(1)) - { - return; - } - - // `unidirectional == 1` is not supported yet + // `unidirectional == 1` with Mask Tensor is not supported yet MLOperatorAttributes attributes(context); - if (attributes.GetOptionalAttribute(AttrName::Unidirectional, 0) != 0) + if (attributes.GetOptionalAttribute(AttrName::Unidirectional, 0) != 0 && context->IsInputValid(5)) { return; } @@ -567,6 +678,12 @@ void CALLBACK QueryQAttention(IMLOperatorSupportQueryContextPrivate* context, /* return; } + // `past_present_share_buffer == 1` is not supported yet + if (attributes.GetOptionalAttribute(AttrName::PastPresentShareBuffer, 0) != 0) + { + return; + } + *isSupported = true; } diff --git a/onnxruntime/core/providers/dml/OperatorAuthorHelper/Attributes.h b/onnxruntime/core/providers/dml/OperatorAuthorHelper/Attributes.h index 543e30fcd9..9d00ea0079 100644 --- a/onnxruntime/core/providers/dml/OperatorAuthorHelper/Attributes.h +++ b/onnxruntime/core/providers/dml/OperatorAuthorHelper/Attributes.h @@ -107,6 +107,7 @@ namespace AttrName static constexpr const char* QkvHiddenSizes = "qkv_hidden_sizes"; static constexpr const char* Unidirectional = "unidirectional"; static constexpr const char* NumHeads = "num_heads"; + static constexpr const char* PastPresentShareBuffer = "past_present_share_buffer"; static constexpr const char* FusedActivation = "fused_activation"; static constexpr const char* FusedActivationDomain = "fused_activation_domain"; diff --git a/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.cpp b/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.cpp index 1fcd3b0430..61e13f7bca 100644 --- a/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.cpp +++ b/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.cpp @@ -2744,6 +2744,48 @@ namespace OperatorHelper m_qkvHiddenSizes = kernelInformation.GetAttributes().GetOptionalAttributeVectorInt32(AttrName::QkvHiddenSizes); } + std::vector QAttentionHelper::GetOutputShapes(const MLShapeInferenceContext& shapeInfo) const + { + ML_CHECK_VALID_ARGUMENT(shapeInfo.GetInputCount() >= 5); + + auto queryShape = shapeInfo.GetInputTensorShape(0); + ML_CHECK_VALID_ARGUMENT(queryShape.size() == 3); + + auto weightShape = shapeInfo.GetInputTensorShape(1); + ML_CHECK_VALID_ARGUMENT(weightShape.size() == 2); + ML_CHECK_VALID_ARGUMENT(weightShape[1] % 3 == 0); + + const uint32_t batchSize = queryShape[0]; + const uint32_t sequenceLength = queryShape[1]; + const uint32_t hiddenSize = weightShape[1] / 3; + const uint32_t headSize = hiddenSize / m_numHeads; + + std::vector outputShapes(2); + + outputShapes[0] = EdgeShapes({batchSize, sequenceLength, hiddenSize}); + + uint32_t totalSequenceLength = sequenceLength; + if (shapeInfo.IsInputValid(8)) + { + ML_CHECK_VALID_ARGUMENT(shapeInfo.GetInputTensorDimensionCount(8) == 5); + const uint32_t pastSequenceLength = shapeInfo.GetInputTensorShape(8)[3]; + totalSequenceLength += pastSequenceLength; + } + + if (shapeInfo.IsOutputValid(1)) + { + ML_CHECK_VALID_ARGUMENT(shapeInfo.IsInputValid(8)); + outputShapes[1] = EdgeShapes({2, batchSize, m_numHeads, totalSequenceLength, headSize}); + } + + return outputShapes; + } + + void QAttentionHelper::Initialize(const IKernelInformationAdapter& kernelInformation) + { + m_numHeads = gsl::narrow_cast(kernelInformation.GetAttributes().GetAttribute(AttrName::NumHeads)); + } + std::vector SkipLayerNormHelper::GetOutputShapes(const MLShapeInferenceContext& shapeInfo) const { ML_CHECK_VALID_ARGUMENT(shapeInfo.GetInputCount() >= 3); diff --git a/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.h b/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.h index 8d7f0b5b04..fe6e53ef42 100644 --- a/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.h +++ b/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.h @@ -1490,6 +1490,22 @@ private: std::vector m_qkvHiddenSizes; }; +class QAttentionHelper +{ +public: + template + QAttentionHelper(const Info_t& info, const Shape_t& shapeInfo) + { + Initialize(KernelInformationAdapter(info)); + } + + std::vector GetOutputShapes(const MLShapeInferenceContext& shapeInfo) const; + +private: + void Initialize(const IKernelInformationAdapter& kernelInformation); + uint32_t m_numHeads; +}; + class SkipLayerNormHelper { public: @@ -1630,7 +1646,7 @@ using ShapeInferenceHelper_Affine = GetOutputShapeAsInputShapeHelper; using ShapeInferenceHelper_QuantizeLinear = GetOutputShapeAsInputShapeHelper; using ShapeInferenceHelper_DequantizeLinear = GetOutputShapeAsInputShapeHelper; using ShapeInferenceHelper_QLinearSigmoid = GetOutputShapeAsInputShapeHelper; -using ShapeInferenceHelper_QAttention = AttentionHelper; +using ShapeInferenceHelper_QAttention = QAttentionHelper; using ShapeInferenceHelper_Attention = AttentionHelper; using ShapeInferenceHelper_MultiHeadAttention = MultiHeadAttentionHelper; using ShapeInferenceHelper_Sign = GetBroadcastedOutputShapeHelper;