From 89aa4697b145b3879c8a0f40db353af1e5e5003a Mon Sep 17 00:00:00 2001 From: raoanag <127366241+raoanag@users.noreply.github.com> Date: Mon, 11 Mar 2024 09:44:34 -0800 Subject: [PATCH] [DML] QAttention (#19766) ### Description DML Implementation for [com.microsoft.QAttention](https://github.com/microsoft/onnxruntime/blob/main/docs/ContribOperators.md#com.microsoft.QAttention) ### Motivation and Context --------- Co-authored-by: Xiang Zhang --- docs/OperatorKernels.md | 1 + .../src/Operators/DmlOperatorQAttention.cpp | 704 ++++++++++++++++++ .../Operators/DmlOperatorQLinearSigmoid.cpp | 2 +- .../src/Operators/OperatorRegistration.cpp | 23 +- .../OperatorAuthorHelper/OperatorHelper.cpp | 42 ++ .../dml/OperatorAuthorHelper/OperatorHelper.h | 17 + .../OperatorAuthorHelper/OperatorVersions.h | 1 + .../contrib_ops/quantize_attention_op_test.cc | 60 +- 8 files changed, 839 insertions(+), 11 deletions(-) create mode 100644 onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorQAttention.cpp diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md index 955957f295..eddc3b7873 100644 --- a/docs/OperatorKernels.md +++ b/docs/OperatorKernels.md @@ -1277,6 +1277,7 @@ Do not modify directly.* |MatMulIntegerToFloat|*in* A:**T1**
*in* B:**T2**
*in* a_scale:**T3**
*in* b_scale:**T3**
*in* a_zero_point:**T1**
*in* b_zero_point:**T2**
*in* bias:**T3**
*out* Y:**T3**|1+|**T1** = tensor(int8), tensor(uint8)
**T2** = tensor(int8), tensor(uint8)
**T3** = tensor(float), tensor(float16)| |MultiHeadAttention|*in* query:**T**
*in* key:**T**
*in* value:**T**
*in* bias:**T**
*in* key_padding_mask:**M**
*in* relative_position_bias:**T**
*in* past_key:**T**
*in* past_value:**T**
*out* output:**T**
*out* present_key:**T**
*out* present_value:**T**|1+|**M** = tensor(int32)
**T** = tensor(float), tensor(float16)| |NhwcConv|*in* X:**T**
*in* W:**T**
*in* B:**T**
*out* Y:**T**|1+|**T** = tensor(float), tensor(float16)| +|QAttention|*in* input:**T1**
*in* weight:**T2**
*in* bias:**T3**
*in* input_scale:**T3**
*in* weight_scale:**T3**
*in* mask_index:**T4**
*in* input_zero_point:**T1**
*in* weight_zero_point:**T2**
*in* past:**T3**
*out* output:**T3**
*out* present:**T3**|1+|**T1** = tensor(int8), tensor(uint8)
**T2** = tensor(int8), tensor(uint8)
**T3** = tensor(float), tensor(float16)
**T4** = tensor(int32)| |QLinearAdd|*in* A:**T**
*in* A_scale:**tensor(float)**
*in* A_zero_point:**T**
*in* B:**T**
*in* B_scale:**tensor(float)**
*in* B_zero_point:**T**
*in* C_scale:**tensor(float)**
*in* C_zero_point:**T**
*out* C:**T**|1+|**T** = tensor(int8), tensor(uint8)| |QLinearAveragePool|*in* X:**T**
*in* x_scale:**tensor(float)**
*in* x_zero_point:**T**
*in* y_scale:**tensor(float)**
*in* y_zero_point:**T**
*out* Y:**T**|1+|**T** = tensor(int8), tensor(uint8)| |QLinearConcat|*in* Y_scale:**TF**
*in* Y_zero_point:**T8**
*in* inputs:**TV**
*out* Y:**T8**|1+|**T8** = tensor(int8), tensor(uint8)
**TF** = tensor(float)
**TV** = tensor(float), tensor(int8), tensor(uint8)| diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorQAttention.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorQAttention.cpp new file mode 100644 index 0000000000..f9519b26bb --- /dev/null +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorQAttention.cpp @@ -0,0 +1,704 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "precomp.h" + +/* +Abbreviations: B is batch_size, S is sequence_length, W is hidden_size + N is number of attention heads, H is head size, and W=N*H + +Input, Weight, Bias, Mask Index and Past are Inputs + +Mask Index/Causal Input Weight Bias + | \ | / + | \ | / + | \ | / + | MatMulIntToFloat + | / | \ + | / | \ + | / | \ + | Slice Slice Slice + | | | | + | | | | + | Identity Identity Identity // The identities are used to transpose NCHW -> NHCW while + | | | | // keeping the GEMM strides as NCHW to better target metacommands + | | | | + | | | | Past + | | | | / \ + | | | | / \ + | | | | Slice Slice + | | | | | | + | | | | | | + | | | | | | + --------------------------MHA ----------- + / | \ + / | \ + / | \ + / | \ + / | \ + / | \ + / presentKey presentValue + / \ / + / \ / + / \ / + / Concat + / | + Output1 Output2 (present) + + This kernel creates a DML_GRAPH, as mentioned above. + For reference, refer to this Doc: + https://github.com/microsoft/onnxruntime/blob/main/docs/ContribOperators.md#commicrosoftqattention + */ + +namespace Dml +{ +class DmlOperatorQAttention : public DmlOperator +{ +public: + DmlOperatorQAttention(const MLOperatorKernelCreationContext& kernelCreationContext) + : DmlOperator(kernelCreationContext) + { + enum InputIndex : uint32_t + { + inputIndex, + weightsIndex, + biasIndex, + inputScaleIndex, + weightScaleIndex, + maskIndex, + inputZeroPointIndex, + weightZeroPointIndex, + pastIndex, + inputCount, + }; + + enum OutputIndex : uint32_t + { + outputIndex, + presentIndex, + outputCount, + }; + + enum MhaInputIndex : uint32_t + { + mhaQueryIndex, + mhaKeyIndex, + mhaValueIndex, + mhaStackedQueryKeyIndex, + mhaStackedKeyValueIndex, + mhaStackedQueryKeyValueIndex, + mhaBiasIndex, + mhaMaskIndex, + mhaRelativePositionBiasIndex, + mhaPastKeyIndex, + mhaPastValueIndex, + mhaInputCount, + }; + + enum MhaOutputIndex : uint32_t + { + mhaOutputIndex, + mhaPresentKeyIndex, + mhaPresentValueIndex, + mhaOutputCount, + }; + + ML_CHECK_VALID_ARGUMENT(kernelCreationContext.GetInputCount() >= 5); + ML_CHECK_VALID_ARGUMENT(kernelCreationContext.GetOutputCount() >= 1); + + const bool hasBias = kernelCreationContext.IsInputValid(biasIndex); + const bool hasMask = kernelCreationContext.IsInputValid(maskIndex); + const bool hasUnpaddedBounds = hasMask && kernelCreationContext.GetInputTensorDimensionCount(maskIndex) == 1; + const bool hasPast = kernelCreationContext.IsInputValid(pastIndex); + + DmlOperator::Initialize(kernelCreationContext, std::nullopt, std::nullopt, std::nullopt, std::nullopt, 1); + + const bool unidirectional = gsl::narrow_cast(kernelCreationContext.GetAttribute(AttrName::Unidirectional)); + const uint32_t numHeads = gsl::narrow_cast(kernelCreationContext.GetAttribute(AttrName::NumHeads)); + ML_CHECK_VALID_ARGUMENT(numHeads > 0); // to avoid process crash because of division by zero. + + auto inputTensorShape = m_inputTensorDescs[inputIndex].GetSizes(); + ML_CHECK_VALID_ARGUMENT(inputTensorShape.size() == 3); + + auto weightTensorShape = m_inputTensorDescs[weightsIndex].GetSizes(); + ML_CHECK_VALID_ARGUMENT(weightTensorShape.size() == 2); + ML_CHECK_VALID_ARGUMENT(weightTensorShape[0] == inputTensorShape[2]); + ML_CHECK_VALID_ARGUMENT(weightTensorShape[1] % 3 == 0); + + if (hasBias) + { + auto biasTensorShape = m_inputTensorDescs[biasIndex].GetSizes(); + ML_CHECK_VALID_ARGUMENT(biasTensorShape.size() == 1); + ML_CHECK_VALID_ARGUMENT(biasTensorShape[0] % 3 == 0); + ML_CHECK_VALID_ARGUMENT(weightTensorShape[1] == biasTensorShape[0]); + } + + if (hasPast) + { + ML_CHECK_VALID_ARGUMENT(kernelCreationContext.IsOutputValid(presentIndex)); + } + + const uint32_t hiddenSize = weightTensorShape[1] / 3; + const uint32_t headSize = hiddenSize / numHeads; + const uint32_t batchSize = inputTensorShape[0]; + const uint32_t sequenceLength = inputTensorShape[1]; + const uint32_t pastSequenceLength = hasPast ? m_inputTensorDescs[pastIndex].GetSizes()[3] : 0; + + uint32_t desiredWeightTensorShape[3] = {batchSize, weightTensorShape[0], 3 * hiddenSize}; + MLOperatorTensorDataType dataType = kernelCreationContext.GetOutputEdgeDescription(outputIndex).tensorDataType; + + m_inputTensorDescs[weightsIndex] = TensorDesc::ConstructBroadcastedTensorDesc( + kernelCreationContext.GetInputEdgeDescription(weightsIndex).tensorDataType, + desiredWeightTensorShape, + weightTensorShape); + + uint32_t desiredBiasTensorShape[3] = {batchSize, sequenceLength, 3 * hiddenSize}; + + if (hasBias) + { + auto biasTensorShape = m_inputTensorDescs[biasIndex].GetSizes(); + m_inputTensorDescs[biasIndex] = TensorDesc::ConstructBroadcastedTensorDesc(kernelCreationContext.GetInputEdgeDescription(biasIndex).tensorDataType, desiredBiasTensorShape, biasTensorShape); + } + + MLOperatorTensorDataType maskTensorDataType = MLOperatorTensorDataType::Undefined; + bool hasMaxSequenceMask = false; + DML_MULTIHEAD_ATTENTION_MASK_TYPE maskType = DML_MULTIHEAD_ATTENTION_MASK_TYPE_NONE; + if (hasMask) + { + if (hasUnpaddedBounds) + { + auto unpaddedKeyBoundsShape = m_inputTensorDescs[maskIndex].GetSizes(); + ML_CHECK_VALID_ARGUMENT(unpaddedKeyBoundsShape.size() == 1); + + const uint32_t batchGroupCount = unpaddedKeyBoundsShape[0] / batchSize; + ML_CHECK_VALID_ARGUMENT(batchGroupCount == 1 || batchGroupCount == 2); + + uint32_t desiredShape[2] = {batchGroupCount, batchSize}; + m_inputTensorDescs[maskIndex] = TensorDesc( + m_inputTensorDescs[maskIndex].GetDmlDataType(), + desiredShape); + + maskType = batchGroupCount == 1 + ? DML_MULTIHEAD_ATTENTION_MASK_TYPE_KEY_SEQUENCE_LENGTH + : DML_MULTIHEAD_ATTENTION_MASK_TYPE_KEY_SEQUENCE_END_START; + } + else + { + auto maskIndexTensorShape = m_inputTensorDescs[maskIndex].GetSizes(); + ML_CHECK_VALID_ARGUMENT(maskIndexTensorShape.size() > 1 && maskIndexTensorShape.size() <= 4); + + maskType = DML_MULTIHEAD_ATTENTION_MASK_TYPE_BOOLEAN; + std::vector reshapedMaskIndexTensorShape(maskIndexTensorShape.begin(), maskIndexTensorShape.end()); + if (maskIndexTensorShape.size() == 4 && maskIndexTensorShape[2] != sequenceLength) + { + hasMaxSequenceMask = true; + ML_CHECK_VALID_ARGUMENT(maskIndexTensorShape[2] == maskIndexTensorShape[3]); + const uint32_t maxSequenceLength = maskIndexTensorShape[2]; + uint32_t desiredMaskIndexShape[4] = {batchSize, numHeads, maxSequenceLength, maxSequenceLength}; + maskTensorDataType = kernelCreationContext.GetInputEdgeDescription(maskIndex).tensorDataType; + m_inputTensorDescs[maskIndex] = TensorDesc::ConstructBroadcastedTensorDesc(maskTensorDataType, desiredMaskIndexShape, reshapedMaskIndexTensorShape); + } + else + { + uint32_t maskIndexDimensionCount = gsl::narrow_cast(maskIndexTensorShape.size()); + reshapedMaskIndexTensorShape.insert(reshapedMaskIndexTensorShape.begin() + 1, 4 - maskIndexDimensionCount, 1); + uint32_t desiredMaskIndexShape[4] = {batchSize, numHeads, sequenceLength, sequenceLength}; + maskTensorDataType = kernelCreationContext.GetInputEdgeDescription(maskIndex).tensorDataType; + m_inputTensorDescs[maskIndex] = TensorDesc::ConstructBroadcastedTensorDesc(maskTensorDataType, desiredMaskIndexShape, reshapedMaskIndexTensorShape); + } + } + } + + MLOperatorTensorDataType pastTensorDataType = MLOperatorTensorDataType::Undefined; + MLOperatorTensorDataType presentTensorDataType = MLOperatorTensorDataType::Undefined; + if (hasPast) + { + pastTensorDataType = kernelCreationContext.GetInputEdgeDescription(pastIndex).tensorDataType; + presentTensorDataType = kernelCreationContext.GetOutputEdgeDescription(presentIndex).tensorDataType; + } + + TensorDesc matMulIntToFloatOutputTensorDesc = TensorDesc::ConstructDefaultTensorDesc(dataType, desiredBiasTensorShape); + DML_TENSOR_DESC namedMatMulIntToFloatOutputTensorDesc = matMulIntToFloatOutputTensorDesc.GetDmlDesc(); + + std::vector inputDescs = GetDmlInputDescs(); + std::vector outputDescs = GetDmlOutputDescs(); + + DML_MATRIX_MULTIPLY_INTEGER_TO_FLOAT_OPERATOR_DESC matMulIntToFloatOperatorDesc = {}; + matMulIntToFloatOperatorDesc.ATensor = &inputDescs[InputIndex::inputIndex]; + matMulIntToFloatOperatorDesc.AScaleTensor = &inputDescs[InputIndex::inputScaleIndex]; + matMulIntToFloatOperatorDesc.AZeroPointTensor = &inputDescs[InputIndex::inputZeroPointIndex]; + matMulIntToFloatOperatorDesc.BTensor = &inputDescs[InputIndex::weightsIndex]; + matMulIntToFloatOperatorDesc.BScaleTensor = &inputDescs[InputIndex::weightScaleIndex]; + matMulIntToFloatOperatorDesc.BZeroPointTensor = &inputDescs[InputIndex::weightZeroPointIndex]; + matMulIntToFloatOperatorDesc.BiasTensor = hasBias ? &inputDescs[InputIndex::biasIndex] : nullptr; + matMulIntToFloatOperatorDesc.OutputTensor = &namedMatMulIntToFloatOutputTensorDesc; + + const DML_OPERATOR_DESC matMulIntToFloatDesc = { DML_OPERATOR_MATRIX_MULTIPLY_INTEGER_TO_FLOAT, &matMulIntToFloatOperatorDesc}; + + std::array queryKeySlicedTensorShape = {batchSize, sequenceLength, hiddenSize + hiddenSize}; + TensorDesc queryKeySlicedInputTensorDesc = TensorDesc::ConstructDefaultTensorDesc(dataType, queryKeySlicedTensorShape); + DML_TENSOR_DESC namedQueryKeySlicedInputTensorDesc = queryKeySlicedInputTensorDesc.GetDmlDesc(); + + std::array valueSlicedTensorShape = {batchSize, sequenceLength, hiddenSize}; + TensorDesc valueSlicedInputTensorDesc = TensorDesc::ConstructDefaultTensorDesc(dataType, valueSlicedTensorShape); + DML_TENSOR_DESC namedValueSlicedInputTensorDesc = valueSlicedInputTensorDesc.GetDmlDesc(); + + // Transpose slice QK from [batchSize, sequenceLength, 2, numHeads, headSize] to [batchSize, sequenceLength, numHeads, 2, headSize] + std::array queryKeyTransposedTensorShape = {batchSize, sequenceLength, numHeads, 2, headSize}; + std::array queryKeyTransposedStrides = { + sequenceLength * numHeads * 2 * headSize, + numHeads * 2 * headSize, + headSize, + numHeads * headSize, + 1, + }; + + TensorDesc queryKeyTransposedInputTensorDesc = TensorDesc( + GetDmlDataTypeFromMlDataType(dataType), + queryKeyTransposedTensorShape, + queryKeyTransposedStrides); + DML_TENSOR_DESC namedQueryKeyTransposedInputTensorDesc = queryKeyTransposedInputTensorDesc.GetDmlDesc(); + + TensorDesc queryKeyTransposedOutputTensorDesc = TensorDesc( + GetDmlDataTypeFromMlDataType(dataType), + queryKeyTransposedTensorShape); + DML_TENSOR_DESC namedQueryKeyTransposedOutputTensorDesc = queryKeyTransposedOutputTensorDesc.GetDmlDesc(); + + // Transpose QKV from [batchSize, sequenceLength, 3, numHeads, headSize] to [batchSize, sequenceLength, numHeads, 3, headSize] + std::array queryKeyValueTransposedTensorShape = {batchSize, sequenceLength, numHeads, 3, headSize}; + std::array queryKeyValueTransposedStrides = { + sequenceLength * numHeads * 3 * headSize, + numHeads * 3 * headSize, + headSize, + numHeads * headSize, + 1, + }; + + TensorDesc queryKeyValueTransposedInputTensorDesc = TensorDesc( + GetDmlDataTypeFromMlDataType(dataType), + queryKeyValueTransposedTensorShape, + queryKeyValueTransposedStrides); + DML_TENSOR_DESC namedQueryKeyValueTransposedInputTensorDesc = queryKeyValueTransposedInputTensorDesc.GetDmlDesc(); + + TensorDesc queryKeyValueTransposedOutputTensorDesc = TensorDesc( + GetDmlDataTypeFromMlDataType(dataType), + queryKeyValueTransposedTensorShape); + DML_TENSOR_DESC namedQueryKeyValueTransposedOutputTensorDesc = queryKeyValueTransposedOutputTensorDesc.GetDmlDesc(); + + std::array queryKeySliceOffset = {0, 0, 0}; + std::array queryKeySliceSize = {batchSize, sequenceLength, hiddenSize + hiddenSize}; + std::array queryKeySliceStrides = {1, 1, 1}; + + std::array valueSliceOffset = {0, 0, 2 * hiddenSize}; + std::array valueSliceSize = {batchSize, sequenceLength, hiddenSize}; + std::array valueSliceStrides = {1, 1, 1}; + + // When Q/K/V all have the same hidden size, we just have to transpose it before sending it to MHA + DML_ELEMENT_WISE_IDENTITY_OPERATOR_DESC transposeOperatorDesc = {}; + + transposeOperatorDesc.InputTensor = &namedQueryKeyValueTransposedInputTensorDesc; + transposeOperatorDesc.OutputTensor = &namedQueryKeyValueTransposedOutputTensorDesc; + + const DML_OPERATOR_DESC transposedDesc = { DML_OPERATOR_ELEMENT_WISE_IDENTITY, &transposeOperatorDesc}; + + std::array maskSliceOutputShape = {batchSize, numHeads, sequenceLength, sequenceLength}; + std::array maskSliceStrides = {1, 1, 1, 1}; + std::array maskSliceOffsets = {0, 0, 0, 0}; + TensorDesc maskSliceOutputTensorDesc; + DML_TENSOR_DESC namedMaskSliceOutputTensorDesc; + + DML_SLICE1_OPERATOR_DESC maskSlicedOperatorDesc = {}; + if (hasMaxSequenceMask) + { + maskSliceOutputTensorDesc = TensorDesc::ConstructDefaultTensorDesc(maskTensorDataType, maskSliceOutputShape); + namedMaskSliceOutputTensorDesc = maskSliceOutputTensorDesc.GetDmlDesc(); + maskSlicedOperatorDesc.InputTensor = &inputDescs[maskIndex]; + maskSlicedOperatorDesc.OutputTensor = &namedMaskSliceOutputTensorDesc; + maskSlicedOperatorDesc.DimensionCount = gsl::narrow_cast(maskSliceOutputShape.size()); + maskSlicedOperatorDesc.InputWindowOffsets = maskSliceOffsets.data(); + maskSlicedOperatorDesc.InputWindowSizes = maskSliceOutputShape.data(); + maskSlicedOperatorDesc.InputWindowStrides = maskSliceStrides.data(); + } + const DML_OPERATOR_DESC maskSlicedDesc = { DML_OPERATOR_SLICE1, &maskSlicedOperatorDesc}; + + // We need to slice Past to get PastValue and PastKey tensors for MHA + std::array pastKeyOutputShape = {1, batchSize, numHeads, pastSequenceLength, headSize}; + std::array pastKeyStrides = {1, 1, 1, 1, 1}; + std::array pastKeyOffsets = {0, 0, 0, 0, 0}; + TensorDesc pastKeyOutputTensorDesc; + DML_TENSOR_DESC namedPastKeyOutputTensorDesc; + + std::array pastValueOutputShape = {1, batchSize, numHeads, pastSequenceLength, headSize}; + std::array pastValueStrides = {1, 1, 1, 1, 1}; + std::array pastValueOffsets = {1, 0, 0, 0, 0}; + TensorDesc pastValueOutputTensorDesc; + DML_TENSOR_DESC namedPastValueOutputTensorDesc; + + DML_SLICE1_OPERATOR_DESC pastKeySlicedOperatorDesc = {}; + DML_SLICE1_OPERATOR_DESC pastValueSlicedOperatorDesc = {}; + + if (hasPast) + { + pastKeyOutputTensorDesc = TensorDesc::ConstructDefaultTensorDesc(pastTensorDataType, pastKeyOutputShape); + namedPastKeyOutputTensorDesc = pastKeyOutputTensorDesc.GetDmlDesc(); + pastKeySlicedOperatorDesc.InputTensor = &inputDescs[pastIndex]; + pastKeySlicedOperatorDesc.OutputTensor = &namedPastKeyOutputTensorDesc; + pastKeySlicedOperatorDesc.DimensionCount = gsl::narrow_cast(pastKeyOutputShape.size()); + pastKeySlicedOperatorDesc.InputWindowOffsets = pastKeyOffsets.data(); + pastKeySlicedOperatorDesc.InputWindowSizes = pastKeyOutputShape.data(); + pastKeySlicedOperatorDesc.InputWindowStrides = pastKeyStrides.data(); + + pastValueOutputTensorDesc = TensorDesc::ConstructDefaultTensorDesc(pastTensorDataType, pastValueOutputShape); + namedPastValueOutputTensorDesc = pastValueOutputTensorDesc.GetDmlDesc(); + pastValueSlicedOperatorDesc.InputTensor = &inputDescs[pastIndex]; + pastValueSlicedOperatorDesc.OutputTensor = &namedPastValueOutputTensorDesc; + pastValueSlicedOperatorDesc.DimensionCount = gsl::narrow_cast(pastValueOutputShape.size()); + pastValueSlicedOperatorDesc.InputWindowOffsets = pastValueOffsets.data(); + pastValueSlicedOperatorDesc.InputWindowSizes = pastValueOutputShape.data(); + pastValueSlicedOperatorDesc.InputWindowStrides = pastValueStrides.data(); + } + + const DML_OPERATOR_DESC pastKeySlicedDesc = { DML_OPERATOR_SLICE1, &pastKeySlicedOperatorDesc}; + const DML_OPERATOR_DESC pastValueSlicedDesc = { DML_OPERATOR_SLICE1, &pastValueSlicedOperatorDesc}; + + // Causal Mask: Upper Triangular Boolean Matrix + // Example: [[1, 0, 0, 0, 0], + // [1, 1, 0, 0, 0], + // [1, 1, 1, 0, 0], + // [1, 1, 1, 1, 0]] + // DML adds maskFilterValue to the "off" bits in the mask and sets the "on" bits to 0 + // passed to MHA as maskIndex Tensor when unidirectional == 1 + std::array causalMaskOutputShape = {1, 1, sequenceLength, pastSequenceLength + sequenceLength}; + TensorDesc causalMaskTensorDesc; + DML_DIAGONAL_MATRIX1_OPERATOR_DESC causalMaskOperatorDesc = {}; + DML_TENSOR_DESC namedcausalMaskTensorDesc; + + if (unidirectional && !hasMask) + { + causalMaskTensorDesc = TensorDesc::ConstructDefaultTensorDesc(MLOperatorTensorDataType::Int32, causalMaskOutputShape); + namedcausalMaskTensorDesc = causalMaskTensorDesc.GetDmlDesc(); + causalMaskOperatorDesc.ValueDataType = DML_TENSOR_DATA_TYPE_INT32; + causalMaskOperatorDesc.DiagonalFillBegin = INT32_MIN; + causalMaskOperatorDesc.DiagonalFillEnd = pastSequenceLength + 1; + causalMaskOperatorDesc.Value.Int32 = 1; + causalMaskOperatorDesc.OutputTensor = &namedcausalMaskTensorDesc; + maskType = DML_MULTIHEAD_ATTENTION_MASK_TYPE_BOOLEAN; + } + DML_OPERATOR_DESC causalMaskDesc = { DML_OPERATOR_DIAGONAL_MATRIX1, &causalMaskOperatorDesc }; + + DML_MULTIHEAD_ATTENTION_OPERATOR_DESC mhaOperatorDesc = {}; + std::array presentKeyOutputShape = {1, batchSize, numHeads, pastSequenceLength + sequenceLength, headSize}; + std::array presentValueOutputShape = {1, batchSize, numHeads, pastSequenceLength + sequenceLength, headSize}; + TensorDesc presentKeyTensorDesc; + TensorDesc presentValueTensorDesc; + DML_TENSOR_DESC namedPresentKeyOutputTensorDesc; + DML_TENSOR_DESC namedPresentValueOutputTensorDesc; + + mhaOperatorDesc.StackedQueryKeyValueTensor = &namedQueryKeyValueTransposedOutputTensorDesc; + + // Broadcast to MHA MaskTensor Shape + std::array mhaMaskTensorShape = {batchSize, numHeads, sequenceLength, pastSequenceLength + sequenceLength}; + TensorDesc broadcastedcausalMaskTensorDesc; + DML_TENSOR_DESC namedbroadcastedcausalMaskTensorDesc; + if (unidirectional && !hasMask) + { + broadcastedcausalMaskTensorDesc = TensorDesc::ConstructBroadcastedTensorDesc(MLOperatorTensorDataType::Int32, mhaMaskTensorShape, causalMaskOutputShape); + namedbroadcastedcausalMaskTensorDesc = broadcastedcausalMaskTensorDesc.GetDmlDesc(); + mhaOperatorDesc.MaskTensor = &namedbroadcastedcausalMaskTensorDesc; + } + else if (hasMaxSequenceMask) + { + mhaOperatorDesc.MaskTensor = &namedMaskSliceOutputTensorDesc; + } + else + { + mhaOperatorDesc.MaskTensor = hasMask ? &inputDescs[maskIndex] : nullptr; + } + + mhaOperatorDesc.RelativePositionBiasTensor = nullptr; + mhaOperatorDesc.OutputTensor = &outputDescs[outputIndex]; + mhaOperatorDesc.Scale = kernelCreationContext.GetOptionalAttribute(AttrName::Scale, gsl::narrow_cast(1.0f / std::sqrt(headSize))); + // Set MaskFilterValue to lowest float for Causal Mask + mhaOperatorDesc.MaskFilterValue = unidirectional ? std::numeric_limits::lowest() : + kernelCreationContext.GetOptionalAttribute(AttrName::MaskFilterValue, -10'000.0f); + mhaOperatorDesc.HeadCount = numHeads; + mhaOperatorDesc.MaskType = maskType; + if (hasPast) + { + presentKeyTensorDesc = TensorDesc::ConstructDefaultTensorDesc(presentTensorDataType, presentKeyOutputShape); + namedPresentKeyOutputTensorDesc = presentKeyTensorDesc.GetDmlDesc(); + presentValueTensorDesc = TensorDesc::ConstructDefaultTensorDesc(presentTensorDataType, presentValueOutputShape); + namedPresentValueOutputTensorDesc = presentValueTensorDesc.GetDmlDesc(); + mhaOperatorDesc.PastKeyTensor = &namedPastKeyOutputTensorDesc; + mhaOperatorDesc.PastValueTensor = &namedPastValueOutputTensorDesc; + mhaOperatorDesc.OutputPresentKeyTensor = &namedPresentKeyOutputTensorDesc; + mhaOperatorDesc.OutputPresentValueTensor = &namedPresentValueOutputTensorDesc; + } + + const DML_OPERATOR_DESC mhaDesc = { DML_OPERATOR_MULTIHEAD_ATTENTION, &mhaOperatorDesc }; + + DML_JOIN_OPERATOR_DESC presentKeyValueJoinOperatorDesc = {}; + std::vector joinInputDesc; + + if (hasPast) + { + joinInputDesc.push_back(namedPresentKeyOutputTensorDesc); + joinInputDesc.push_back(namedPresentValueOutputTensorDesc); + presentKeyValueJoinOperatorDesc.InputCount = gsl::narrow_cast(joinInputDesc.size()); + presentKeyValueJoinOperatorDesc.InputTensors = joinInputDesc.data(); + presentKeyValueJoinOperatorDesc.OutputTensor = &outputDescs[presentIndex]; + presentKeyValueJoinOperatorDesc.Axis = gsl::narrow_cast(0); + } + + DML_OPERATOR_DESC presentKeyValueJoinDesc = { DML_OPERATOR_JOIN, &presentKeyValueJoinOperatorDesc }; + + // Construct the graph + std::vector inputEdges; + std::vector intermediateEdges; + std::vector outputEdges; + + std::vector opDescs = { + &matMulIntToFloatDesc, + &mhaDesc, + }; + + uint32_t currentNodeIndex = 0; + const uint32_t matMulIntToFloatNodeIndex = currentNodeIndex++; + const uint32_t mhaNodeIndex = currentNodeIndex++; + + uint32_t queryKeyValueTransposedNodeIndex = 0; + + opDescs.push_back(&transposedDesc); + queryKeyValueTransposedNodeIndex = currentNodeIndex++; + + uint32_t maskSliceNodeIndex = 0; + if (hasMaxSequenceMask) + { + opDescs.push_back(&maskSlicedDesc); + maskSliceNodeIndex = currentNodeIndex++; + } + + uint32_t pastKeySliceNodeIndex = 0; + uint32_t pastValueSliceNodeIndex = 0; + uint32_t concatNodeIndex = 0; + if (hasPast) + { + opDescs.push_back(&pastKeySlicedDesc); + pastKeySliceNodeIndex = currentNodeIndex++; + opDescs.push_back(&pastValueSlicedDesc); + pastValueSliceNodeIndex = currentNodeIndex++; + opDescs.push_back(&presentKeyValueJoinDesc); + concatNodeIndex = currentNodeIndex++; + } + + uint32_t causalMaskNodeIndex = 0; + if (unidirectional && !hasMask) + { + opDescs.push_back(&causalMaskDesc); + causalMaskNodeIndex = currentNodeIndex++; + } + + DML_INPUT_GRAPH_EDGE_DESC inputToMatMulIntToFloatEdge = {}; + inputToMatMulIntToFloatEdge.GraphInputIndex = InputIndex::inputIndex; + inputToMatMulIntToFloatEdge.ToNodeIndex = matMulIntToFloatNodeIndex; + inputToMatMulIntToFloatEdge.ToNodeInputIndex = 0; + inputEdges.push_back(inputToMatMulIntToFloatEdge); + + DML_INPUT_GRAPH_EDGE_DESC inputScaleToMatMulIntToFloatEdge = {}; + inputScaleToMatMulIntToFloatEdge.GraphInputIndex = InputIndex::inputScaleIndex; + inputScaleToMatMulIntToFloatEdge.ToNodeIndex = matMulIntToFloatNodeIndex; + inputScaleToMatMulIntToFloatEdge.ToNodeInputIndex = 1; + inputEdges.push_back(inputScaleToMatMulIntToFloatEdge); + + DML_INPUT_GRAPH_EDGE_DESC inputZeroPointToMatMulIntToFloatEdge = {}; + inputZeroPointToMatMulIntToFloatEdge.GraphInputIndex = InputIndex::inputZeroPointIndex; + inputZeroPointToMatMulIntToFloatEdge.ToNodeIndex = matMulIntToFloatNodeIndex; + inputZeroPointToMatMulIntToFloatEdge.ToNodeInputIndex = 2; + inputEdges.push_back(inputZeroPointToMatMulIntToFloatEdge); + + DML_INPUT_GRAPH_EDGE_DESC weightToMatMulIntToFloatEdge = {}; + weightToMatMulIntToFloatEdge.GraphInputIndex = InputIndex::weightsIndex; + weightToMatMulIntToFloatEdge.ToNodeIndex = matMulIntToFloatNodeIndex; + weightToMatMulIntToFloatEdge.ToNodeInputIndex = 3; + inputEdges.push_back(weightToMatMulIntToFloatEdge); + + DML_INPUT_GRAPH_EDGE_DESC weightScaleToMatMulIntToFloatEdge = {}; + weightScaleToMatMulIntToFloatEdge.GraphInputIndex = InputIndex::weightScaleIndex; + weightScaleToMatMulIntToFloatEdge.ToNodeIndex = matMulIntToFloatNodeIndex; + weightScaleToMatMulIntToFloatEdge.ToNodeInputIndex = 4; + inputEdges.push_back(weightScaleToMatMulIntToFloatEdge); + + DML_INPUT_GRAPH_EDGE_DESC weightZeroPointToMatMulIntToFloatEdge = {}; + weightZeroPointToMatMulIntToFloatEdge.GraphInputIndex = InputIndex::weightZeroPointIndex; + weightZeroPointToMatMulIntToFloatEdge.ToNodeIndex = matMulIntToFloatNodeIndex; + weightZeroPointToMatMulIntToFloatEdge.ToNodeInputIndex = 5; + inputEdges.push_back(weightZeroPointToMatMulIntToFloatEdge); + + if (hasBias) + { + DML_INPUT_GRAPH_EDGE_DESC biasToMatMulIntToFloatEdge = {}; + biasToMatMulIntToFloatEdge.GraphInputIndex = InputIndex::biasIndex; + biasToMatMulIntToFloatEdge.ToNodeIndex = matMulIntToFloatNodeIndex; + biasToMatMulIntToFloatEdge.ToNodeInputIndex = 6; + inputEdges.push_back(biasToMatMulIntToFloatEdge); + } + + if (hasMask) + { + if (hasUnpaddedBounds) + { + DML_INPUT_GRAPH_EDGE_DESC maskToMhaEdge = {}; + maskToMhaEdge.GraphInputIndex = InputIndex::maskIndex; + maskToMhaEdge.ToNodeIndex = mhaNodeIndex; + maskToMhaEdge.ToNodeInputIndex = mhaMaskIndex; + inputEdges.push_back(maskToMhaEdge); + } + else if (hasMaxSequenceMask) + { + DML_INPUT_GRAPH_EDGE_DESC maskToMaskSliceEdge = {}; + maskToMaskSliceEdge.GraphInputIndex = InputIndex::maskIndex; + maskToMaskSliceEdge.ToNodeIndex = maskSliceNodeIndex; + maskToMaskSliceEdge.ToNodeInputIndex = 0; + inputEdges.push_back(maskToMaskSliceEdge); + + DML_INTERMEDIATE_GRAPH_EDGE_DESC maskSliceToMhaEdge = {}; + maskSliceToMhaEdge.FromNodeIndex = maskSliceNodeIndex; + maskSliceToMhaEdge.FromNodeOutputIndex = 0; + maskSliceToMhaEdge.ToNodeIndex = mhaNodeIndex; + maskSliceToMhaEdge.ToNodeInputIndex = mhaMaskIndex; + intermediateEdges.push_back(maskSliceToMhaEdge); + } + else + { + DML_INPUT_GRAPH_EDGE_DESC maskToMhaEdge = {}; + maskToMhaEdge.GraphInputIndex = InputIndex::maskIndex; + maskToMhaEdge.ToNodeIndex = mhaNodeIndex; + maskToMhaEdge.ToNodeInputIndex = mhaMaskIndex; + inputEdges.push_back(maskToMhaEdge); + } + } + else if (unidirectional) + { + DML_INTERMEDIATE_GRAPH_EDGE_DESC causalMaskToMhaEdge = {}; + causalMaskToMhaEdge.FromNodeIndex = causalMaskNodeIndex; + causalMaskToMhaEdge.FromNodeOutputIndex = 0; + causalMaskToMhaEdge.ToNodeIndex = mhaNodeIndex; + causalMaskToMhaEdge.ToNodeInputIndex = mhaMaskIndex; + intermediateEdges.push_back(causalMaskToMhaEdge); + } + + if (hasPast) + { + DML_INPUT_GRAPH_EDGE_DESC pastToPastKeySliceEdge = {}; + pastToPastKeySliceEdge.GraphInputIndex = InputIndex::pastIndex; + pastToPastKeySliceEdge.ToNodeIndex = pastKeySliceNodeIndex; + pastToPastKeySliceEdge.ToNodeInputIndex = 0; + inputEdges.push_back(pastToPastKeySliceEdge); + + DML_INPUT_GRAPH_EDGE_DESC pastToPastValueSliceEdge = {}; + pastToPastValueSliceEdge.GraphInputIndex = InputIndex::pastIndex; + pastToPastValueSliceEdge.ToNodeIndex = pastValueSliceNodeIndex; + pastToPastValueSliceEdge.ToNodeInputIndex = 0; + inputEdges.push_back(pastToPastValueSliceEdge); + + DML_INTERMEDIATE_GRAPH_EDGE_DESC pastKeyToMhaEdge = {}; + pastKeyToMhaEdge.FromNodeIndex = pastKeySliceNodeIndex; + pastKeyToMhaEdge.FromNodeOutputIndex = 0; + pastKeyToMhaEdge.ToNodeIndex = mhaNodeIndex; + pastKeyToMhaEdge.ToNodeInputIndex = mhaPastKeyIndex; + intermediateEdges.push_back(pastKeyToMhaEdge); + + DML_INTERMEDIATE_GRAPH_EDGE_DESC pastValueToMhaEdge = {}; + pastValueToMhaEdge.FromNodeIndex = pastValueSliceNodeIndex; + pastValueToMhaEdge.FromNodeOutputIndex = 0; + pastValueToMhaEdge.ToNodeIndex = mhaNodeIndex; + pastValueToMhaEdge.ToNodeInputIndex = mhaPastValueIndex; + intermediateEdges.push_back(pastValueToMhaEdge); + + DML_INTERMEDIATE_GRAPH_EDGE_DESC presentKeyToConcatEdge = {}; + presentKeyToConcatEdge.FromNodeIndex = mhaNodeIndex; + presentKeyToConcatEdge.FromNodeOutputIndex = mhaPresentKeyIndex; + presentKeyToConcatEdge.ToNodeIndex = concatNodeIndex; + presentKeyToConcatEdge.ToNodeInputIndex = 0; + intermediateEdges.push_back(presentKeyToConcatEdge); + + DML_INTERMEDIATE_GRAPH_EDGE_DESC presentValueToConcatEdge = {}; + presentValueToConcatEdge.FromNodeIndex = mhaNodeIndex; + presentValueToConcatEdge.FromNodeOutputIndex = mhaPresentValueIndex; + presentValueToConcatEdge.ToNodeIndex = concatNodeIndex; + presentValueToConcatEdge.ToNodeInputIndex = 1; + intermediateEdges.push_back(presentValueToConcatEdge); + } + + DML_INTERMEDIATE_GRAPH_EDGE_DESC matMulIntToFloatToQueryKeyValueTransposeEdge = {}; + matMulIntToFloatToQueryKeyValueTransposeEdge.FromNodeIndex = matMulIntToFloatNodeIndex; + matMulIntToFloatToQueryKeyValueTransposeEdge.FromNodeOutputIndex = 0; + matMulIntToFloatToQueryKeyValueTransposeEdge.ToNodeIndex = queryKeyValueTransposedNodeIndex; + matMulIntToFloatToQueryKeyValueTransposeEdge.ToNodeInputIndex = 0; + intermediateEdges.push_back(matMulIntToFloatToQueryKeyValueTransposeEdge); + + // All we need to do here is transpose the stacked QKV tensor into something DML supports + DML_INTERMEDIATE_GRAPH_EDGE_DESC queryKeyValueTransposedToMhaEdge = {}; + queryKeyValueTransposedToMhaEdge.FromNodeIndex = queryKeyValueTransposedNodeIndex; + queryKeyValueTransposedToMhaEdge.FromNodeOutputIndex = 0; + queryKeyValueTransposedToMhaEdge.ToNodeIndex = mhaNodeIndex; + queryKeyValueTransposedToMhaEdge.ToNodeInputIndex = mhaStackedQueryKeyValueIndex; + intermediateEdges.push_back(queryKeyValueTransposedToMhaEdge); + + DML_OUTPUT_GRAPH_EDGE_DESC mhaToOutputEdge = {}; + mhaToOutputEdge.FromNodeIndex = mhaNodeIndex; + mhaToOutputEdge.FromNodeOutputIndex = mhaOutputIndex; + mhaToOutputEdge.GraphOutputIndex = OutputIndex::outputIndex; + outputEdges.push_back(mhaToOutputEdge); + + if (hasPast) + { + DML_OUTPUT_GRAPH_EDGE_DESC concatToOutputEdge = {}; + concatToOutputEdge.FromNodeIndex = concatNodeIndex; + concatToOutputEdge.FromNodeOutputIndex = 0; + concatToOutputEdge.GraphOutputIndex = OutputIndex::presentIndex; + outputEdges.push_back(concatToOutputEdge); + } + + MLOperatorGraphDesc operatorGraphDesc = {}; + operatorGraphDesc.inputEdgeCount = gsl::narrow_cast(inputEdges.size()); + operatorGraphDesc.inputEdges = inputEdges.data(); + operatorGraphDesc.intermediateEdgeCount = gsl::narrow_cast(intermediateEdges.size()); + operatorGraphDesc.intermediateEdges = intermediateEdges.data(); + operatorGraphDesc.outputEdgeCount = gsl::narrow_cast(outputEdges.size()); + operatorGraphDesc.outputEdges = outputEdges.data(); + operatorGraphDesc.nodeCount = gsl::narrow_cast(opDescs.size()); + operatorGraphDesc.nodes = opDescs.data(); + + SetDmlOperatorGraphDesc(std::move(operatorGraphDesc), kernelCreationContext); + } +}; + +void CALLBACK QueryQAttention(IMLOperatorSupportQueryContextPrivate* context, /*out*/ bool* isSupported) +{ + *isSupported = false; + + // `unidirectional == 1` with Mask Tensor is not supported yet + MLOperatorAttributes attributes(context); + if (attributes.GetOptionalAttribute(AttrName::Unidirectional, 0) != 0 && context->IsInputValid(5)) + { + return; + } + + // `do_rotary == 1` is not supported yet + if (attributes.GetOptionalAttribute(AttrName::DoRotary, 0) != 0) + { + return; + } + + // `past_present_share_buffer == 1` is not supported yet + if (attributes.GetOptionalAttribute(AttrName::PastPresentShareBuffer, 0) != 0) + { + return; + } + + *isSupported = true; +} + +DML_OP_DEFINE_CREATION_FUNCTION(QAttention, DmlOperatorQAttention); +} // namespace Dml diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorQLinearSigmoid.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorQLinearSigmoid.cpp index f658e7c7da..bc0082fef3 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorQLinearSigmoid.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorQLinearSigmoid.cpp @@ -178,4 +178,4 @@ void CALLBACK QueryQLinearSigmoid(IMLOperatorSupportQueryContextPrivate* context } DML_OP_DEFINE_CREATION_FUNCTION(QLinearSigmoid, DmlOperatorQLinearSigmoid); -} // namespace Dml +} // namespace Dml \ No newline at end of file diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.cpp index 38cf80b381..71fc8741bf 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.cpp @@ -516,6 +516,7 @@ DML_OP_EXTERN_CREATION_FUNCTION(Resize19); DML_OP_EXTERN_CREATION_FUNCTION(Shape); DML_OP_EXTERN_CREATION_FUNCTION(Size); +DML_OP_EXTERN_CREATION_FUNCTION(QAttention); DML_OP_EXTERN_CREATION_FUNCTION(Attention); DML_OP_EXTERN_CREATION_FUNCTION(MultiHeadAttention); DML_OP_EXTERN_CREATION_FUNCTION(NonZero); @@ -537,6 +538,7 @@ DML_OP_EXTERN_QUERY_FUNCTION(Pad); DML_OP_EXTERN_QUERY_FUNCTION(LayerNormalization); DML_OP_EXTERN_QUERY_FUNCTION(SkipLayerNormalization); DML_OP_EXTERN_QUERY_FUNCTION(QLinearSigmoid); +DML_OP_EXTERN_QUERY_FUNCTION(QAttention); DML_OP_EXTERN_QUERY_FUNCTION(Attention); constexpr static std::array typeNameListDefault = {"T"}; @@ -614,15 +616,23 @@ constexpr static std::array supportedTypeListLayerN constexpr static std::array supportedTypeListShape = {SupportedTensorDataTypes::All, SupportedTensorDataTypes::Int64}; constexpr static std::array supportedTypeListSize = {SupportedTensorDataTypes::All, SupportedTensorDataTypes::Int64}; constexpr static std::array supportedTypeListQLinearSigmoid = {SupportedTensorDataTypes::UInt8 | SupportedTensorDataTypes::Int8}; + +constexpr static std::array supportedTypeListQAttention = { + SupportedTensorDataTypes::Ints8Bit, + SupportedTensorDataTypes::Ints8Bit, + SupportedTensorDataTypes::Float16to32, + SupportedTensorDataTypes::Int32 +}; + constexpr static std::array supportedTypeListAttention = {SupportedTensorDataTypes::Float16to32, SupportedTensorDataTypes::Int32}; constexpr static std::array supportedTypeListRotaryEmbedding = {SupportedTensorDataTypes::Float16to32, SupportedTensorDataTypes::Int64}; constexpr static std::array supportedTypeListGroupNorm = {SupportedTensorDataTypes::Float16to32, SupportedTensorDataTypes::Float16to32}; constexpr static std::array supportedTypeListNonZero = {SupportedTensorDataTypes::Float16to32 | SupportedTensorDataTypes::Ints8Bit | SupportedTensorDataTypes::Ints16Bit | SupportedTensorDataTypes::Ints32Bit | SupportedTensorDataTypes::Bool}; constexpr static std::array supportedTypeListQLinearMatMul = { - SupportedTensorDataTypes::Int8|SupportedTensorDataTypes::UInt8, - SupportedTensorDataTypes::Int8|SupportedTensorDataTypes::UInt8, - SupportedTensorDataTypes::Int8|SupportedTensorDataTypes::UInt8 + SupportedTensorDataTypes::Ints8Bit, + SupportedTensorDataTypes::Ints8Bit, + SupportedTensorDataTypes::Ints8Bit }; constexpr static std::array supportedTypeListMatMulIntegerToFloat = { @@ -632,9 +642,9 @@ constexpr static std::array supportedTypeListMatMul }; constexpr static std::array supportedTypeListQLinearConv = { - SupportedTensorDataTypes::Int8|SupportedTensorDataTypes::UInt8, - SupportedTensorDataTypes::Int8|SupportedTensorDataTypes::UInt8, - SupportedTensorDataTypes::Int8|SupportedTensorDataTypes::UInt8, + SupportedTensorDataTypes::Ints8Bit, + SupportedTensorDataTypes::Ints8Bit, + SupportedTensorDataTypes::Ints8Bit, SupportedTensorDataTypes::Int32 }; @@ -1069,6 +1079,7 @@ constexpr static OperatorRegistrationInformation operatorRegistrationInformation {REG_INFO_MS( 1, DynamicQuantizeMatMul, typeNameListTwo, supportedTypeListDynamicQuantizeLinear, DmlGraphSupport::Supported)}, {REG_INFO_MS( 1, FusedMatMulActivation, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, {REG_INFO_MS( 1, QLinearSigmoid, typeNameListDefault, supportedTypeListQLinearSigmoid, DmlGraphSupport::Supported, requiredConstantCpuInputs(), std::nullopt, QueryQLinearSigmoid)}, + {REG_INFO_MS( 1, QAttention, typeNameListFour, supportedTypeListQAttention, DmlGraphSupport::Supported, requiredConstantCpuInputs(), std::nullopt, QueryQAttention)}, {REG_INFO_MS( 1, Attention, typeNameListAttention, supportedTypeListAttention, DmlGraphSupport::Supported, requiredConstantCpuInputs(), std::nullopt, QueryAttention)}, {REG_INFO_MS( 1, MultiHeadAttention, typeNameListAttention, supportedTypeListAttention, DmlGraphSupport::Supported)}, {REG_INFO_MS( 1, RotaryEmbedding, typeNameListRotaryEmbedding, supportedTypeListRotaryEmbedding, DmlGraphSupport::Supported)}, diff --git a/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.cpp b/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.cpp index 317f5ebcbc..acda1a516b 100644 --- a/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.cpp +++ b/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.cpp @@ -2802,6 +2802,48 @@ namespace OperatorHelper m_qkvHiddenSizes = kernelInformation.GetAttributes().GetOptionalAttributeVectorInt32(AttrName::QkvHiddenSizes); } + std::vector QAttentionHelper::GetOutputShapes(const MLShapeInferenceContext& shapeInfo) const + { + ML_CHECK_VALID_ARGUMENT(shapeInfo.GetInputCount() >= 5); + + auto queryShape = shapeInfo.GetInputTensorShape(0); + ML_CHECK_VALID_ARGUMENT(queryShape.size() == 3); + + auto weightShape = shapeInfo.GetInputTensorShape(1); + ML_CHECK_VALID_ARGUMENT(weightShape.size() == 2); + ML_CHECK_VALID_ARGUMENT(weightShape[1] % 3 == 0); + + const uint32_t batchSize = queryShape[0]; + const uint32_t sequenceLength = queryShape[1]; + const uint32_t hiddenSize = weightShape[1] / 3; + const uint32_t headSize = hiddenSize / m_numHeads; + + std::vector outputShapes(2); + + outputShapes[0] = EdgeShapes({batchSize, sequenceLength, hiddenSize}); + + uint32_t totalSequenceLength = sequenceLength; + if (shapeInfo.IsInputValid(8)) + { + ML_CHECK_VALID_ARGUMENT(shapeInfo.GetInputTensorDimensionCount(8) == 5); + const uint32_t pastSequenceLength = shapeInfo.GetInputTensorShape(8)[3]; + totalSequenceLength += pastSequenceLength; + } + + if (shapeInfo.IsOutputValid(1)) + { + ML_CHECK_VALID_ARGUMENT(shapeInfo.IsInputValid(8)); + outputShapes[1] = EdgeShapes({2, batchSize, m_numHeads, totalSequenceLength, headSize}); + } + + return outputShapes; + } + + void QAttentionHelper::Initialize(const IKernelInformationAdapter& kernelInformation) + { + m_numHeads = gsl::narrow_cast(kernelInformation.GetAttributes().GetAttribute(AttrName::NumHeads)); + } + std::vector SkipLayerNormHelper::GetOutputShapes(const MLShapeInferenceContext& shapeInfo) const { ML_CHECK_VALID_ARGUMENT(shapeInfo.GetInputCount() >= 3); diff --git a/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.h b/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.h index 1f5daed6ea..aff31bb305 100644 --- a/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.h +++ b/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.h @@ -1554,6 +1554,22 @@ private: std::vector m_qkvHiddenSizes; }; +class QAttentionHelper +{ +public: + template + QAttentionHelper(const Info_t& info, const Shape_t& shapeInfo) + { + Initialize(KernelInformationAdapter(info)); + } + + std::vector GetOutputShapes(const MLShapeInferenceContext& shapeInfo) const; + +private: + void Initialize(const IKernelInformationAdapter& kernelInformation); + uint32_t m_numHeads; +}; + class SkipLayerNormHelper { public: @@ -1699,6 +1715,7 @@ using ShapeInferenceHelper_Affine = GetOutputShapeAsInputShapeHelper; using ShapeInferenceHelper_QuantizeLinear = GetOutputShapeAsInputShapeHelper; using ShapeInferenceHelper_DequantizeLinear = GetOutputShapeAsInputShapeHelper; using ShapeInferenceHelper_QLinearSigmoid = GetOutputShapeAsInputShapeHelper; +using ShapeInferenceHelper_QAttention = QAttentionHelper; using ShapeInferenceHelper_Attention = AttentionHelper; using ShapeInferenceHelper_MultiHeadAttention = MultiHeadAttentionHelper; using ShapeInferenceHelper_RotaryEmbedding = GetOutputShapeAsInputShapeHelper; diff --git a/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorVersions.h b/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorVersions.h index 8de43f2705..7492b72942 100644 --- a/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorVersions.h +++ b/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorVersions.h @@ -448,6 +448,7 @@ namespace OperatorHelper static const int sc_sinceVer_FusedMatMul = 1; static const int sc_sinceVer_FusedMatMulActivation = 1; static const int sc_sinceVer_QLinearSigmoid = 1; + static const int sc_sinceVer_QAttention = 1; static const int sc_sinceVer_Attention = 1; static const int sc_sinceVer_MatMulIntegerToFloat = 1; static const int sc_sinceVer_MultiHeadAttention = 1; diff --git a/onnxruntime/test/contrib_ops/quantize_attention_op_test.cc b/onnxruntime/test/contrib_ops/quantize_attention_op_test.cc index 3af334696a..fd222583ac 100644 --- a/onnxruntime/test/contrib_ops/quantize_attention_op_test.cc +++ b/onnxruntime/test/contrib_ops/quantize_attention_op_test.cc @@ -20,7 +20,8 @@ namespace test { enum class EP : char { CPU, CUDA, - DNNL + DNNL, + DML }; // input: [batch_size, sequence_length, hidden_size] @@ -111,7 +112,9 @@ void RunQAttention(const std::vector& input_data, execution_providers.push_back(DefaultCudaExecutionProvider()); } else if constexpr (ep == EP::CPU) { execution_providers.push_back(DefaultCpuExecutionProvider()); - } else { // onednn ep + } else if constexpr (ep == EP::DML) { + execution_providers.push_back(DefaultDmlExecutionProvider()); + } else { // onednn ep execution_providers.push_back(DefaultDnnlExecutionProvider()); } @@ -192,6 +195,52 @@ static void RunQAttentionDNNL( #endif } +static void RunQAttentionDML( + const std::vector& input_data, + const std::vector& weights_data, + const std::vector& bias_data, + const std::vector& mask_index_data, + const std::vector& output_data, + int batch_size, + int sequence_length, + int hidden_size, + int number_of_heads, + bool use_special_quantize_parameter = true, + bool is_unidirectional = false, + int input_hidden_size = 0) { + // Return without running code if USE_DML is not defined +#ifdef USE_DML + bool enable_dml = (nullptr != DefaultDmlExecutionProvider().get()); + if (enable_dml) { + quantization::Params input_quant_params(/*scale=*/0.0f, /*zero_point=*/0); + quantization::Params weights_quant_params(/*scale=*/0.0f, /*zero_point=*/0); + if (use_special_quantize_parameter) { + input_quant_params.scale = 0.1f; + weights_quant_params.scale = 0.1f; + input_quant_params.zero_point = 128; + weights_quant_params.zero_point = 1; + } + + RunQAttention( + input_data, weights_data, bias_data, mask_index_data, output_data, input_quant_params, weights_quant_params, + batch_size, sequence_length, hidden_size, number_of_heads, is_unidirectional, false, input_hidden_size); + } +#else + ORT_UNUSED_PARAMETER(input_data); + ORT_UNUSED_PARAMETER(weights_data); + ORT_UNUSED_PARAMETER(bias_data); + ORT_UNUSED_PARAMETER(mask_index_data); + ORT_UNUSED_PARAMETER(output_data); + ORT_UNUSED_PARAMETER(batch_size); + ORT_UNUSED_PARAMETER(sequence_length); + ORT_UNUSED_PARAMETER(hidden_size); + ORT_UNUSED_PARAMETER(number_of_heads); + ORT_UNUSED_PARAMETER(use_special_quantize_parameter); + ORT_UNUSED_PARAMETER(is_unidirectional); + ORT_UNUSED_PARAMETER(input_hidden_size); +#endif +} + static void RunQAttentionU8U8( const std::vector& input_data, const std::vector& weights_data, @@ -272,6 +321,9 @@ static void RunQAttentionAll( RunQAttentionDNNL(input_data, weight_data, bias_data, mask_index_data, output_data, batch_size, sequence_length, hidden_size, number_of_heads, use_special_quantize_parameter, is_unidirectional, input_hidden_size); + RunQAttentionDML(input_data, weight_data, bias_data, mask_index_data, output_data, + batch_size, sequence_length, hidden_size, number_of_heads, + use_special_quantize_parameter, is_unidirectional, input_hidden_size); } // ONEDNN EP only supports 2D raw mask @@ -859,8 +911,8 @@ void TestQuantizedAttentionPastState(int64_t batch, std::vector input_dims{batch, seq_len, hidden_size}; std::vector input_data = random.Gaussian(input_dims, input_mean, static_cast(input_range / 6), input_min, input_max); - constexpr WeightT weight_min = std::numeric_limits::min(); - constexpr WeightT weight_max = std::numeric_limits::max(); + constexpr WeightT weight_min = std::is_same_v ? std::numeric_limits::min() / 2 : std::numeric_limits::min(); + constexpr WeightT weight_max = std::numeric_limits::max() / 2; constexpr int32_t weight_range = weight_max - weight_min; std::vector weight_zero_point(weight_scale_zp_size);