diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md
index 955957f295..eddc3b7873 100644
--- a/docs/OperatorKernels.md
+++ b/docs/OperatorKernels.md
@@ -1277,6 +1277,7 @@ Do not modify directly.*
|MatMulIntegerToFloat|*in* A:**T1**
*in* B:**T2**
*in* a_scale:**T3**
*in* b_scale:**T3**
*in* a_zero_point:**T1**
*in* b_zero_point:**T2**
*in* bias:**T3**
*out* Y:**T3**|1+|**T1** = tensor(int8), tensor(uint8)
**T2** = tensor(int8), tensor(uint8)
**T3** = tensor(float), tensor(float16)|
|MultiHeadAttention|*in* query:**T**
*in* key:**T**
*in* value:**T**
*in* bias:**T**
*in* key_padding_mask:**M**
*in* relative_position_bias:**T**
*in* past_key:**T**
*in* past_value:**T**
*out* output:**T**
*out* present_key:**T**
*out* present_value:**T**|1+|**M** = tensor(int32)
**T** = tensor(float), tensor(float16)|
|NhwcConv|*in* X:**T**
*in* W:**T**
*in* B:**T**
*out* Y:**T**|1+|**T** = tensor(float), tensor(float16)|
+|QAttention|*in* input:**T1**
*in* weight:**T2**
*in* bias:**T3**
*in* input_scale:**T3**
*in* weight_scale:**T3**
*in* mask_index:**T4**
*in* input_zero_point:**T1**
*in* weight_zero_point:**T2**
*in* past:**T3**
*out* output:**T3**
*out* present:**T3**|1+|**T1** = tensor(int8), tensor(uint8)
**T2** = tensor(int8), tensor(uint8)
**T3** = tensor(float), tensor(float16)
**T4** = tensor(int32)|
|QLinearAdd|*in* A:**T**
*in* A_scale:**tensor(float)**
*in* A_zero_point:**T**
*in* B:**T**
*in* B_scale:**tensor(float)**
*in* B_zero_point:**T**
*in* C_scale:**tensor(float)**
*in* C_zero_point:**T**
*out* C:**T**|1+|**T** = tensor(int8), tensor(uint8)|
|QLinearAveragePool|*in* X:**T**
*in* x_scale:**tensor(float)**
*in* x_zero_point:**T**
*in* y_scale:**tensor(float)**
*in* y_zero_point:**T**
*out* Y:**T**|1+|**T** = tensor(int8), tensor(uint8)|
|QLinearConcat|*in* Y_scale:**TF**
*in* Y_zero_point:**T8**
*in* inputs:**TV**
*out* Y:**T8**|1+|**T8** = tensor(int8), tensor(uint8)
**TF** = tensor(float)
**TV** = tensor(float), tensor(int8), tensor(uint8)|
diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorQAttention.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorQAttention.cpp
new file mode 100644
index 0000000000..f9519b26bb
--- /dev/null
+++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorQAttention.cpp
@@ -0,0 +1,704 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
+#include "precomp.h"
+
+/*
+Abbreviations: B is batch_size, S is sequence_length, W is hidden_size
+ N is number of attention heads, H is head size, and W=N*H
+
+Input, Weight, Bias, Mask Index and Past are Inputs
+
+Mask Index/Causal Input Weight Bias
+ | \ | /
+ | \ | /
+ | \ | /
+ | MatMulIntToFloat
+ | / | \
+ | / | \
+ | / | \
+ | Slice Slice Slice
+ | | | |
+ | | | |
+ | Identity Identity Identity // The identities are used to transpose NCHW -> NHCW while
+ | | | | // keeping the GEMM strides as NCHW to better target metacommands
+ | | | |
+ | | | | Past
+ | | | | / \
+ | | | | / \
+ | | | | Slice Slice
+ | | | | | |
+ | | | | | |
+ | | | | | |
+ --------------------------MHA -----------
+ / | \
+ / | \
+ / | \
+ / | \
+ / | \
+ / | \
+ / presentKey presentValue
+ / \ /
+ / \ /
+ / \ /
+ / Concat
+ / |
+ Output1 Output2 (present)
+
+ This kernel creates a DML_GRAPH, as mentioned above.
+ For reference, refer to this Doc:
+ https://github.com/microsoft/onnxruntime/blob/main/docs/ContribOperators.md#commicrosoftqattention
+ */
+
+namespace Dml
+{
+class DmlOperatorQAttention : public DmlOperator
+{
+public:
+ DmlOperatorQAttention(const MLOperatorKernelCreationContext& kernelCreationContext)
+ : DmlOperator(kernelCreationContext)
+ {
+ enum InputIndex : uint32_t
+ {
+ inputIndex,
+ weightsIndex,
+ biasIndex,
+ inputScaleIndex,
+ weightScaleIndex,
+ maskIndex,
+ inputZeroPointIndex,
+ weightZeroPointIndex,
+ pastIndex,
+ inputCount,
+ };
+
+ enum OutputIndex : uint32_t
+ {
+ outputIndex,
+ presentIndex,
+ outputCount,
+ };
+
+ enum MhaInputIndex : uint32_t
+ {
+ mhaQueryIndex,
+ mhaKeyIndex,
+ mhaValueIndex,
+ mhaStackedQueryKeyIndex,
+ mhaStackedKeyValueIndex,
+ mhaStackedQueryKeyValueIndex,
+ mhaBiasIndex,
+ mhaMaskIndex,
+ mhaRelativePositionBiasIndex,
+ mhaPastKeyIndex,
+ mhaPastValueIndex,
+ mhaInputCount,
+ };
+
+ enum MhaOutputIndex : uint32_t
+ {
+ mhaOutputIndex,
+ mhaPresentKeyIndex,
+ mhaPresentValueIndex,
+ mhaOutputCount,
+ };
+
+ ML_CHECK_VALID_ARGUMENT(kernelCreationContext.GetInputCount() >= 5);
+ ML_CHECK_VALID_ARGUMENT(kernelCreationContext.GetOutputCount() >= 1);
+
+ const bool hasBias = kernelCreationContext.IsInputValid(biasIndex);
+ const bool hasMask = kernelCreationContext.IsInputValid(maskIndex);
+ const bool hasUnpaddedBounds = hasMask && kernelCreationContext.GetInputTensorDimensionCount(maskIndex) == 1;
+ const bool hasPast = kernelCreationContext.IsInputValid(pastIndex);
+
+ DmlOperator::Initialize(kernelCreationContext, std::nullopt, std::nullopt, std::nullopt, std::nullopt, 1);
+
+ const bool unidirectional = gsl::narrow_cast(kernelCreationContext.GetAttribute(AttrName::Unidirectional));
+ const uint32_t numHeads = gsl::narrow_cast(kernelCreationContext.GetAttribute(AttrName::NumHeads));
+ ML_CHECK_VALID_ARGUMENT(numHeads > 0); // to avoid process crash because of division by zero.
+
+ auto inputTensorShape = m_inputTensorDescs[inputIndex].GetSizes();
+ ML_CHECK_VALID_ARGUMENT(inputTensorShape.size() == 3);
+
+ auto weightTensorShape = m_inputTensorDescs[weightsIndex].GetSizes();
+ ML_CHECK_VALID_ARGUMENT(weightTensorShape.size() == 2);
+ ML_CHECK_VALID_ARGUMENT(weightTensorShape[0] == inputTensorShape[2]);
+ ML_CHECK_VALID_ARGUMENT(weightTensorShape[1] % 3 == 0);
+
+ if (hasBias)
+ {
+ auto biasTensorShape = m_inputTensorDescs[biasIndex].GetSizes();
+ ML_CHECK_VALID_ARGUMENT(biasTensorShape.size() == 1);
+ ML_CHECK_VALID_ARGUMENT(biasTensorShape[0] % 3 == 0);
+ ML_CHECK_VALID_ARGUMENT(weightTensorShape[1] == biasTensorShape[0]);
+ }
+
+ if (hasPast)
+ {
+ ML_CHECK_VALID_ARGUMENT(kernelCreationContext.IsOutputValid(presentIndex));
+ }
+
+ const uint32_t hiddenSize = weightTensorShape[1] / 3;
+ const uint32_t headSize = hiddenSize / numHeads;
+ const uint32_t batchSize = inputTensorShape[0];
+ const uint32_t sequenceLength = inputTensorShape[1];
+ const uint32_t pastSequenceLength = hasPast ? m_inputTensorDescs[pastIndex].GetSizes()[3] : 0;
+
+ uint32_t desiredWeightTensorShape[3] = {batchSize, weightTensorShape[0], 3 * hiddenSize};
+ MLOperatorTensorDataType dataType = kernelCreationContext.GetOutputEdgeDescription(outputIndex).tensorDataType;
+
+ m_inputTensorDescs[weightsIndex] = TensorDesc::ConstructBroadcastedTensorDesc(
+ kernelCreationContext.GetInputEdgeDescription(weightsIndex).tensorDataType,
+ desiredWeightTensorShape,
+ weightTensorShape);
+
+ uint32_t desiredBiasTensorShape[3] = {batchSize, sequenceLength, 3 * hiddenSize};
+
+ if (hasBias)
+ {
+ auto biasTensorShape = m_inputTensorDescs[biasIndex].GetSizes();
+ m_inputTensorDescs[biasIndex] = TensorDesc::ConstructBroadcastedTensorDesc(kernelCreationContext.GetInputEdgeDescription(biasIndex).tensorDataType, desiredBiasTensorShape, biasTensorShape);
+ }
+
+ MLOperatorTensorDataType maskTensorDataType = MLOperatorTensorDataType::Undefined;
+ bool hasMaxSequenceMask = false;
+ DML_MULTIHEAD_ATTENTION_MASK_TYPE maskType = DML_MULTIHEAD_ATTENTION_MASK_TYPE_NONE;
+ if (hasMask)
+ {
+ if (hasUnpaddedBounds)
+ {
+ auto unpaddedKeyBoundsShape = m_inputTensorDescs[maskIndex].GetSizes();
+ ML_CHECK_VALID_ARGUMENT(unpaddedKeyBoundsShape.size() == 1);
+
+ const uint32_t batchGroupCount = unpaddedKeyBoundsShape[0] / batchSize;
+ ML_CHECK_VALID_ARGUMENT(batchGroupCount == 1 || batchGroupCount == 2);
+
+ uint32_t desiredShape[2] = {batchGroupCount, batchSize};
+ m_inputTensorDescs[maskIndex] = TensorDesc(
+ m_inputTensorDescs[maskIndex].GetDmlDataType(),
+ desiredShape);
+
+ maskType = batchGroupCount == 1
+ ? DML_MULTIHEAD_ATTENTION_MASK_TYPE_KEY_SEQUENCE_LENGTH
+ : DML_MULTIHEAD_ATTENTION_MASK_TYPE_KEY_SEQUENCE_END_START;
+ }
+ else
+ {
+ auto maskIndexTensorShape = m_inputTensorDescs[maskIndex].GetSizes();
+ ML_CHECK_VALID_ARGUMENT(maskIndexTensorShape.size() > 1 && maskIndexTensorShape.size() <= 4);
+
+ maskType = DML_MULTIHEAD_ATTENTION_MASK_TYPE_BOOLEAN;
+ std::vector reshapedMaskIndexTensorShape(maskIndexTensorShape.begin(), maskIndexTensorShape.end());
+ if (maskIndexTensorShape.size() == 4 && maskIndexTensorShape[2] != sequenceLength)
+ {
+ hasMaxSequenceMask = true;
+ ML_CHECK_VALID_ARGUMENT(maskIndexTensorShape[2] == maskIndexTensorShape[3]);
+ const uint32_t maxSequenceLength = maskIndexTensorShape[2];
+ uint32_t desiredMaskIndexShape[4] = {batchSize, numHeads, maxSequenceLength, maxSequenceLength};
+ maskTensorDataType = kernelCreationContext.GetInputEdgeDescription(maskIndex).tensorDataType;
+ m_inputTensorDescs[maskIndex] = TensorDesc::ConstructBroadcastedTensorDesc(maskTensorDataType, desiredMaskIndexShape, reshapedMaskIndexTensorShape);
+ }
+ else
+ {
+ uint32_t maskIndexDimensionCount = gsl::narrow_cast(maskIndexTensorShape.size());
+ reshapedMaskIndexTensorShape.insert(reshapedMaskIndexTensorShape.begin() + 1, 4 - maskIndexDimensionCount, 1);
+ uint32_t desiredMaskIndexShape[4] = {batchSize, numHeads, sequenceLength, sequenceLength};
+ maskTensorDataType = kernelCreationContext.GetInputEdgeDescription(maskIndex).tensorDataType;
+ m_inputTensorDescs[maskIndex] = TensorDesc::ConstructBroadcastedTensorDesc(maskTensorDataType, desiredMaskIndexShape, reshapedMaskIndexTensorShape);
+ }
+ }
+ }
+
+ MLOperatorTensorDataType pastTensorDataType = MLOperatorTensorDataType::Undefined;
+ MLOperatorTensorDataType presentTensorDataType = MLOperatorTensorDataType::Undefined;
+ if (hasPast)
+ {
+ pastTensorDataType = kernelCreationContext.GetInputEdgeDescription(pastIndex).tensorDataType;
+ presentTensorDataType = kernelCreationContext.GetOutputEdgeDescription(presentIndex).tensorDataType;
+ }
+
+ TensorDesc matMulIntToFloatOutputTensorDesc = TensorDesc::ConstructDefaultTensorDesc(dataType, desiredBiasTensorShape);
+ DML_TENSOR_DESC namedMatMulIntToFloatOutputTensorDesc = matMulIntToFloatOutputTensorDesc.GetDmlDesc();
+
+ std::vector inputDescs = GetDmlInputDescs();
+ std::vector outputDescs = GetDmlOutputDescs();
+
+ DML_MATRIX_MULTIPLY_INTEGER_TO_FLOAT_OPERATOR_DESC matMulIntToFloatOperatorDesc = {};
+ matMulIntToFloatOperatorDesc.ATensor = &inputDescs[InputIndex::inputIndex];
+ matMulIntToFloatOperatorDesc.AScaleTensor = &inputDescs[InputIndex::inputScaleIndex];
+ matMulIntToFloatOperatorDesc.AZeroPointTensor = &inputDescs[InputIndex::inputZeroPointIndex];
+ matMulIntToFloatOperatorDesc.BTensor = &inputDescs[InputIndex::weightsIndex];
+ matMulIntToFloatOperatorDesc.BScaleTensor = &inputDescs[InputIndex::weightScaleIndex];
+ matMulIntToFloatOperatorDesc.BZeroPointTensor = &inputDescs[InputIndex::weightZeroPointIndex];
+ matMulIntToFloatOperatorDesc.BiasTensor = hasBias ? &inputDescs[InputIndex::biasIndex] : nullptr;
+ matMulIntToFloatOperatorDesc.OutputTensor = &namedMatMulIntToFloatOutputTensorDesc;
+
+ const DML_OPERATOR_DESC matMulIntToFloatDesc = { DML_OPERATOR_MATRIX_MULTIPLY_INTEGER_TO_FLOAT, &matMulIntToFloatOperatorDesc};
+
+ std::array queryKeySlicedTensorShape = {batchSize, sequenceLength, hiddenSize + hiddenSize};
+ TensorDesc queryKeySlicedInputTensorDesc = TensorDesc::ConstructDefaultTensorDesc(dataType, queryKeySlicedTensorShape);
+ DML_TENSOR_DESC namedQueryKeySlicedInputTensorDesc = queryKeySlicedInputTensorDesc.GetDmlDesc();
+
+ std::array valueSlicedTensorShape = {batchSize, sequenceLength, hiddenSize};
+ TensorDesc valueSlicedInputTensorDesc = TensorDesc::ConstructDefaultTensorDesc(dataType, valueSlicedTensorShape);
+ DML_TENSOR_DESC namedValueSlicedInputTensorDesc = valueSlicedInputTensorDesc.GetDmlDesc();
+
+ // Transpose slice QK from [batchSize, sequenceLength, 2, numHeads, headSize] to [batchSize, sequenceLength, numHeads, 2, headSize]
+ std::array queryKeyTransposedTensorShape = {batchSize, sequenceLength, numHeads, 2, headSize};
+ std::array queryKeyTransposedStrides = {
+ sequenceLength * numHeads * 2 * headSize,
+ numHeads * 2 * headSize,
+ headSize,
+ numHeads * headSize,
+ 1,
+ };
+
+ TensorDesc queryKeyTransposedInputTensorDesc = TensorDesc(
+ GetDmlDataTypeFromMlDataType(dataType),
+ queryKeyTransposedTensorShape,
+ queryKeyTransposedStrides);
+ DML_TENSOR_DESC namedQueryKeyTransposedInputTensorDesc = queryKeyTransposedInputTensorDesc.GetDmlDesc();
+
+ TensorDesc queryKeyTransposedOutputTensorDesc = TensorDesc(
+ GetDmlDataTypeFromMlDataType(dataType),
+ queryKeyTransposedTensorShape);
+ DML_TENSOR_DESC namedQueryKeyTransposedOutputTensorDesc = queryKeyTransposedOutputTensorDesc.GetDmlDesc();
+
+ // Transpose QKV from [batchSize, sequenceLength, 3, numHeads, headSize] to [batchSize, sequenceLength, numHeads, 3, headSize]
+ std::array queryKeyValueTransposedTensorShape = {batchSize, sequenceLength, numHeads, 3, headSize};
+ std::array queryKeyValueTransposedStrides = {
+ sequenceLength * numHeads * 3 * headSize,
+ numHeads * 3 * headSize,
+ headSize,
+ numHeads * headSize,
+ 1,
+ };
+
+ TensorDesc queryKeyValueTransposedInputTensorDesc = TensorDesc(
+ GetDmlDataTypeFromMlDataType(dataType),
+ queryKeyValueTransposedTensorShape,
+ queryKeyValueTransposedStrides);
+ DML_TENSOR_DESC namedQueryKeyValueTransposedInputTensorDesc = queryKeyValueTransposedInputTensorDesc.GetDmlDesc();
+
+ TensorDesc queryKeyValueTransposedOutputTensorDesc = TensorDesc(
+ GetDmlDataTypeFromMlDataType(dataType),
+ queryKeyValueTransposedTensorShape);
+ DML_TENSOR_DESC namedQueryKeyValueTransposedOutputTensorDesc = queryKeyValueTransposedOutputTensorDesc.GetDmlDesc();
+
+ std::array queryKeySliceOffset = {0, 0, 0};
+ std::array queryKeySliceSize = {batchSize, sequenceLength, hiddenSize + hiddenSize};
+ std::array queryKeySliceStrides = {1, 1, 1};
+
+ std::array valueSliceOffset = {0, 0, 2 * hiddenSize};
+ std::array valueSliceSize = {batchSize, sequenceLength, hiddenSize};
+ std::array valueSliceStrides = {1, 1, 1};
+
+ // When Q/K/V all have the same hidden size, we just have to transpose it before sending it to MHA
+ DML_ELEMENT_WISE_IDENTITY_OPERATOR_DESC transposeOperatorDesc = {};
+
+ transposeOperatorDesc.InputTensor = &namedQueryKeyValueTransposedInputTensorDesc;
+ transposeOperatorDesc.OutputTensor = &namedQueryKeyValueTransposedOutputTensorDesc;
+
+ const DML_OPERATOR_DESC transposedDesc = { DML_OPERATOR_ELEMENT_WISE_IDENTITY, &transposeOperatorDesc};
+
+ std::array maskSliceOutputShape = {batchSize, numHeads, sequenceLength, sequenceLength};
+ std::array maskSliceStrides = {1, 1, 1, 1};
+ std::array maskSliceOffsets = {0, 0, 0, 0};
+ TensorDesc maskSliceOutputTensorDesc;
+ DML_TENSOR_DESC namedMaskSliceOutputTensorDesc;
+
+ DML_SLICE1_OPERATOR_DESC maskSlicedOperatorDesc = {};
+ if (hasMaxSequenceMask)
+ {
+ maskSliceOutputTensorDesc = TensorDesc::ConstructDefaultTensorDesc(maskTensorDataType, maskSliceOutputShape);
+ namedMaskSliceOutputTensorDesc = maskSliceOutputTensorDesc.GetDmlDesc();
+ maskSlicedOperatorDesc.InputTensor = &inputDescs[maskIndex];
+ maskSlicedOperatorDesc.OutputTensor = &namedMaskSliceOutputTensorDesc;
+ maskSlicedOperatorDesc.DimensionCount = gsl::narrow_cast(maskSliceOutputShape.size());
+ maskSlicedOperatorDesc.InputWindowOffsets = maskSliceOffsets.data();
+ maskSlicedOperatorDesc.InputWindowSizes = maskSliceOutputShape.data();
+ maskSlicedOperatorDesc.InputWindowStrides = maskSliceStrides.data();
+ }
+ const DML_OPERATOR_DESC maskSlicedDesc = { DML_OPERATOR_SLICE1, &maskSlicedOperatorDesc};
+
+ // We need to slice Past to get PastValue and PastKey tensors for MHA
+ std::array pastKeyOutputShape = {1, batchSize, numHeads, pastSequenceLength, headSize};
+ std::array pastKeyStrides = {1, 1, 1, 1, 1};
+ std::array pastKeyOffsets = {0, 0, 0, 0, 0};
+ TensorDesc pastKeyOutputTensorDesc;
+ DML_TENSOR_DESC namedPastKeyOutputTensorDesc;
+
+ std::array pastValueOutputShape = {1, batchSize, numHeads, pastSequenceLength, headSize};
+ std::array pastValueStrides = {1, 1, 1, 1, 1};
+ std::array pastValueOffsets = {1, 0, 0, 0, 0};
+ TensorDesc pastValueOutputTensorDesc;
+ DML_TENSOR_DESC namedPastValueOutputTensorDesc;
+
+ DML_SLICE1_OPERATOR_DESC pastKeySlicedOperatorDesc = {};
+ DML_SLICE1_OPERATOR_DESC pastValueSlicedOperatorDesc = {};
+
+ if (hasPast)
+ {
+ pastKeyOutputTensorDesc = TensorDesc::ConstructDefaultTensorDesc(pastTensorDataType, pastKeyOutputShape);
+ namedPastKeyOutputTensorDesc = pastKeyOutputTensorDesc.GetDmlDesc();
+ pastKeySlicedOperatorDesc.InputTensor = &inputDescs[pastIndex];
+ pastKeySlicedOperatorDesc.OutputTensor = &namedPastKeyOutputTensorDesc;
+ pastKeySlicedOperatorDesc.DimensionCount = gsl::narrow_cast(pastKeyOutputShape.size());
+ pastKeySlicedOperatorDesc.InputWindowOffsets = pastKeyOffsets.data();
+ pastKeySlicedOperatorDesc.InputWindowSizes = pastKeyOutputShape.data();
+ pastKeySlicedOperatorDesc.InputWindowStrides = pastKeyStrides.data();
+
+ pastValueOutputTensorDesc = TensorDesc::ConstructDefaultTensorDesc(pastTensorDataType, pastValueOutputShape);
+ namedPastValueOutputTensorDesc = pastValueOutputTensorDesc.GetDmlDesc();
+ pastValueSlicedOperatorDesc.InputTensor = &inputDescs[pastIndex];
+ pastValueSlicedOperatorDesc.OutputTensor = &namedPastValueOutputTensorDesc;
+ pastValueSlicedOperatorDesc.DimensionCount = gsl::narrow_cast(pastValueOutputShape.size());
+ pastValueSlicedOperatorDesc.InputWindowOffsets = pastValueOffsets.data();
+ pastValueSlicedOperatorDesc.InputWindowSizes = pastValueOutputShape.data();
+ pastValueSlicedOperatorDesc.InputWindowStrides = pastValueStrides.data();
+ }
+
+ const DML_OPERATOR_DESC pastKeySlicedDesc = { DML_OPERATOR_SLICE1, &pastKeySlicedOperatorDesc};
+ const DML_OPERATOR_DESC pastValueSlicedDesc = { DML_OPERATOR_SLICE1, &pastValueSlicedOperatorDesc};
+
+ // Causal Mask: Upper Triangular Boolean Matrix
+ // Example: [[1, 0, 0, 0, 0],
+ // [1, 1, 0, 0, 0],
+ // [1, 1, 1, 0, 0],
+ // [1, 1, 1, 1, 0]]
+ // DML adds maskFilterValue to the "off" bits in the mask and sets the "on" bits to 0
+ // passed to MHA as maskIndex Tensor when unidirectional == 1
+ std::array causalMaskOutputShape = {1, 1, sequenceLength, pastSequenceLength + sequenceLength};
+ TensorDesc causalMaskTensorDesc;
+ DML_DIAGONAL_MATRIX1_OPERATOR_DESC causalMaskOperatorDesc = {};
+ DML_TENSOR_DESC namedcausalMaskTensorDesc;
+
+ if (unidirectional && !hasMask)
+ {
+ causalMaskTensorDesc = TensorDesc::ConstructDefaultTensorDesc(MLOperatorTensorDataType::Int32, causalMaskOutputShape);
+ namedcausalMaskTensorDesc = causalMaskTensorDesc.GetDmlDesc();
+ causalMaskOperatorDesc.ValueDataType = DML_TENSOR_DATA_TYPE_INT32;
+ causalMaskOperatorDesc.DiagonalFillBegin = INT32_MIN;
+ causalMaskOperatorDesc.DiagonalFillEnd = pastSequenceLength + 1;
+ causalMaskOperatorDesc.Value.Int32 = 1;
+ causalMaskOperatorDesc.OutputTensor = &namedcausalMaskTensorDesc;
+ maskType = DML_MULTIHEAD_ATTENTION_MASK_TYPE_BOOLEAN;
+ }
+ DML_OPERATOR_DESC causalMaskDesc = { DML_OPERATOR_DIAGONAL_MATRIX1, &causalMaskOperatorDesc };
+
+ DML_MULTIHEAD_ATTENTION_OPERATOR_DESC mhaOperatorDesc = {};
+ std::array presentKeyOutputShape = {1, batchSize, numHeads, pastSequenceLength + sequenceLength, headSize};
+ std::array presentValueOutputShape = {1, batchSize, numHeads, pastSequenceLength + sequenceLength, headSize};
+ TensorDesc presentKeyTensorDesc;
+ TensorDesc presentValueTensorDesc;
+ DML_TENSOR_DESC namedPresentKeyOutputTensorDesc;
+ DML_TENSOR_DESC namedPresentValueOutputTensorDesc;
+
+ mhaOperatorDesc.StackedQueryKeyValueTensor = &namedQueryKeyValueTransposedOutputTensorDesc;
+
+ // Broadcast to MHA MaskTensor Shape
+ std::array mhaMaskTensorShape = {batchSize, numHeads, sequenceLength, pastSequenceLength + sequenceLength};
+ TensorDesc broadcastedcausalMaskTensorDesc;
+ DML_TENSOR_DESC namedbroadcastedcausalMaskTensorDesc;
+ if (unidirectional && !hasMask)
+ {
+ broadcastedcausalMaskTensorDesc = TensorDesc::ConstructBroadcastedTensorDesc(MLOperatorTensorDataType::Int32, mhaMaskTensorShape, causalMaskOutputShape);
+ namedbroadcastedcausalMaskTensorDesc = broadcastedcausalMaskTensorDesc.GetDmlDesc();
+ mhaOperatorDesc.MaskTensor = &namedbroadcastedcausalMaskTensorDesc;
+ }
+ else if (hasMaxSequenceMask)
+ {
+ mhaOperatorDesc.MaskTensor = &namedMaskSliceOutputTensorDesc;
+ }
+ else
+ {
+ mhaOperatorDesc.MaskTensor = hasMask ? &inputDescs[maskIndex] : nullptr;
+ }
+
+ mhaOperatorDesc.RelativePositionBiasTensor = nullptr;
+ mhaOperatorDesc.OutputTensor = &outputDescs[outputIndex];
+ mhaOperatorDesc.Scale = kernelCreationContext.GetOptionalAttribute(AttrName::Scale, gsl::narrow_cast(1.0f / std::sqrt(headSize)));
+ // Set MaskFilterValue to lowest float for Causal Mask
+ mhaOperatorDesc.MaskFilterValue = unidirectional ? std::numeric_limits::lowest() :
+ kernelCreationContext.GetOptionalAttribute(AttrName::MaskFilterValue, -10'000.0f);
+ mhaOperatorDesc.HeadCount = numHeads;
+ mhaOperatorDesc.MaskType = maskType;
+ if (hasPast)
+ {
+ presentKeyTensorDesc = TensorDesc::ConstructDefaultTensorDesc(presentTensorDataType, presentKeyOutputShape);
+ namedPresentKeyOutputTensorDesc = presentKeyTensorDesc.GetDmlDesc();
+ presentValueTensorDesc = TensorDesc::ConstructDefaultTensorDesc(presentTensorDataType, presentValueOutputShape);
+ namedPresentValueOutputTensorDesc = presentValueTensorDesc.GetDmlDesc();
+ mhaOperatorDesc.PastKeyTensor = &namedPastKeyOutputTensorDesc;
+ mhaOperatorDesc.PastValueTensor = &namedPastValueOutputTensorDesc;
+ mhaOperatorDesc.OutputPresentKeyTensor = &namedPresentKeyOutputTensorDesc;
+ mhaOperatorDesc.OutputPresentValueTensor = &namedPresentValueOutputTensorDesc;
+ }
+
+ const DML_OPERATOR_DESC mhaDesc = { DML_OPERATOR_MULTIHEAD_ATTENTION, &mhaOperatorDesc };
+
+ DML_JOIN_OPERATOR_DESC presentKeyValueJoinOperatorDesc = {};
+ std::vector joinInputDesc;
+
+ if (hasPast)
+ {
+ joinInputDesc.push_back(namedPresentKeyOutputTensorDesc);
+ joinInputDesc.push_back(namedPresentValueOutputTensorDesc);
+ presentKeyValueJoinOperatorDesc.InputCount = gsl::narrow_cast(joinInputDesc.size());
+ presentKeyValueJoinOperatorDesc.InputTensors = joinInputDesc.data();
+ presentKeyValueJoinOperatorDesc.OutputTensor = &outputDescs[presentIndex];
+ presentKeyValueJoinOperatorDesc.Axis = gsl::narrow_cast(0);
+ }
+
+ DML_OPERATOR_DESC presentKeyValueJoinDesc = { DML_OPERATOR_JOIN, &presentKeyValueJoinOperatorDesc };
+
+ // Construct the graph
+ std::vector inputEdges;
+ std::vector intermediateEdges;
+ std::vector outputEdges;
+
+ std::vector opDescs = {
+ &matMulIntToFloatDesc,
+ &mhaDesc,
+ };
+
+ uint32_t currentNodeIndex = 0;
+ const uint32_t matMulIntToFloatNodeIndex = currentNodeIndex++;
+ const uint32_t mhaNodeIndex = currentNodeIndex++;
+
+ uint32_t queryKeyValueTransposedNodeIndex = 0;
+
+ opDescs.push_back(&transposedDesc);
+ queryKeyValueTransposedNodeIndex = currentNodeIndex++;
+
+ uint32_t maskSliceNodeIndex = 0;
+ if (hasMaxSequenceMask)
+ {
+ opDescs.push_back(&maskSlicedDesc);
+ maskSliceNodeIndex = currentNodeIndex++;
+ }
+
+ uint32_t pastKeySliceNodeIndex = 0;
+ uint32_t pastValueSliceNodeIndex = 0;
+ uint32_t concatNodeIndex = 0;
+ if (hasPast)
+ {
+ opDescs.push_back(&pastKeySlicedDesc);
+ pastKeySliceNodeIndex = currentNodeIndex++;
+ opDescs.push_back(&pastValueSlicedDesc);
+ pastValueSliceNodeIndex = currentNodeIndex++;
+ opDescs.push_back(&presentKeyValueJoinDesc);
+ concatNodeIndex = currentNodeIndex++;
+ }
+
+ uint32_t causalMaskNodeIndex = 0;
+ if (unidirectional && !hasMask)
+ {
+ opDescs.push_back(&causalMaskDesc);
+ causalMaskNodeIndex = currentNodeIndex++;
+ }
+
+ DML_INPUT_GRAPH_EDGE_DESC inputToMatMulIntToFloatEdge = {};
+ inputToMatMulIntToFloatEdge.GraphInputIndex = InputIndex::inputIndex;
+ inputToMatMulIntToFloatEdge.ToNodeIndex = matMulIntToFloatNodeIndex;
+ inputToMatMulIntToFloatEdge.ToNodeInputIndex = 0;
+ inputEdges.push_back(inputToMatMulIntToFloatEdge);
+
+ DML_INPUT_GRAPH_EDGE_DESC inputScaleToMatMulIntToFloatEdge = {};
+ inputScaleToMatMulIntToFloatEdge.GraphInputIndex = InputIndex::inputScaleIndex;
+ inputScaleToMatMulIntToFloatEdge.ToNodeIndex = matMulIntToFloatNodeIndex;
+ inputScaleToMatMulIntToFloatEdge.ToNodeInputIndex = 1;
+ inputEdges.push_back(inputScaleToMatMulIntToFloatEdge);
+
+ DML_INPUT_GRAPH_EDGE_DESC inputZeroPointToMatMulIntToFloatEdge = {};
+ inputZeroPointToMatMulIntToFloatEdge.GraphInputIndex = InputIndex::inputZeroPointIndex;
+ inputZeroPointToMatMulIntToFloatEdge.ToNodeIndex = matMulIntToFloatNodeIndex;
+ inputZeroPointToMatMulIntToFloatEdge.ToNodeInputIndex = 2;
+ inputEdges.push_back(inputZeroPointToMatMulIntToFloatEdge);
+
+ DML_INPUT_GRAPH_EDGE_DESC weightToMatMulIntToFloatEdge = {};
+ weightToMatMulIntToFloatEdge.GraphInputIndex = InputIndex::weightsIndex;
+ weightToMatMulIntToFloatEdge.ToNodeIndex = matMulIntToFloatNodeIndex;
+ weightToMatMulIntToFloatEdge.ToNodeInputIndex = 3;
+ inputEdges.push_back(weightToMatMulIntToFloatEdge);
+
+ DML_INPUT_GRAPH_EDGE_DESC weightScaleToMatMulIntToFloatEdge = {};
+ weightScaleToMatMulIntToFloatEdge.GraphInputIndex = InputIndex::weightScaleIndex;
+ weightScaleToMatMulIntToFloatEdge.ToNodeIndex = matMulIntToFloatNodeIndex;
+ weightScaleToMatMulIntToFloatEdge.ToNodeInputIndex = 4;
+ inputEdges.push_back(weightScaleToMatMulIntToFloatEdge);
+
+ DML_INPUT_GRAPH_EDGE_DESC weightZeroPointToMatMulIntToFloatEdge = {};
+ weightZeroPointToMatMulIntToFloatEdge.GraphInputIndex = InputIndex::weightZeroPointIndex;
+ weightZeroPointToMatMulIntToFloatEdge.ToNodeIndex = matMulIntToFloatNodeIndex;
+ weightZeroPointToMatMulIntToFloatEdge.ToNodeInputIndex = 5;
+ inputEdges.push_back(weightZeroPointToMatMulIntToFloatEdge);
+
+ if (hasBias)
+ {
+ DML_INPUT_GRAPH_EDGE_DESC biasToMatMulIntToFloatEdge = {};
+ biasToMatMulIntToFloatEdge.GraphInputIndex = InputIndex::biasIndex;
+ biasToMatMulIntToFloatEdge.ToNodeIndex = matMulIntToFloatNodeIndex;
+ biasToMatMulIntToFloatEdge.ToNodeInputIndex = 6;
+ inputEdges.push_back(biasToMatMulIntToFloatEdge);
+ }
+
+ if (hasMask)
+ {
+ if (hasUnpaddedBounds)
+ {
+ DML_INPUT_GRAPH_EDGE_DESC maskToMhaEdge = {};
+ maskToMhaEdge.GraphInputIndex = InputIndex::maskIndex;
+ maskToMhaEdge.ToNodeIndex = mhaNodeIndex;
+ maskToMhaEdge.ToNodeInputIndex = mhaMaskIndex;
+ inputEdges.push_back(maskToMhaEdge);
+ }
+ else if (hasMaxSequenceMask)
+ {
+ DML_INPUT_GRAPH_EDGE_DESC maskToMaskSliceEdge = {};
+ maskToMaskSliceEdge.GraphInputIndex = InputIndex::maskIndex;
+ maskToMaskSliceEdge.ToNodeIndex = maskSliceNodeIndex;
+ maskToMaskSliceEdge.ToNodeInputIndex = 0;
+ inputEdges.push_back(maskToMaskSliceEdge);
+
+ DML_INTERMEDIATE_GRAPH_EDGE_DESC maskSliceToMhaEdge = {};
+ maskSliceToMhaEdge.FromNodeIndex = maskSliceNodeIndex;
+ maskSliceToMhaEdge.FromNodeOutputIndex = 0;
+ maskSliceToMhaEdge.ToNodeIndex = mhaNodeIndex;
+ maskSliceToMhaEdge.ToNodeInputIndex = mhaMaskIndex;
+ intermediateEdges.push_back(maskSliceToMhaEdge);
+ }
+ else
+ {
+ DML_INPUT_GRAPH_EDGE_DESC maskToMhaEdge = {};
+ maskToMhaEdge.GraphInputIndex = InputIndex::maskIndex;
+ maskToMhaEdge.ToNodeIndex = mhaNodeIndex;
+ maskToMhaEdge.ToNodeInputIndex = mhaMaskIndex;
+ inputEdges.push_back(maskToMhaEdge);
+ }
+ }
+ else if (unidirectional)
+ {
+ DML_INTERMEDIATE_GRAPH_EDGE_DESC causalMaskToMhaEdge = {};
+ causalMaskToMhaEdge.FromNodeIndex = causalMaskNodeIndex;
+ causalMaskToMhaEdge.FromNodeOutputIndex = 0;
+ causalMaskToMhaEdge.ToNodeIndex = mhaNodeIndex;
+ causalMaskToMhaEdge.ToNodeInputIndex = mhaMaskIndex;
+ intermediateEdges.push_back(causalMaskToMhaEdge);
+ }
+
+ if (hasPast)
+ {
+ DML_INPUT_GRAPH_EDGE_DESC pastToPastKeySliceEdge = {};
+ pastToPastKeySliceEdge.GraphInputIndex = InputIndex::pastIndex;
+ pastToPastKeySliceEdge.ToNodeIndex = pastKeySliceNodeIndex;
+ pastToPastKeySliceEdge.ToNodeInputIndex = 0;
+ inputEdges.push_back(pastToPastKeySliceEdge);
+
+ DML_INPUT_GRAPH_EDGE_DESC pastToPastValueSliceEdge = {};
+ pastToPastValueSliceEdge.GraphInputIndex = InputIndex::pastIndex;
+ pastToPastValueSliceEdge.ToNodeIndex = pastValueSliceNodeIndex;
+ pastToPastValueSliceEdge.ToNodeInputIndex = 0;
+ inputEdges.push_back(pastToPastValueSliceEdge);
+
+ DML_INTERMEDIATE_GRAPH_EDGE_DESC pastKeyToMhaEdge = {};
+ pastKeyToMhaEdge.FromNodeIndex = pastKeySliceNodeIndex;
+ pastKeyToMhaEdge.FromNodeOutputIndex = 0;
+ pastKeyToMhaEdge.ToNodeIndex = mhaNodeIndex;
+ pastKeyToMhaEdge.ToNodeInputIndex = mhaPastKeyIndex;
+ intermediateEdges.push_back(pastKeyToMhaEdge);
+
+ DML_INTERMEDIATE_GRAPH_EDGE_DESC pastValueToMhaEdge = {};
+ pastValueToMhaEdge.FromNodeIndex = pastValueSliceNodeIndex;
+ pastValueToMhaEdge.FromNodeOutputIndex = 0;
+ pastValueToMhaEdge.ToNodeIndex = mhaNodeIndex;
+ pastValueToMhaEdge.ToNodeInputIndex = mhaPastValueIndex;
+ intermediateEdges.push_back(pastValueToMhaEdge);
+
+ DML_INTERMEDIATE_GRAPH_EDGE_DESC presentKeyToConcatEdge = {};
+ presentKeyToConcatEdge.FromNodeIndex = mhaNodeIndex;
+ presentKeyToConcatEdge.FromNodeOutputIndex = mhaPresentKeyIndex;
+ presentKeyToConcatEdge.ToNodeIndex = concatNodeIndex;
+ presentKeyToConcatEdge.ToNodeInputIndex = 0;
+ intermediateEdges.push_back(presentKeyToConcatEdge);
+
+ DML_INTERMEDIATE_GRAPH_EDGE_DESC presentValueToConcatEdge = {};
+ presentValueToConcatEdge.FromNodeIndex = mhaNodeIndex;
+ presentValueToConcatEdge.FromNodeOutputIndex = mhaPresentValueIndex;
+ presentValueToConcatEdge.ToNodeIndex = concatNodeIndex;
+ presentValueToConcatEdge.ToNodeInputIndex = 1;
+ intermediateEdges.push_back(presentValueToConcatEdge);
+ }
+
+ DML_INTERMEDIATE_GRAPH_EDGE_DESC matMulIntToFloatToQueryKeyValueTransposeEdge = {};
+ matMulIntToFloatToQueryKeyValueTransposeEdge.FromNodeIndex = matMulIntToFloatNodeIndex;
+ matMulIntToFloatToQueryKeyValueTransposeEdge.FromNodeOutputIndex = 0;
+ matMulIntToFloatToQueryKeyValueTransposeEdge.ToNodeIndex = queryKeyValueTransposedNodeIndex;
+ matMulIntToFloatToQueryKeyValueTransposeEdge.ToNodeInputIndex = 0;
+ intermediateEdges.push_back(matMulIntToFloatToQueryKeyValueTransposeEdge);
+
+ // All we need to do here is transpose the stacked QKV tensor into something DML supports
+ DML_INTERMEDIATE_GRAPH_EDGE_DESC queryKeyValueTransposedToMhaEdge = {};
+ queryKeyValueTransposedToMhaEdge.FromNodeIndex = queryKeyValueTransposedNodeIndex;
+ queryKeyValueTransposedToMhaEdge.FromNodeOutputIndex = 0;
+ queryKeyValueTransposedToMhaEdge.ToNodeIndex = mhaNodeIndex;
+ queryKeyValueTransposedToMhaEdge.ToNodeInputIndex = mhaStackedQueryKeyValueIndex;
+ intermediateEdges.push_back(queryKeyValueTransposedToMhaEdge);
+
+ DML_OUTPUT_GRAPH_EDGE_DESC mhaToOutputEdge = {};
+ mhaToOutputEdge.FromNodeIndex = mhaNodeIndex;
+ mhaToOutputEdge.FromNodeOutputIndex = mhaOutputIndex;
+ mhaToOutputEdge.GraphOutputIndex = OutputIndex::outputIndex;
+ outputEdges.push_back(mhaToOutputEdge);
+
+ if (hasPast)
+ {
+ DML_OUTPUT_GRAPH_EDGE_DESC concatToOutputEdge = {};
+ concatToOutputEdge.FromNodeIndex = concatNodeIndex;
+ concatToOutputEdge.FromNodeOutputIndex = 0;
+ concatToOutputEdge.GraphOutputIndex = OutputIndex::presentIndex;
+ outputEdges.push_back(concatToOutputEdge);
+ }
+
+ MLOperatorGraphDesc operatorGraphDesc = {};
+ operatorGraphDesc.inputEdgeCount = gsl::narrow_cast(inputEdges.size());
+ operatorGraphDesc.inputEdges = inputEdges.data();
+ operatorGraphDesc.intermediateEdgeCount = gsl::narrow_cast(intermediateEdges.size());
+ operatorGraphDesc.intermediateEdges = intermediateEdges.data();
+ operatorGraphDesc.outputEdgeCount = gsl::narrow_cast(outputEdges.size());
+ operatorGraphDesc.outputEdges = outputEdges.data();
+ operatorGraphDesc.nodeCount = gsl::narrow_cast(opDescs.size());
+ operatorGraphDesc.nodes = opDescs.data();
+
+ SetDmlOperatorGraphDesc(std::move(operatorGraphDesc), kernelCreationContext);
+ }
+};
+
+void CALLBACK QueryQAttention(IMLOperatorSupportQueryContextPrivate* context, /*out*/ bool* isSupported)
+{
+ *isSupported = false;
+
+ // `unidirectional == 1` with Mask Tensor is not supported yet
+ MLOperatorAttributes attributes(context);
+ if (attributes.GetOptionalAttribute(AttrName::Unidirectional, 0) != 0 && context->IsInputValid(5))
+ {
+ return;
+ }
+
+ // `do_rotary == 1` is not supported yet
+ if (attributes.GetOptionalAttribute(AttrName::DoRotary, 0) != 0)
+ {
+ return;
+ }
+
+ // `past_present_share_buffer == 1` is not supported yet
+ if (attributes.GetOptionalAttribute(AttrName::PastPresentShareBuffer, 0) != 0)
+ {
+ return;
+ }
+
+ *isSupported = true;
+}
+
+DML_OP_DEFINE_CREATION_FUNCTION(QAttention, DmlOperatorQAttention);
+} // namespace Dml
diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorQLinearSigmoid.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorQLinearSigmoid.cpp
index f658e7c7da..bc0082fef3 100644
--- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorQLinearSigmoid.cpp
+++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorQLinearSigmoid.cpp
@@ -178,4 +178,4 @@ void CALLBACK QueryQLinearSigmoid(IMLOperatorSupportQueryContextPrivate* context
}
DML_OP_DEFINE_CREATION_FUNCTION(QLinearSigmoid, DmlOperatorQLinearSigmoid);
-} // namespace Dml
+} // namespace Dml
\ No newline at end of file
diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.cpp
index 38cf80b381..71fc8741bf 100644
--- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.cpp
+++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.cpp
@@ -516,6 +516,7 @@ DML_OP_EXTERN_CREATION_FUNCTION(Resize19);
DML_OP_EXTERN_CREATION_FUNCTION(Shape);
DML_OP_EXTERN_CREATION_FUNCTION(Size);
+DML_OP_EXTERN_CREATION_FUNCTION(QAttention);
DML_OP_EXTERN_CREATION_FUNCTION(Attention);
DML_OP_EXTERN_CREATION_FUNCTION(MultiHeadAttention);
DML_OP_EXTERN_CREATION_FUNCTION(NonZero);
@@ -537,6 +538,7 @@ DML_OP_EXTERN_QUERY_FUNCTION(Pad);
DML_OP_EXTERN_QUERY_FUNCTION(LayerNormalization);
DML_OP_EXTERN_QUERY_FUNCTION(SkipLayerNormalization);
DML_OP_EXTERN_QUERY_FUNCTION(QLinearSigmoid);
+DML_OP_EXTERN_QUERY_FUNCTION(QAttention);
DML_OP_EXTERN_QUERY_FUNCTION(Attention);
constexpr static std::array typeNameListDefault = {"T"};
@@ -614,15 +616,23 @@ constexpr static std::array supportedTypeListLayerN
constexpr static std::array supportedTypeListShape = {SupportedTensorDataTypes::All, SupportedTensorDataTypes::Int64};
constexpr static std::array supportedTypeListSize = {SupportedTensorDataTypes::All, SupportedTensorDataTypes::Int64};
constexpr static std::array supportedTypeListQLinearSigmoid = {SupportedTensorDataTypes::UInt8 | SupportedTensorDataTypes::Int8};
+
+constexpr static std::array supportedTypeListQAttention = {
+ SupportedTensorDataTypes::Ints8Bit,
+ SupportedTensorDataTypes::Ints8Bit,
+ SupportedTensorDataTypes::Float16to32,
+ SupportedTensorDataTypes::Int32
+};
+
constexpr static std::array supportedTypeListAttention = {SupportedTensorDataTypes::Float16to32, SupportedTensorDataTypes::Int32};
constexpr static std::array supportedTypeListRotaryEmbedding = {SupportedTensorDataTypes::Float16to32, SupportedTensorDataTypes::Int64};
constexpr static std::array supportedTypeListGroupNorm = {SupportedTensorDataTypes::Float16to32, SupportedTensorDataTypes::Float16to32};
constexpr static std::array supportedTypeListNonZero = {SupportedTensorDataTypes::Float16to32 | SupportedTensorDataTypes::Ints8Bit | SupportedTensorDataTypes::Ints16Bit | SupportedTensorDataTypes::Ints32Bit | SupportedTensorDataTypes::Bool};
constexpr static std::array supportedTypeListQLinearMatMul = {
- SupportedTensorDataTypes::Int8|SupportedTensorDataTypes::UInt8,
- SupportedTensorDataTypes::Int8|SupportedTensorDataTypes::UInt8,
- SupportedTensorDataTypes::Int8|SupportedTensorDataTypes::UInt8
+ SupportedTensorDataTypes::Ints8Bit,
+ SupportedTensorDataTypes::Ints8Bit,
+ SupportedTensorDataTypes::Ints8Bit
};
constexpr static std::array supportedTypeListMatMulIntegerToFloat = {
@@ -632,9 +642,9 @@ constexpr static std::array supportedTypeListMatMul
};
constexpr static std::array supportedTypeListQLinearConv = {
- SupportedTensorDataTypes::Int8|SupportedTensorDataTypes::UInt8,
- SupportedTensorDataTypes::Int8|SupportedTensorDataTypes::UInt8,
- SupportedTensorDataTypes::Int8|SupportedTensorDataTypes::UInt8,
+ SupportedTensorDataTypes::Ints8Bit,
+ SupportedTensorDataTypes::Ints8Bit,
+ SupportedTensorDataTypes::Ints8Bit,
SupportedTensorDataTypes::Int32
};
@@ -1069,6 +1079,7 @@ constexpr static OperatorRegistrationInformation operatorRegistrationInformation
{REG_INFO_MS( 1, DynamicQuantizeMatMul, typeNameListTwo, supportedTypeListDynamicQuantizeLinear, DmlGraphSupport::Supported)},
{REG_INFO_MS( 1, FusedMatMulActivation, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO_MS( 1, QLinearSigmoid, typeNameListDefault, supportedTypeListQLinearSigmoid, DmlGraphSupport::Supported, requiredConstantCpuInputs(), std::nullopt, QueryQLinearSigmoid)},
+ {REG_INFO_MS( 1, QAttention, typeNameListFour, supportedTypeListQAttention, DmlGraphSupport::Supported, requiredConstantCpuInputs(), std::nullopt, QueryQAttention)},
{REG_INFO_MS( 1, Attention, typeNameListAttention, supportedTypeListAttention, DmlGraphSupport::Supported, requiredConstantCpuInputs(), std::nullopt, QueryAttention)},
{REG_INFO_MS( 1, MultiHeadAttention, typeNameListAttention, supportedTypeListAttention, DmlGraphSupport::Supported)},
{REG_INFO_MS( 1, RotaryEmbedding, typeNameListRotaryEmbedding, supportedTypeListRotaryEmbedding, DmlGraphSupport::Supported)},
diff --git a/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.cpp b/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.cpp
index 317f5ebcbc..acda1a516b 100644
--- a/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.cpp
+++ b/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.cpp
@@ -2802,6 +2802,48 @@ namespace OperatorHelper
m_qkvHiddenSizes = kernelInformation.GetAttributes().GetOptionalAttributeVectorInt32(AttrName::QkvHiddenSizes);
}
+ std::vector QAttentionHelper::GetOutputShapes(const MLShapeInferenceContext& shapeInfo) const
+ {
+ ML_CHECK_VALID_ARGUMENT(shapeInfo.GetInputCount() >= 5);
+
+ auto queryShape = shapeInfo.GetInputTensorShape(0);
+ ML_CHECK_VALID_ARGUMENT(queryShape.size() == 3);
+
+ auto weightShape = shapeInfo.GetInputTensorShape(1);
+ ML_CHECK_VALID_ARGUMENT(weightShape.size() == 2);
+ ML_CHECK_VALID_ARGUMENT(weightShape[1] % 3 == 0);
+
+ const uint32_t batchSize = queryShape[0];
+ const uint32_t sequenceLength = queryShape[1];
+ const uint32_t hiddenSize = weightShape[1] / 3;
+ const uint32_t headSize = hiddenSize / m_numHeads;
+
+ std::vector outputShapes(2);
+
+ outputShapes[0] = EdgeShapes({batchSize, sequenceLength, hiddenSize});
+
+ uint32_t totalSequenceLength = sequenceLength;
+ if (shapeInfo.IsInputValid(8))
+ {
+ ML_CHECK_VALID_ARGUMENT(shapeInfo.GetInputTensorDimensionCount(8) == 5);
+ const uint32_t pastSequenceLength = shapeInfo.GetInputTensorShape(8)[3];
+ totalSequenceLength += pastSequenceLength;
+ }
+
+ if (shapeInfo.IsOutputValid(1))
+ {
+ ML_CHECK_VALID_ARGUMENT(shapeInfo.IsInputValid(8));
+ outputShapes[1] = EdgeShapes({2, batchSize, m_numHeads, totalSequenceLength, headSize});
+ }
+
+ return outputShapes;
+ }
+
+ void QAttentionHelper::Initialize(const IKernelInformationAdapter& kernelInformation)
+ {
+ m_numHeads = gsl::narrow_cast(kernelInformation.GetAttributes().GetAttribute(AttrName::NumHeads));
+ }
+
std::vector SkipLayerNormHelper::GetOutputShapes(const MLShapeInferenceContext& shapeInfo) const
{
ML_CHECK_VALID_ARGUMENT(shapeInfo.GetInputCount() >= 3);
diff --git a/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.h b/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.h
index 1f5daed6ea..aff31bb305 100644
--- a/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.h
+++ b/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.h
@@ -1554,6 +1554,22 @@ private:
std::vector m_qkvHiddenSizes;
};
+class QAttentionHelper
+{
+public:
+ template
+ QAttentionHelper(const Info_t& info, const Shape_t& shapeInfo)
+ {
+ Initialize(KernelInformationAdapter(info));
+ }
+
+ std::vector GetOutputShapes(const MLShapeInferenceContext& shapeInfo) const;
+
+private:
+ void Initialize(const IKernelInformationAdapter& kernelInformation);
+ uint32_t m_numHeads;
+};
+
class SkipLayerNormHelper
{
public:
@@ -1699,6 +1715,7 @@ using ShapeInferenceHelper_Affine = GetOutputShapeAsInputShapeHelper;
using ShapeInferenceHelper_QuantizeLinear = GetOutputShapeAsInputShapeHelper;
using ShapeInferenceHelper_DequantizeLinear = GetOutputShapeAsInputShapeHelper;
using ShapeInferenceHelper_QLinearSigmoid = GetOutputShapeAsInputShapeHelper;
+using ShapeInferenceHelper_QAttention = QAttentionHelper;
using ShapeInferenceHelper_Attention = AttentionHelper;
using ShapeInferenceHelper_MultiHeadAttention = MultiHeadAttentionHelper;
using ShapeInferenceHelper_RotaryEmbedding = GetOutputShapeAsInputShapeHelper;
diff --git a/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorVersions.h b/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorVersions.h
index 8de43f2705..7492b72942 100644
--- a/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorVersions.h
+++ b/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorVersions.h
@@ -448,6 +448,7 @@ namespace OperatorHelper
static const int sc_sinceVer_FusedMatMul = 1;
static const int sc_sinceVer_FusedMatMulActivation = 1;
static const int sc_sinceVer_QLinearSigmoid = 1;
+ static const int sc_sinceVer_QAttention = 1;
static const int sc_sinceVer_Attention = 1;
static const int sc_sinceVer_MatMulIntegerToFloat = 1;
static const int sc_sinceVer_MultiHeadAttention = 1;
diff --git a/onnxruntime/test/contrib_ops/quantize_attention_op_test.cc b/onnxruntime/test/contrib_ops/quantize_attention_op_test.cc
index 3af334696a..fd222583ac 100644
--- a/onnxruntime/test/contrib_ops/quantize_attention_op_test.cc
+++ b/onnxruntime/test/contrib_ops/quantize_attention_op_test.cc
@@ -20,7 +20,8 @@ namespace test {
enum class EP : char {
CPU,
CUDA,
- DNNL
+ DNNL,
+ DML
};
// input: [batch_size, sequence_length, hidden_size]
@@ -111,7 +112,9 @@ void RunQAttention(const std::vector& input_data,
execution_providers.push_back(DefaultCudaExecutionProvider());
} else if constexpr (ep == EP::CPU) {
execution_providers.push_back(DefaultCpuExecutionProvider());
- } else { // onednn ep
+ } else if constexpr (ep == EP::DML) {
+ execution_providers.push_back(DefaultDmlExecutionProvider());
+ } else { // onednn ep
execution_providers.push_back(DefaultDnnlExecutionProvider());
}
@@ -192,6 +195,52 @@ static void RunQAttentionDNNL(
#endif
}
+static void RunQAttentionDML(
+ const std::vector& input_data,
+ const std::vector& weights_data,
+ const std::vector& bias_data,
+ const std::vector& mask_index_data,
+ const std::vector& output_data,
+ int batch_size,
+ int sequence_length,
+ int hidden_size,
+ int number_of_heads,
+ bool use_special_quantize_parameter = true,
+ bool is_unidirectional = false,
+ int input_hidden_size = 0) {
+ // Return without running code if USE_DML is not defined
+#ifdef USE_DML
+ bool enable_dml = (nullptr != DefaultDmlExecutionProvider().get());
+ if (enable_dml) {
+ quantization::Params input_quant_params(/*scale=*/0.0f, /*zero_point=*/0);
+ quantization::Params weights_quant_params(/*scale=*/0.0f, /*zero_point=*/0);
+ if (use_special_quantize_parameter) {
+ input_quant_params.scale = 0.1f;
+ weights_quant_params.scale = 0.1f;
+ input_quant_params.zero_point = 128;
+ weights_quant_params.zero_point = 1;
+ }
+
+ RunQAttention(
+ input_data, weights_data, bias_data, mask_index_data, output_data, input_quant_params, weights_quant_params,
+ batch_size, sequence_length, hidden_size, number_of_heads, is_unidirectional, false, input_hidden_size);
+ }
+#else
+ ORT_UNUSED_PARAMETER(input_data);
+ ORT_UNUSED_PARAMETER(weights_data);
+ ORT_UNUSED_PARAMETER(bias_data);
+ ORT_UNUSED_PARAMETER(mask_index_data);
+ ORT_UNUSED_PARAMETER(output_data);
+ ORT_UNUSED_PARAMETER(batch_size);
+ ORT_UNUSED_PARAMETER(sequence_length);
+ ORT_UNUSED_PARAMETER(hidden_size);
+ ORT_UNUSED_PARAMETER(number_of_heads);
+ ORT_UNUSED_PARAMETER(use_special_quantize_parameter);
+ ORT_UNUSED_PARAMETER(is_unidirectional);
+ ORT_UNUSED_PARAMETER(input_hidden_size);
+#endif
+}
+
static void RunQAttentionU8U8(
const std::vector& input_data,
const std::vector& weights_data,
@@ -272,6 +321,9 @@ static void RunQAttentionAll(
RunQAttentionDNNL(input_data, weight_data, bias_data, mask_index_data, output_data,
batch_size, sequence_length, hidden_size, number_of_heads,
use_special_quantize_parameter, is_unidirectional, input_hidden_size);
+ RunQAttentionDML(input_data, weight_data, bias_data, mask_index_data, output_data,
+ batch_size, sequence_length, hidden_size, number_of_heads,
+ use_special_quantize_parameter, is_unidirectional, input_hidden_size);
}
// ONEDNN EP only supports 2D raw mask
@@ -859,8 +911,8 @@ void TestQuantizedAttentionPastState(int64_t batch,
std::vector input_dims{batch, seq_len, hidden_size};
std::vector input_data = random.Gaussian(input_dims, input_mean, static_cast(input_range / 6), input_min, input_max);
- constexpr WeightT weight_min = std::numeric_limits::min();
- constexpr WeightT weight_max = std::numeric_limits::max();
+ constexpr WeightT weight_min = std::is_same_v ? std::numeric_limits::min() / 2 : std::numeric_limits::min();
+ constexpr WeightT weight_max = std::numeric_limits::max() / 2;
constexpr int32_t weight_range = weight_max - weight_min;
std::vector weight_zero_point(weight_scale_zp_size);