[DML] QAttention (#19766)

### Description
DML Implementation for
[com.microsoft.QAttention](https://github.com/microsoft/onnxruntime/blob/main/docs/ContribOperators.md#com.microsoft.QAttention)



### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->

---------

Co-authored-by: Xiang Zhang <xianz@microsoft.com>
This commit is contained in:
raoanag 2024-03-11 09:44:34 -08:00 committed by GitHub
parent 5479124834
commit 89aa4697b1
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
8 changed files with 839 additions and 11 deletions

View file

@ -1277,6 +1277,7 @@ Do not modify directly.*
|MatMulIntegerToFloat|*in* A:**T1**<br> *in* B:**T2**<br> *in* a_scale:**T3**<br> *in* b_scale:**T3**<br> *in* a_zero_point:**T1**<br> *in* b_zero_point:**T2**<br> *in* bias:**T3**<br> *out* Y:**T3**|1+|**T1** = tensor(int8), tensor(uint8)<br/> **T2** = tensor(int8), tensor(uint8)<br/> **T3** = tensor(float), tensor(float16)|
|MultiHeadAttention|*in* query:**T**<br> *in* key:**T**<br> *in* value:**T**<br> *in* bias:**T**<br> *in* key_padding_mask:**M**<br> *in* relative_position_bias:**T**<br> *in* past_key:**T**<br> *in* past_value:**T**<br> *out* output:**T**<br> *out* present_key:**T**<br> *out* present_value:**T**|1+|**M** = tensor(int32)<br/> **T** = tensor(float), tensor(float16)|
|NhwcConv|*in* X:**T**<br> *in* W:**T**<br> *in* B:**T**<br> *out* Y:**T**|1+|**T** = tensor(float), tensor(float16)|
|QAttention|*in* input:**T1**<br> *in* weight:**T2**<br> *in* bias:**T3**<br> *in* input_scale:**T3**<br> *in* weight_scale:**T3**<br> *in* mask_index:**T4**<br> *in* input_zero_point:**T1**<br> *in* weight_zero_point:**T2**<br> *in* past:**T3**<br> *out* output:**T3**<br> *out* present:**T3**|1+|**T1** = tensor(int8), tensor(uint8)<br/> **T2** = tensor(int8), tensor(uint8)<br/> **T3** = tensor(float), tensor(float16)<br/> **T4** = tensor(int32)|
|QLinearAdd|*in* A:**T**<br> *in* A_scale:**tensor(float)**<br> *in* A_zero_point:**T**<br> *in* B:**T**<br> *in* B_scale:**tensor(float)**<br> *in* B_zero_point:**T**<br> *in* C_scale:**tensor(float)**<br> *in* C_zero_point:**T**<br> *out* C:**T**|1+|**T** = tensor(int8), tensor(uint8)|
|QLinearAveragePool|*in* X:**T**<br> *in* x_scale:**tensor(float)**<br> *in* x_zero_point:**T**<br> *in* y_scale:**tensor(float)**<br> *in* y_zero_point:**T**<br> *out* Y:**T**|1+|**T** = tensor(int8), tensor(uint8)|
|QLinearConcat|*in* Y_scale:**TF**<br> *in* Y_zero_point:**T8**<br> *in* inputs:**TV**<br> *out* Y:**T8**|1+|**T8** = tensor(int8), tensor(uint8)<br/> **TF** = tensor(float)<br/> **TV** = tensor(float), tensor(int8), tensor(uint8)|

View file

@ -0,0 +1,704 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "precomp.h"
/*
Abbreviations: B is batch_size, S is sequence_length, W is hidden_size
N is number of attention heads, H is head size, and W=N*H
Input, Weight, Bias, Mask Index and Past are Inputs
Mask Index/Causal Input Weight Bias
| \ | /
| \ | /
| \ | /
| MatMulIntToFloat
| / | \
| / | \
| / | \
| Slice Slice Slice
| | | |
| | | |
| Identity Identity Identity // The identities are used to transpose NCHW -> NHCW while
| | | | // keeping the GEMM strides as NCHW to better target metacommands
| | | |
| | | | Past
| | | | / \
| | | | / \
| | | | Slice Slice
| | | | | |
| | | | | |
| | | | | |
--------------------------MHA -----------
/ | \
/ | \
/ | \
/ | \
/ | \
/ | \
/ presentKey presentValue
/ \ /
/ \ /
/ \ /
/ Concat
/ |
Output1 Output2 (present)
This kernel creates a DML_GRAPH, as mentioned above.
For reference, refer to this Doc:
https://github.com/microsoft/onnxruntime/blob/main/docs/ContribOperators.md#commicrosoftqattention
*/
namespace Dml
{
class DmlOperatorQAttention : public DmlOperator
{
public:
DmlOperatorQAttention(const MLOperatorKernelCreationContext& kernelCreationContext)
: DmlOperator(kernelCreationContext)
{
enum InputIndex : uint32_t
{
inputIndex,
weightsIndex,
biasIndex,
inputScaleIndex,
weightScaleIndex,
maskIndex,
inputZeroPointIndex,
weightZeroPointIndex,
pastIndex,
inputCount,
};
enum OutputIndex : uint32_t
{
outputIndex,
presentIndex,
outputCount,
};
enum MhaInputIndex : uint32_t
{
mhaQueryIndex,
mhaKeyIndex,
mhaValueIndex,
mhaStackedQueryKeyIndex,
mhaStackedKeyValueIndex,
mhaStackedQueryKeyValueIndex,
mhaBiasIndex,
mhaMaskIndex,
mhaRelativePositionBiasIndex,
mhaPastKeyIndex,
mhaPastValueIndex,
mhaInputCount,
};
enum MhaOutputIndex : uint32_t
{
mhaOutputIndex,
mhaPresentKeyIndex,
mhaPresentValueIndex,
mhaOutputCount,
};
ML_CHECK_VALID_ARGUMENT(kernelCreationContext.GetInputCount() >= 5);
ML_CHECK_VALID_ARGUMENT(kernelCreationContext.GetOutputCount() >= 1);
const bool hasBias = kernelCreationContext.IsInputValid(biasIndex);
const bool hasMask = kernelCreationContext.IsInputValid(maskIndex);
const bool hasUnpaddedBounds = hasMask && kernelCreationContext.GetInputTensorDimensionCount(maskIndex) == 1;
const bool hasPast = kernelCreationContext.IsInputValid(pastIndex);
DmlOperator::Initialize(kernelCreationContext, std::nullopt, std::nullopt, std::nullopt, std::nullopt, 1);
const bool unidirectional = gsl::narrow_cast<uint32_t>(kernelCreationContext.GetAttribute<int64_t>(AttrName::Unidirectional));
const uint32_t numHeads = gsl::narrow_cast<uint32_t>(kernelCreationContext.GetAttribute<int64_t>(AttrName::NumHeads));
ML_CHECK_VALID_ARGUMENT(numHeads > 0); // to avoid process crash because of division by zero.
auto inputTensorShape = m_inputTensorDescs[inputIndex].GetSizes();
ML_CHECK_VALID_ARGUMENT(inputTensorShape.size() == 3);
auto weightTensorShape = m_inputTensorDescs[weightsIndex].GetSizes();
ML_CHECK_VALID_ARGUMENT(weightTensorShape.size() == 2);
ML_CHECK_VALID_ARGUMENT(weightTensorShape[0] == inputTensorShape[2]);
ML_CHECK_VALID_ARGUMENT(weightTensorShape[1] % 3 == 0);
if (hasBias)
{
auto biasTensorShape = m_inputTensorDescs[biasIndex].GetSizes();
ML_CHECK_VALID_ARGUMENT(biasTensorShape.size() == 1);
ML_CHECK_VALID_ARGUMENT(biasTensorShape[0] % 3 == 0);
ML_CHECK_VALID_ARGUMENT(weightTensorShape[1] == biasTensorShape[0]);
}
if (hasPast)
{
ML_CHECK_VALID_ARGUMENT(kernelCreationContext.IsOutputValid(presentIndex));
}
const uint32_t hiddenSize = weightTensorShape[1] / 3;
const uint32_t headSize = hiddenSize / numHeads;
const uint32_t batchSize = inputTensorShape[0];
const uint32_t sequenceLength = inputTensorShape[1];
const uint32_t pastSequenceLength = hasPast ? m_inputTensorDescs[pastIndex].GetSizes()[3] : 0;
uint32_t desiredWeightTensorShape[3] = {batchSize, weightTensorShape[0], 3 * hiddenSize};
MLOperatorTensorDataType dataType = kernelCreationContext.GetOutputEdgeDescription(outputIndex).tensorDataType;
m_inputTensorDescs[weightsIndex] = TensorDesc::ConstructBroadcastedTensorDesc(
kernelCreationContext.GetInputEdgeDescription(weightsIndex).tensorDataType,
desiredWeightTensorShape,
weightTensorShape);
uint32_t desiredBiasTensorShape[3] = {batchSize, sequenceLength, 3 * hiddenSize};
if (hasBias)
{
auto biasTensorShape = m_inputTensorDescs[biasIndex].GetSizes();
m_inputTensorDescs[biasIndex] = TensorDesc::ConstructBroadcastedTensorDesc(kernelCreationContext.GetInputEdgeDescription(biasIndex).tensorDataType, desiredBiasTensorShape, biasTensorShape);
}
MLOperatorTensorDataType maskTensorDataType = MLOperatorTensorDataType::Undefined;
bool hasMaxSequenceMask = false;
DML_MULTIHEAD_ATTENTION_MASK_TYPE maskType = DML_MULTIHEAD_ATTENTION_MASK_TYPE_NONE;
if (hasMask)
{
if (hasUnpaddedBounds)
{
auto unpaddedKeyBoundsShape = m_inputTensorDescs[maskIndex].GetSizes();
ML_CHECK_VALID_ARGUMENT(unpaddedKeyBoundsShape.size() == 1);
const uint32_t batchGroupCount = unpaddedKeyBoundsShape[0] / batchSize;
ML_CHECK_VALID_ARGUMENT(batchGroupCount == 1 || batchGroupCount == 2);
uint32_t desiredShape[2] = {batchGroupCount, batchSize};
m_inputTensorDescs[maskIndex] = TensorDesc(
m_inputTensorDescs[maskIndex].GetDmlDataType(),
desiredShape);
maskType = batchGroupCount == 1
? DML_MULTIHEAD_ATTENTION_MASK_TYPE_KEY_SEQUENCE_LENGTH
: DML_MULTIHEAD_ATTENTION_MASK_TYPE_KEY_SEQUENCE_END_START;
}
else
{
auto maskIndexTensorShape = m_inputTensorDescs[maskIndex].GetSizes();
ML_CHECK_VALID_ARGUMENT(maskIndexTensorShape.size() > 1 && maskIndexTensorShape.size() <= 4);
maskType = DML_MULTIHEAD_ATTENTION_MASK_TYPE_BOOLEAN;
std::vector<uint32_t> reshapedMaskIndexTensorShape(maskIndexTensorShape.begin(), maskIndexTensorShape.end());
if (maskIndexTensorShape.size() == 4 && maskIndexTensorShape[2] != sequenceLength)
{
hasMaxSequenceMask = true;
ML_CHECK_VALID_ARGUMENT(maskIndexTensorShape[2] == maskIndexTensorShape[3]);
const uint32_t maxSequenceLength = maskIndexTensorShape[2];
uint32_t desiredMaskIndexShape[4] = {batchSize, numHeads, maxSequenceLength, maxSequenceLength};
maskTensorDataType = kernelCreationContext.GetInputEdgeDescription(maskIndex).tensorDataType;
m_inputTensorDescs[maskIndex] = TensorDesc::ConstructBroadcastedTensorDesc(maskTensorDataType, desiredMaskIndexShape, reshapedMaskIndexTensorShape);
}
else
{
uint32_t maskIndexDimensionCount = gsl::narrow_cast<uint32_t>(maskIndexTensorShape.size());
reshapedMaskIndexTensorShape.insert(reshapedMaskIndexTensorShape.begin() + 1, 4 - maskIndexDimensionCount, 1);
uint32_t desiredMaskIndexShape[4] = {batchSize, numHeads, sequenceLength, sequenceLength};
maskTensorDataType = kernelCreationContext.GetInputEdgeDescription(maskIndex).tensorDataType;
m_inputTensorDescs[maskIndex] = TensorDesc::ConstructBroadcastedTensorDesc(maskTensorDataType, desiredMaskIndexShape, reshapedMaskIndexTensorShape);
}
}
}
MLOperatorTensorDataType pastTensorDataType = MLOperatorTensorDataType::Undefined;
MLOperatorTensorDataType presentTensorDataType = MLOperatorTensorDataType::Undefined;
if (hasPast)
{
pastTensorDataType = kernelCreationContext.GetInputEdgeDescription(pastIndex).tensorDataType;
presentTensorDataType = kernelCreationContext.GetOutputEdgeDescription(presentIndex).tensorDataType;
}
TensorDesc matMulIntToFloatOutputTensorDesc = TensorDesc::ConstructDefaultTensorDesc(dataType, desiredBiasTensorShape);
DML_TENSOR_DESC namedMatMulIntToFloatOutputTensorDesc = matMulIntToFloatOutputTensorDesc.GetDmlDesc();
std::vector<DML_TENSOR_DESC> inputDescs = GetDmlInputDescs();
std::vector<DML_TENSOR_DESC> outputDescs = GetDmlOutputDescs();
DML_MATRIX_MULTIPLY_INTEGER_TO_FLOAT_OPERATOR_DESC matMulIntToFloatOperatorDesc = {};
matMulIntToFloatOperatorDesc.ATensor = &inputDescs[InputIndex::inputIndex];
matMulIntToFloatOperatorDesc.AScaleTensor = &inputDescs[InputIndex::inputScaleIndex];
matMulIntToFloatOperatorDesc.AZeroPointTensor = &inputDescs[InputIndex::inputZeroPointIndex];
matMulIntToFloatOperatorDesc.BTensor = &inputDescs[InputIndex::weightsIndex];
matMulIntToFloatOperatorDesc.BScaleTensor = &inputDescs[InputIndex::weightScaleIndex];
matMulIntToFloatOperatorDesc.BZeroPointTensor = &inputDescs[InputIndex::weightZeroPointIndex];
matMulIntToFloatOperatorDesc.BiasTensor = hasBias ? &inputDescs[InputIndex::biasIndex] : nullptr;
matMulIntToFloatOperatorDesc.OutputTensor = &namedMatMulIntToFloatOutputTensorDesc;
const DML_OPERATOR_DESC matMulIntToFloatDesc = { DML_OPERATOR_MATRIX_MULTIPLY_INTEGER_TO_FLOAT, &matMulIntToFloatOperatorDesc};
std::array<uint32_t, 3> queryKeySlicedTensorShape = {batchSize, sequenceLength, hiddenSize + hiddenSize};
TensorDesc queryKeySlicedInputTensorDesc = TensorDesc::ConstructDefaultTensorDesc(dataType, queryKeySlicedTensorShape);
DML_TENSOR_DESC namedQueryKeySlicedInputTensorDesc = queryKeySlicedInputTensorDesc.GetDmlDesc();
std::array<uint32_t, 3> valueSlicedTensorShape = {batchSize, sequenceLength, hiddenSize};
TensorDesc valueSlicedInputTensorDesc = TensorDesc::ConstructDefaultTensorDesc(dataType, valueSlicedTensorShape);
DML_TENSOR_DESC namedValueSlicedInputTensorDesc = valueSlicedInputTensorDesc.GetDmlDesc();
// Transpose slice QK from [batchSize, sequenceLength, 2, numHeads, headSize] to [batchSize, sequenceLength, numHeads, 2, headSize]
std::array<uint32_t, 5> queryKeyTransposedTensorShape = {batchSize, sequenceLength, numHeads, 2, headSize};
std::array<uint32_t, 5> queryKeyTransposedStrides = {
sequenceLength * numHeads * 2 * headSize,
numHeads * 2 * headSize,
headSize,
numHeads * headSize,
1,
};
TensorDesc queryKeyTransposedInputTensorDesc = TensorDesc(
GetDmlDataTypeFromMlDataType(dataType),
queryKeyTransposedTensorShape,
queryKeyTransposedStrides);
DML_TENSOR_DESC namedQueryKeyTransposedInputTensorDesc = queryKeyTransposedInputTensorDesc.GetDmlDesc();
TensorDesc queryKeyTransposedOutputTensorDesc = TensorDesc(
GetDmlDataTypeFromMlDataType(dataType),
queryKeyTransposedTensorShape);
DML_TENSOR_DESC namedQueryKeyTransposedOutputTensorDesc = queryKeyTransposedOutputTensorDesc.GetDmlDesc();
// Transpose QKV from [batchSize, sequenceLength, 3, numHeads, headSize] to [batchSize, sequenceLength, numHeads, 3, headSize]
std::array<uint32_t, 5> queryKeyValueTransposedTensorShape = {batchSize, sequenceLength, numHeads, 3, headSize};
std::array<uint32_t, 5> queryKeyValueTransposedStrides = {
sequenceLength * numHeads * 3 * headSize,
numHeads * 3 * headSize,
headSize,
numHeads * headSize,
1,
};
TensorDesc queryKeyValueTransposedInputTensorDesc = TensorDesc(
GetDmlDataTypeFromMlDataType(dataType),
queryKeyValueTransposedTensorShape,
queryKeyValueTransposedStrides);
DML_TENSOR_DESC namedQueryKeyValueTransposedInputTensorDesc = queryKeyValueTransposedInputTensorDesc.GetDmlDesc();
TensorDesc queryKeyValueTransposedOutputTensorDesc = TensorDesc(
GetDmlDataTypeFromMlDataType(dataType),
queryKeyValueTransposedTensorShape);
DML_TENSOR_DESC namedQueryKeyValueTransposedOutputTensorDesc = queryKeyValueTransposedOutputTensorDesc.GetDmlDesc();
std::array<uint32_t, 3> queryKeySliceOffset = {0, 0, 0};
std::array<uint32_t, 3> queryKeySliceSize = {batchSize, sequenceLength, hiddenSize + hiddenSize};
std::array<int32_t, 3> queryKeySliceStrides = {1, 1, 1};
std::array<uint32_t, 3> valueSliceOffset = {0, 0, 2 * hiddenSize};
std::array<uint32_t, 3> valueSliceSize = {batchSize, sequenceLength, hiddenSize};
std::array<int32_t, 3> valueSliceStrides = {1, 1, 1};
// When Q/K/V all have the same hidden size, we just have to transpose it before sending it to MHA
DML_ELEMENT_WISE_IDENTITY_OPERATOR_DESC transposeOperatorDesc = {};
transposeOperatorDesc.InputTensor = &namedQueryKeyValueTransposedInputTensorDesc;
transposeOperatorDesc.OutputTensor = &namedQueryKeyValueTransposedOutputTensorDesc;
const DML_OPERATOR_DESC transposedDesc = { DML_OPERATOR_ELEMENT_WISE_IDENTITY, &transposeOperatorDesc};
std::array<uint32_t, 4> maskSliceOutputShape = {batchSize, numHeads, sequenceLength, sequenceLength};
std::array<int32_t, 4> maskSliceStrides = {1, 1, 1, 1};
std::array<uint32_t, 4> maskSliceOffsets = {0, 0, 0, 0};
TensorDesc maskSliceOutputTensorDesc;
DML_TENSOR_DESC namedMaskSliceOutputTensorDesc;
DML_SLICE1_OPERATOR_DESC maskSlicedOperatorDesc = {};
if (hasMaxSequenceMask)
{
maskSliceOutputTensorDesc = TensorDesc::ConstructDefaultTensorDesc(maskTensorDataType, maskSliceOutputShape);
namedMaskSliceOutputTensorDesc = maskSliceOutputTensorDesc.GetDmlDesc();
maskSlicedOperatorDesc.InputTensor = &inputDescs[maskIndex];
maskSlicedOperatorDesc.OutputTensor = &namedMaskSliceOutputTensorDesc;
maskSlicedOperatorDesc.DimensionCount = gsl::narrow_cast<uint32_t>(maskSliceOutputShape.size());
maskSlicedOperatorDesc.InputWindowOffsets = maskSliceOffsets.data();
maskSlicedOperatorDesc.InputWindowSizes = maskSliceOutputShape.data();
maskSlicedOperatorDesc.InputWindowStrides = maskSliceStrides.data();
}
const DML_OPERATOR_DESC maskSlicedDesc = { DML_OPERATOR_SLICE1, &maskSlicedOperatorDesc};
// We need to slice Past to get PastValue and PastKey tensors for MHA
std::array<uint32_t, 5> pastKeyOutputShape = {1, batchSize, numHeads, pastSequenceLength, headSize};
std::array<int32_t, 5> pastKeyStrides = {1, 1, 1, 1, 1};
std::array<uint32_t, 5> pastKeyOffsets = {0, 0, 0, 0, 0};
TensorDesc pastKeyOutputTensorDesc;
DML_TENSOR_DESC namedPastKeyOutputTensorDesc;
std::array<uint32_t, 5> pastValueOutputShape = {1, batchSize, numHeads, pastSequenceLength, headSize};
std::array<int32_t, 5> pastValueStrides = {1, 1, 1, 1, 1};
std::array<uint32_t, 5> pastValueOffsets = {1, 0, 0, 0, 0};
TensorDesc pastValueOutputTensorDesc;
DML_TENSOR_DESC namedPastValueOutputTensorDesc;
DML_SLICE1_OPERATOR_DESC pastKeySlicedOperatorDesc = {};
DML_SLICE1_OPERATOR_DESC pastValueSlicedOperatorDesc = {};
if (hasPast)
{
pastKeyOutputTensorDesc = TensorDesc::ConstructDefaultTensorDesc(pastTensorDataType, pastKeyOutputShape);
namedPastKeyOutputTensorDesc = pastKeyOutputTensorDesc.GetDmlDesc();
pastKeySlicedOperatorDesc.InputTensor = &inputDescs[pastIndex];
pastKeySlicedOperatorDesc.OutputTensor = &namedPastKeyOutputTensorDesc;
pastKeySlicedOperatorDesc.DimensionCount = gsl::narrow_cast<uint32_t>(pastKeyOutputShape.size());
pastKeySlicedOperatorDesc.InputWindowOffsets = pastKeyOffsets.data();
pastKeySlicedOperatorDesc.InputWindowSizes = pastKeyOutputShape.data();
pastKeySlicedOperatorDesc.InputWindowStrides = pastKeyStrides.data();
pastValueOutputTensorDesc = TensorDesc::ConstructDefaultTensorDesc(pastTensorDataType, pastValueOutputShape);
namedPastValueOutputTensorDesc = pastValueOutputTensorDesc.GetDmlDesc();
pastValueSlicedOperatorDesc.InputTensor = &inputDescs[pastIndex];
pastValueSlicedOperatorDesc.OutputTensor = &namedPastValueOutputTensorDesc;
pastValueSlicedOperatorDesc.DimensionCount = gsl::narrow_cast<uint32_t>(pastValueOutputShape.size());
pastValueSlicedOperatorDesc.InputWindowOffsets = pastValueOffsets.data();
pastValueSlicedOperatorDesc.InputWindowSizes = pastValueOutputShape.data();
pastValueSlicedOperatorDesc.InputWindowStrides = pastValueStrides.data();
}
const DML_OPERATOR_DESC pastKeySlicedDesc = { DML_OPERATOR_SLICE1, &pastKeySlicedOperatorDesc};
const DML_OPERATOR_DESC pastValueSlicedDesc = { DML_OPERATOR_SLICE1, &pastValueSlicedOperatorDesc};
// Causal Mask: Upper Triangular Boolean Matrix
// Example: [[1, 0, 0, 0, 0],
// [1, 1, 0, 0, 0],
// [1, 1, 1, 0, 0],
// [1, 1, 1, 1, 0]]
// DML adds maskFilterValue to the "off" bits in the mask and sets the "on" bits to 0
// passed to MHA as maskIndex Tensor when unidirectional == 1
std::array<uint32_t, 4> causalMaskOutputShape = {1, 1, sequenceLength, pastSequenceLength + sequenceLength};
TensorDesc causalMaskTensorDesc;
DML_DIAGONAL_MATRIX1_OPERATOR_DESC causalMaskOperatorDesc = {};
DML_TENSOR_DESC namedcausalMaskTensorDesc;
if (unidirectional && !hasMask)
{
causalMaskTensorDesc = TensorDesc::ConstructDefaultTensorDesc(MLOperatorTensorDataType::Int32, causalMaskOutputShape);
namedcausalMaskTensorDesc = causalMaskTensorDesc.GetDmlDesc();
causalMaskOperatorDesc.ValueDataType = DML_TENSOR_DATA_TYPE_INT32;
causalMaskOperatorDesc.DiagonalFillBegin = INT32_MIN;
causalMaskOperatorDesc.DiagonalFillEnd = pastSequenceLength + 1;
causalMaskOperatorDesc.Value.Int32 = 1;
causalMaskOperatorDesc.OutputTensor = &namedcausalMaskTensorDesc;
maskType = DML_MULTIHEAD_ATTENTION_MASK_TYPE_BOOLEAN;
}
DML_OPERATOR_DESC causalMaskDesc = { DML_OPERATOR_DIAGONAL_MATRIX1, &causalMaskOperatorDesc };
DML_MULTIHEAD_ATTENTION_OPERATOR_DESC mhaOperatorDesc = {};
std::array<uint32_t, 5> presentKeyOutputShape = {1, batchSize, numHeads, pastSequenceLength + sequenceLength, headSize};
std::array<uint32_t, 5> presentValueOutputShape = {1, batchSize, numHeads, pastSequenceLength + sequenceLength, headSize};
TensorDesc presentKeyTensorDesc;
TensorDesc presentValueTensorDesc;
DML_TENSOR_DESC namedPresentKeyOutputTensorDesc;
DML_TENSOR_DESC namedPresentValueOutputTensorDesc;
mhaOperatorDesc.StackedQueryKeyValueTensor = &namedQueryKeyValueTransposedOutputTensorDesc;
// Broadcast to MHA MaskTensor Shape
std::array<uint32_t, 4> mhaMaskTensorShape = {batchSize, numHeads, sequenceLength, pastSequenceLength + sequenceLength};
TensorDesc broadcastedcausalMaskTensorDesc;
DML_TENSOR_DESC namedbroadcastedcausalMaskTensorDesc;
if (unidirectional && !hasMask)
{
broadcastedcausalMaskTensorDesc = TensorDesc::ConstructBroadcastedTensorDesc(MLOperatorTensorDataType::Int32, mhaMaskTensorShape, causalMaskOutputShape);
namedbroadcastedcausalMaskTensorDesc = broadcastedcausalMaskTensorDesc.GetDmlDesc();
mhaOperatorDesc.MaskTensor = &namedbroadcastedcausalMaskTensorDesc;
}
else if (hasMaxSequenceMask)
{
mhaOperatorDesc.MaskTensor = &namedMaskSliceOutputTensorDesc;
}
else
{
mhaOperatorDesc.MaskTensor = hasMask ? &inputDescs[maskIndex] : nullptr;
}
mhaOperatorDesc.RelativePositionBiasTensor = nullptr;
mhaOperatorDesc.OutputTensor = &outputDescs[outputIndex];
mhaOperatorDesc.Scale = kernelCreationContext.GetOptionalAttribute<float>(AttrName::Scale, gsl::narrow_cast<float>(1.0f / std::sqrt(headSize)));
// Set MaskFilterValue to lowest float for Causal Mask
mhaOperatorDesc.MaskFilterValue = unidirectional ? std::numeric_limits<float>::lowest() :
kernelCreationContext.GetOptionalAttribute<float>(AttrName::MaskFilterValue, -10'000.0f);
mhaOperatorDesc.HeadCount = numHeads;
mhaOperatorDesc.MaskType = maskType;
if (hasPast)
{
presentKeyTensorDesc = TensorDesc::ConstructDefaultTensorDesc(presentTensorDataType, presentKeyOutputShape);
namedPresentKeyOutputTensorDesc = presentKeyTensorDesc.GetDmlDesc();
presentValueTensorDesc = TensorDesc::ConstructDefaultTensorDesc(presentTensorDataType, presentValueOutputShape);
namedPresentValueOutputTensorDesc = presentValueTensorDesc.GetDmlDesc();
mhaOperatorDesc.PastKeyTensor = &namedPastKeyOutputTensorDesc;
mhaOperatorDesc.PastValueTensor = &namedPastValueOutputTensorDesc;
mhaOperatorDesc.OutputPresentKeyTensor = &namedPresentKeyOutputTensorDesc;
mhaOperatorDesc.OutputPresentValueTensor = &namedPresentValueOutputTensorDesc;
}
const DML_OPERATOR_DESC mhaDesc = { DML_OPERATOR_MULTIHEAD_ATTENTION, &mhaOperatorDesc };
DML_JOIN_OPERATOR_DESC presentKeyValueJoinOperatorDesc = {};
std::vector<DML_TENSOR_DESC> joinInputDesc;
if (hasPast)
{
joinInputDesc.push_back(namedPresentKeyOutputTensorDesc);
joinInputDesc.push_back(namedPresentValueOutputTensorDesc);
presentKeyValueJoinOperatorDesc.InputCount = gsl::narrow_cast<uint32_t>(joinInputDesc.size());
presentKeyValueJoinOperatorDesc.InputTensors = joinInputDesc.data();
presentKeyValueJoinOperatorDesc.OutputTensor = &outputDescs[presentIndex];
presentKeyValueJoinOperatorDesc.Axis = gsl::narrow_cast<uint32_t>(0);
}
DML_OPERATOR_DESC presentKeyValueJoinDesc = { DML_OPERATOR_JOIN, &presentKeyValueJoinOperatorDesc };
// Construct the graph
std::vector<DML_INPUT_GRAPH_EDGE_DESC> inputEdges;
std::vector<DML_INTERMEDIATE_GRAPH_EDGE_DESC> intermediateEdges;
std::vector<DML_OUTPUT_GRAPH_EDGE_DESC> outputEdges;
std::vector<const DML_OPERATOR_DESC*> opDescs = {
&matMulIntToFloatDesc,
&mhaDesc,
};
uint32_t currentNodeIndex = 0;
const uint32_t matMulIntToFloatNodeIndex = currentNodeIndex++;
const uint32_t mhaNodeIndex = currentNodeIndex++;
uint32_t queryKeyValueTransposedNodeIndex = 0;
opDescs.push_back(&transposedDesc);
queryKeyValueTransposedNodeIndex = currentNodeIndex++;
uint32_t maskSliceNodeIndex = 0;
if (hasMaxSequenceMask)
{
opDescs.push_back(&maskSlicedDesc);
maskSliceNodeIndex = currentNodeIndex++;
}
uint32_t pastKeySliceNodeIndex = 0;
uint32_t pastValueSliceNodeIndex = 0;
uint32_t concatNodeIndex = 0;
if (hasPast)
{
opDescs.push_back(&pastKeySlicedDesc);
pastKeySliceNodeIndex = currentNodeIndex++;
opDescs.push_back(&pastValueSlicedDesc);
pastValueSliceNodeIndex = currentNodeIndex++;
opDescs.push_back(&presentKeyValueJoinDesc);
concatNodeIndex = currentNodeIndex++;
}
uint32_t causalMaskNodeIndex = 0;
if (unidirectional && !hasMask)
{
opDescs.push_back(&causalMaskDesc);
causalMaskNodeIndex = currentNodeIndex++;
}
DML_INPUT_GRAPH_EDGE_DESC inputToMatMulIntToFloatEdge = {};
inputToMatMulIntToFloatEdge.GraphInputIndex = InputIndex::inputIndex;
inputToMatMulIntToFloatEdge.ToNodeIndex = matMulIntToFloatNodeIndex;
inputToMatMulIntToFloatEdge.ToNodeInputIndex = 0;
inputEdges.push_back(inputToMatMulIntToFloatEdge);
DML_INPUT_GRAPH_EDGE_DESC inputScaleToMatMulIntToFloatEdge = {};
inputScaleToMatMulIntToFloatEdge.GraphInputIndex = InputIndex::inputScaleIndex;
inputScaleToMatMulIntToFloatEdge.ToNodeIndex = matMulIntToFloatNodeIndex;
inputScaleToMatMulIntToFloatEdge.ToNodeInputIndex = 1;
inputEdges.push_back(inputScaleToMatMulIntToFloatEdge);
DML_INPUT_GRAPH_EDGE_DESC inputZeroPointToMatMulIntToFloatEdge = {};
inputZeroPointToMatMulIntToFloatEdge.GraphInputIndex = InputIndex::inputZeroPointIndex;
inputZeroPointToMatMulIntToFloatEdge.ToNodeIndex = matMulIntToFloatNodeIndex;
inputZeroPointToMatMulIntToFloatEdge.ToNodeInputIndex = 2;
inputEdges.push_back(inputZeroPointToMatMulIntToFloatEdge);
DML_INPUT_GRAPH_EDGE_DESC weightToMatMulIntToFloatEdge = {};
weightToMatMulIntToFloatEdge.GraphInputIndex = InputIndex::weightsIndex;
weightToMatMulIntToFloatEdge.ToNodeIndex = matMulIntToFloatNodeIndex;
weightToMatMulIntToFloatEdge.ToNodeInputIndex = 3;
inputEdges.push_back(weightToMatMulIntToFloatEdge);
DML_INPUT_GRAPH_EDGE_DESC weightScaleToMatMulIntToFloatEdge = {};
weightScaleToMatMulIntToFloatEdge.GraphInputIndex = InputIndex::weightScaleIndex;
weightScaleToMatMulIntToFloatEdge.ToNodeIndex = matMulIntToFloatNodeIndex;
weightScaleToMatMulIntToFloatEdge.ToNodeInputIndex = 4;
inputEdges.push_back(weightScaleToMatMulIntToFloatEdge);
DML_INPUT_GRAPH_EDGE_DESC weightZeroPointToMatMulIntToFloatEdge = {};
weightZeroPointToMatMulIntToFloatEdge.GraphInputIndex = InputIndex::weightZeroPointIndex;
weightZeroPointToMatMulIntToFloatEdge.ToNodeIndex = matMulIntToFloatNodeIndex;
weightZeroPointToMatMulIntToFloatEdge.ToNodeInputIndex = 5;
inputEdges.push_back(weightZeroPointToMatMulIntToFloatEdge);
if (hasBias)
{
DML_INPUT_GRAPH_EDGE_DESC biasToMatMulIntToFloatEdge = {};
biasToMatMulIntToFloatEdge.GraphInputIndex = InputIndex::biasIndex;
biasToMatMulIntToFloatEdge.ToNodeIndex = matMulIntToFloatNodeIndex;
biasToMatMulIntToFloatEdge.ToNodeInputIndex = 6;
inputEdges.push_back(biasToMatMulIntToFloatEdge);
}
if (hasMask)
{
if (hasUnpaddedBounds)
{
DML_INPUT_GRAPH_EDGE_DESC maskToMhaEdge = {};
maskToMhaEdge.GraphInputIndex = InputIndex::maskIndex;
maskToMhaEdge.ToNodeIndex = mhaNodeIndex;
maskToMhaEdge.ToNodeInputIndex = mhaMaskIndex;
inputEdges.push_back(maskToMhaEdge);
}
else if (hasMaxSequenceMask)
{
DML_INPUT_GRAPH_EDGE_DESC maskToMaskSliceEdge = {};
maskToMaskSliceEdge.GraphInputIndex = InputIndex::maskIndex;
maskToMaskSliceEdge.ToNodeIndex = maskSliceNodeIndex;
maskToMaskSliceEdge.ToNodeInputIndex = 0;
inputEdges.push_back(maskToMaskSliceEdge);
DML_INTERMEDIATE_GRAPH_EDGE_DESC maskSliceToMhaEdge = {};
maskSliceToMhaEdge.FromNodeIndex = maskSliceNodeIndex;
maskSliceToMhaEdge.FromNodeOutputIndex = 0;
maskSliceToMhaEdge.ToNodeIndex = mhaNodeIndex;
maskSliceToMhaEdge.ToNodeInputIndex = mhaMaskIndex;
intermediateEdges.push_back(maskSliceToMhaEdge);
}
else
{
DML_INPUT_GRAPH_EDGE_DESC maskToMhaEdge = {};
maskToMhaEdge.GraphInputIndex = InputIndex::maskIndex;
maskToMhaEdge.ToNodeIndex = mhaNodeIndex;
maskToMhaEdge.ToNodeInputIndex = mhaMaskIndex;
inputEdges.push_back(maskToMhaEdge);
}
}
else if (unidirectional)
{
DML_INTERMEDIATE_GRAPH_EDGE_DESC causalMaskToMhaEdge = {};
causalMaskToMhaEdge.FromNodeIndex = causalMaskNodeIndex;
causalMaskToMhaEdge.FromNodeOutputIndex = 0;
causalMaskToMhaEdge.ToNodeIndex = mhaNodeIndex;
causalMaskToMhaEdge.ToNodeInputIndex = mhaMaskIndex;
intermediateEdges.push_back(causalMaskToMhaEdge);
}
if (hasPast)
{
DML_INPUT_GRAPH_EDGE_DESC pastToPastKeySliceEdge = {};
pastToPastKeySliceEdge.GraphInputIndex = InputIndex::pastIndex;
pastToPastKeySliceEdge.ToNodeIndex = pastKeySliceNodeIndex;
pastToPastKeySliceEdge.ToNodeInputIndex = 0;
inputEdges.push_back(pastToPastKeySliceEdge);
DML_INPUT_GRAPH_EDGE_DESC pastToPastValueSliceEdge = {};
pastToPastValueSliceEdge.GraphInputIndex = InputIndex::pastIndex;
pastToPastValueSliceEdge.ToNodeIndex = pastValueSliceNodeIndex;
pastToPastValueSliceEdge.ToNodeInputIndex = 0;
inputEdges.push_back(pastToPastValueSliceEdge);
DML_INTERMEDIATE_GRAPH_EDGE_DESC pastKeyToMhaEdge = {};
pastKeyToMhaEdge.FromNodeIndex = pastKeySliceNodeIndex;
pastKeyToMhaEdge.FromNodeOutputIndex = 0;
pastKeyToMhaEdge.ToNodeIndex = mhaNodeIndex;
pastKeyToMhaEdge.ToNodeInputIndex = mhaPastKeyIndex;
intermediateEdges.push_back(pastKeyToMhaEdge);
DML_INTERMEDIATE_GRAPH_EDGE_DESC pastValueToMhaEdge = {};
pastValueToMhaEdge.FromNodeIndex = pastValueSliceNodeIndex;
pastValueToMhaEdge.FromNodeOutputIndex = 0;
pastValueToMhaEdge.ToNodeIndex = mhaNodeIndex;
pastValueToMhaEdge.ToNodeInputIndex = mhaPastValueIndex;
intermediateEdges.push_back(pastValueToMhaEdge);
DML_INTERMEDIATE_GRAPH_EDGE_DESC presentKeyToConcatEdge = {};
presentKeyToConcatEdge.FromNodeIndex = mhaNodeIndex;
presentKeyToConcatEdge.FromNodeOutputIndex = mhaPresentKeyIndex;
presentKeyToConcatEdge.ToNodeIndex = concatNodeIndex;
presentKeyToConcatEdge.ToNodeInputIndex = 0;
intermediateEdges.push_back(presentKeyToConcatEdge);
DML_INTERMEDIATE_GRAPH_EDGE_DESC presentValueToConcatEdge = {};
presentValueToConcatEdge.FromNodeIndex = mhaNodeIndex;
presentValueToConcatEdge.FromNodeOutputIndex = mhaPresentValueIndex;
presentValueToConcatEdge.ToNodeIndex = concatNodeIndex;
presentValueToConcatEdge.ToNodeInputIndex = 1;
intermediateEdges.push_back(presentValueToConcatEdge);
}
DML_INTERMEDIATE_GRAPH_EDGE_DESC matMulIntToFloatToQueryKeyValueTransposeEdge = {};
matMulIntToFloatToQueryKeyValueTransposeEdge.FromNodeIndex = matMulIntToFloatNodeIndex;
matMulIntToFloatToQueryKeyValueTransposeEdge.FromNodeOutputIndex = 0;
matMulIntToFloatToQueryKeyValueTransposeEdge.ToNodeIndex = queryKeyValueTransposedNodeIndex;
matMulIntToFloatToQueryKeyValueTransposeEdge.ToNodeInputIndex = 0;
intermediateEdges.push_back(matMulIntToFloatToQueryKeyValueTransposeEdge);
// All we need to do here is transpose the stacked QKV tensor into something DML supports
DML_INTERMEDIATE_GRAPH_EDGE_DESC queryKeyValueTransposedToMhaEdge = {};
queryKeyValueTransposedToMhaEdge.FromNodeIndex = queryKeyValueTransposedNodeIndex;
queryKeyValueTransposedToMhaEdge.FromNodeOutputIndex = 0;
queryKeyValueTransposedToMhaEdge.ToNodeIndex = mhaNodeIndex;
queryKeyValueTransposedToMhaEdge.ToNodeInputIndex = mhaStackedQueryKeyValueIndex;
intermediateEdges.push_back(queryKeyValueTransposedToMhaEdge);
DML_OUTPUT_GRAPH_EDGE_DESC mhaToOutputEdge = {};
mhaToOutputEdge.FromNodeIndex = mhaNodeIndex;
mhaToOutputEdge.FromNodeOutputIndex = mhaOutputIndex;
mhaToOutputEdge.GraphOutputIndex = OutputIndex::outputIndex;
outputEdges.push_back(mhaToOutputEdge);
if (hasPast)
{
DML_OUTPUT_GRAPH_EDGE_DESC concatToOutputEdge = {};
concatToOutputEdge.FromNodeIndex = concatNodeIndex;
concatToOutputEdge.FromNodeOutputIndex = 0;
concatToOutputEdge.GraphOutputIndex = OutputIndex::presentIndex;
outputEdges.push_back(concatToOutputEdge);
}
MLOperatorGraphDesc operatorGraphDesc = {};
operatorGraphDesc.inputEdgeCount = gsl::narrow_cast<uint32_t>(inputEdges.size());
operatorGraphDesc.inputEdges = inputEdges.data();
operatorGraphDesc.intermediateEdgeCount = gsl::narrow_cast<uint32_t>(intermediateEdges.size());
operatorGraphDesc.intermediateEdges = intermediateEdges.data();
operatorGraphDesc.outputEdgeCount = gsl::narrow_cast<uint32_t>(outputEdges.size());
operatorGraphDesc.outputEdges = outputEdges.data();
operatorGraphDesc.nodeCount = gsl::narrow_cast<uint32_t>(opDescs.size());
operatorGraphDesc.nodes = opDescs.data();
SetDmlOperatorGraphDesc(std::move(operatorGraphDesc), kernelCreationContext);
}
};
void CALLBACK QueryQAttention(IMLOperatorSupportQueryContextPrivate* context, /*out*/ bool* isSupported)
{
*isSupported = false;
// `unidirectional == 1` with Mask Tensor is not supported yet
MLOperatorAttributes attributes(context);
if (attributes.GetOptionalAttribute<int32_t>(AttrName::Unidirectional, 0) != 0 && context->IsInputValid(5))
{
return;
}
// `do_rotary == 1` is not supported yet
if (attributes.GetOptionalAttribute<int32_t>(AttrName::DoRotary, 0) != 0)
{
return;
}
// `past_present_share_buffer == 1` is not supported yet
if (attributes.GetOptionalAttribute<int32_t>(AttrName::PastPresentShareBuffer, 0) != 0)
{
return;
}
*isSupported = true;
}
DML_OP_DEFINE_CREATION_FUNCTION(QAttention, DmlOperatorQAttention);
} // namespace Dml

View file

@ -178,4 +178,4 @@ void CALLBACK QueryQLinearSigmoid(IMLOperatorSupportQueryContextPrivate* context
}
DML_OP_DEFINE_CREATION_FUNCTION(QLinearSigmoid, DmlOperatorQLinearSigmoid);
} // namespace Dml
} // namespace Dml

View file

@ -516,6 +516,7 @@ DML_OP_EXTERN_CREATION_FUNCTION(Resize19);
DML_OP_EXTERN_CREATION_FUNCTION(Shape);
DML_OP_EXTERN_CREATION_FUNCTION(Size);
DML_OP_EXTERN_CREATION_FUNCTION(QAttention);
DML_OP_EXTERN_CREATION_FUNCTION(Attention);
DML_OP_EXTERN_CREATION_FUNCTION(MultiHeadAttention);
DML_OP_EXTERN_CREATION_FUNCTION(NonZero);
@ -537,6 +538,7 @@ DML_OP_EXTERN_QUERY_FUNCTION(Pad);
DML_OP_EXTERN_QUERY_FUNCTION(LayerNormalization);
DML_OP_EXTERN_QUERY_FUNCTION(SkipLayerNormalization);
DML_OP_EXTERN_QUERY_FUNCTION(QLinearSigmoid);
DML_OP_EXTERN_QUERY_FUNCTION(QAttention);
DML_OP_EXTERN_QUERY_FUNCTION(Attention);
constexpr static std::array<const char*, 1> typeNameListDefault = {"T"};
@ -614,15 +616,23 @@ constexpr static std::array<SupportedTensorDataTypes, 2> supportedTypeListLayerN
constexpr static std::array<SupportedTensorDataTypes, 2> supportedTypeListShape = {SupportedTensorDataTypes::All, SupportedTensorDataTypes::Int64};
constexpr static std::array<SupportedTensorDataTypes, 2> supportedTypeListSize = {SupportedTensorDataTypes::All, SupportedTensorDataTypes::Int64};
constexpr static std::array<SupportedTensorDataTypes, 1> supportedTypeListQLinearSigmoid = {SupportedTensorDataTypes::UInt8 | SupportedTensorDataTypes::Int8};
constexpr static std::array<SupportedTensorDataTypes, 4> supportedTypeListQAttention = {
SupportedTensorDataTypes::Ints8Bit,
SupportedTensorDataTypes::Ints8Bit,
SupportedTensorDataTypes::Float16to32,
SupportedTensorDataTypes::Int32
};
constexpr static std::array<SupportedTensorDataTypes, 2> supportedTypeListAttention = {SupportedTensorDataTypes::Float16to32, SupportedTensorDataTypes::Int32};
constexpr static std::array<SupportedTensorDataTypes, 2> supportedTypeListRotaryEmbedding = {SupportedTensorDataTypes::Float16to32, SupportedTensorDataTypes::Int64};
constexpr static std::array<SupportedTensorDataTypes, 2> supportedTypeListGroupNorm = {SupportedTensorDataTypes::Float16to32, SupportedTensorDataTypes::Float16to32};
constexpr static std::array<SupportedTensorDataTypes, 1> supportedTypeListNonZero = {SupportedTensorDataTypes::Float16to32 | SupportedTensorDataTypes::Ints8Bit | SupportedTensorDataTypes::Ints16Bit | SupportedTensorDataTypes::Ints32Bit | SupportedTensorDataTypes::Bool};
constexpr static std::array<SupportedTensorDataTypes, 3> supportedTypeListQLinearMatMul = {
SupportedTensorDataTypes::Int8|SupportedTensorDataTypes::UInt8,
SupportedTensorDataTypes::Int8|SupportedTensorDataTypes::UInt8,
SupportedTensorDataTypes::Int8|SupportedTensorDataTypes::UInt8
SupportedTensorDataTypes::Ints8Bit,
SupportedTensorDataTypes::Ints8Bit,
SupportedTensorDataTypes::Ints8Bit
};
constexpr static std::array<SupportedTensorDataTypes, 3> supportedTypeListMatMulIntegerToFloat = {
@ -632,9 +642,9 @@ constexpr static std::array<SupportedTensorDataTypes, 3> supportedTypeListMatMul
};
constexpr static std::array<SupportedTensorDataTypes, 4> supportedTypeListQLinearConv = {
SupportedTensorDataTypes::Int8|SupportedTensorDataTypes::UInt8,
SupportedTensorDataTypes::Int8|SupportedTensorDataTypes::UInt8,
SupportedTensorDataTypes::Int8|SupportedTensorDataTypes::UInt8,
SupportedTensorDataTypes::Ints8Bit,
SupportedTensorDataTypes::Ints8Bit,
SupportedTensorDataTypes::Ints8Bit,
SupportedTensorDataTypes::Int32
};
@ -1069,6 +1079,7 @@ constexpr static OperatorRegistrationInformation operatorRegistrationInformation
{REG_INFO_MS( 1, DynamicQuantizeMatMul, typeNameListTwo, supportedTypeListDynamicQuantizeLinear, DmlGraphSupport::Supported)},
{REG_INFO_MS( 1, FusedMatMulActivation, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO_MS( 1, QLinearSigmoid, typeNameListDefault, supportedTypeListQLinearSigmoid, DmlGraphSupport::Supported, requiredConstantCpuInputs(), std::nullopt, QueryQLinearSigmoid)},
{REG_INFO_MS( 1, QAttention, typeNameListFour, supportedTypeListQAttention, DmlGraphSupport::Supported, requiredConstantCpuInputs(), std::nullopt, QueryQAttention)},
{REG_INFO_MS( 1, Attention, typeNameListAttention, supportedTypeListAttention, DmlGraphSupport::Supported, requiredConstantCpuInputs(), std::nullopt, QueryAttention)},
{REG_INFO_MS( 1, MultiHeadAttention, typeNameListAttention, supportedTypeListAttention, DmlGraphSupport::Supported)},
{REG_INFO_MS( 1, RotaryEmbedding, typeNameListRotaryEmbedding, supportedTypeListRotaryEmbedding, DmlGraphSupport::Supported)},

View file

@ -2802,6 +2802,48 @@ namespace OperatorHelper
m_qkvHiddenSizes = kernelInformation.GetAttributes().GetOptionalAttributeVectorInt32(AttrName::QkvHiddenSizes);
}
std::vector<EdgeShapes> QAttentionHelper::GetOutputShapes(const MLShapeInferenceContext& shapeInfo) const
{
ML_CHECK_VALID_ARGUMENT(shapeInfo.GetInputCount() >= 5);
auto queryShape = shapeInfo.GetInputTensorShape(0);
ML_CHECK_VALID_ARGUMENT(queryShape.size() == 3);
auto weightShape = shapeInfo.GetInputTensorShape(1);
ML_CHECK_VALID_ARGUMENT(weightShape.size() == 2);
ML_CHECK_VALID_ARGUMENT(weightShape[1] % 3 == 0);
const uint32_t batchSize = queryShape[0];
const uint32_t sequenceLength = queryShape[1];
const uint32_t hiddenSize = weightShape[1] / 3;
const uint32_t headSize = hiddenSize / m_numHeads;
std::vector<EdgeShapes> outputShapes(2);
outputShapes[0] = EdgeShapes({batchSize, sequenceLength, hiddenSize});
uint32_t totalSequenceLength = sequenceLength;
if (shapeInfo.IsInputValid(8))
{
ML_CHECK_VALID_ARGUMENT(shapeInfo.GetInputTensorDimensionCount(8) == 5);
const uint32_t pastSequenceLength = shapeInfo.GetInputTensorShape(8)[3];
totalSequenceLength += pastSequenceLength;
}
if (shapeInfo.IsOutputValid(1))
{
ML_CHECK_VALID_ARGUMENT(shapeInfo.IsInputValid(8));
outputShapes[1] = EdgeShapes({2, batchSize, m_numHeads, totalSequenceLength, headSize});
}
return outputShapes;
}
void QAttentionHelper::Initialize(const IKernelInformationAdapter& kernelInformation)
{
m_numHeads = gsl::narrow_cast<uint32_t>(kernelInformation.GetAttributes().GetAttribute<int64_t>(AttrName::NumHeads));
}
std::vector<EdgeShapes> SkipLayerNormHelper::GetOutputShapes(const MLShapeInferenceContext& shapeInfo) const
{
ML_CHECK_VALID_ARGUMENT(shapeInfo.GetInputCount() >= 3);

View file

@ -1554,6 +1554,22 @@ private:
std::vector<int32_t> m_qkvHiddenSizes;
};
class QAttentionHelper
{
public:
template <typename Info_t, typename Shape_t>
QAttentionHelper(const Info_t& info, const Shape_t& shapeInfo)
{
Initialize(KernelInformationAdapter(info));
}
std::vector<EdgeShapes> GetOutputShapes(const MLShapeInferenceContext& shapeInfo) const;
private:
void Initialize(const IKernelInformationAdapter& kernelInformation);
uint32_t m_numHeads;
};
class SkipLayerNormHelper
{
public:
@ -1699,6 +1715,7 @@ using ShapeInferenceHelper_Affine = GetOutputShapeAsInputShapeHelper;
using ShapeInferenceHelper_QuantizeLinear = GetOutputShapeAsInputShapeHelper;
using ShapeInferenceHelper_DequantizeLinear = GetOutputShapeAsInputShapeHelper;
using ShapeInferenceHelper_QLinearSigmoid = GetOutputShapeAsInputShapeHelper;
using ShapeInferenceHelper_QAttention = QAttentionHelper;
using ShapeInferenceHelper_Attention = AttentionHelper;
using ShapeInferenceHelper_MultiHeadAttention = MultiHeadAttentionHelper;
using ShapeInferenceHelper_RotaryEmbedding = GetOutputShapeAsInputShapeHelper;

View file

@ -448,6 +448,7 @@ namespace OperatorHelper
static const int sc_sinceVer_FusedMatMul = 1;
static const int sc_sinceVer_FusedMatMulActivation = 1;
static const int sc_sinceVer_QLinearSigmoid = 1;
static const int sc_sinceVer_QAttention = 1;
static const int sc_sinceVer_Attention = 1;
static const int sc_sinceVer_MatMulIntegerToFloat = 1;
static const int sc_sinceVer_MultiHeadAttention = 1;

View file

@ -20,7 +20,8 @@ namespace test {
enum class EP : char {
CPU,
CUDA,
DNNL
DNNL,
DML
};
// input: [batch_size, sequence_length, hidden_size]
@ -111,7 +112,9 @@ void RunQAttention(const std::vector<float>& input_data,
execution_providers.push_back(DefaultCudaExecutionProvider());
} else if constexpr (ep == EP::CPU) {
execution_providers.push_back(DefaultCpuExecutionProvider());
} else { // onednn ep
} else if constexpr (ep == EP::DML) {
execution_providers.push_back(DefaultDmlExecutionProvider());
} else { // onednn ep
execution_providers.push_back(DefaultDnnlExecutionProvider());
}
@ -192,6 +195,52 @@ static void RunQAttentionDNNL(
#endif
}
static void RunQAttentionDML(
const std::vector<float>& input_data,
const std::vector<float>& weights_data,
const std::vector<float>& bias_data,
const std::vector<int32_t>& mask_index_data,
const std::vector<float>& output_data,
int batch_size,
int sequence_length,
int hidden_size,
int number_of_heads,
bool use_special_quantize_parameter = true,
bool is_unidirectional = false,
int input_hidden_size = 0) {
// Return without running code if USE_DML is not defined
#ifdef USE_DML
bool enable_dml = (nullptr != DefaultDmlExecutionProvider().get());
if (enable_dml) {
quantization::Params<uint8_t> input_quant_params(/*scale=*/0.0f, /*zero_point=*/0);
quantization::Params<int8_t> weights_quant_params(/*scale=*/0.0f, /*zero_point=*/0);
if (use_special_quantize_parameter) {
input_quant_params.scale = 0.1f;
weights_quant_params.scale = 0.1f;
input_quant_params.zero_point = 128;
weights_quant_params.zero_point = 1;
}
RunQAttention<uint8_t, int8_t, EP::DML>(
input_data, weights_data, bias_data, mask_index_data, output_data, input_quant_params, weights_quant_params,
batch_size, sequence_length, hidden_size, number_of_heads, is_unidirectional, false, input_hidden_size);
}
#else
ORT_UNUSED_PARAMETER(input_data);
ORT_UNUSED_PARAMETER(weights_data);
ORT_UNUSED_PARAMETER(bias_data);
ORT_UNUSED_PARAMETER(mask_index_data);
ORT_UNUSED_PARAMETER(output_data);
ORT_UNUSED_PARAMETER(batch_size);
ORT_UNUSED_PARAMETER(sequence_length);
ORT_UNUSED_PARAMETER(hidden_size);
ORT_UNUSED_PARAMETER(number_of_heads);
ORT_UNUSED_PARAMETER(use_special_quantize_parameter);
ORT_UNUSED_PARAMETER(is_unidirectional);
ORT_UNUSED_PARAMETER(input_hidden_size);
#endif
}
static void RunQAttentionU8U8(
const std::vector<float>& input_data,
const std::vector<float>& weights_data,
@ -272,6 +321,9 @@ static void RunQAttentionAll(
RunQAttentionDNNL(input_data, weight_data, bias_data, mask_index_data, output_data,
batch_size, sequence_length, hidden_size, number_of_heads,
use_special_quantize_parameter, is_unidirectional, input_hidden_size);
RunQAttentionDML(input_data, weight_data, bias_data, mask_index_data, output_data,
batch_size, sequence_length, hidden_size, number_of_heads,
use_special_quantize_parameter, is_unidirectional, input_hidden_size);
}
// ONEDNN EP only supports 2D raw mask
@ -859,8 +911,8 @@ void TestQuantizedAttentionPastState(int64_t batch,
std::vector<int64_t> input_dims{batch, seq_len, hidden_size};
std::vector<InputT> input_data = random.Gaussian<InputT>(input_dims, input_mean, static_cast<InputT>(input_range / 6), input_min, input_max);
constexpr WeightT weight_min = std::numeric_limits<WeightT>::min();
constexpr WeightT weight_max = std::numeric_limits<WeightT>::max();
constexpr WeightT weight_min = std::is_same_v<WeightT, int8_t> ? std::numeric_limits<int8_t>::min() / 2 : std::numeric_limits<WeightT>::min();
constexpr WeightT weight_max = std::numeric_limits<WeightT>::max() / 2;
constexpr int32_t weight_range = weight_max - weight_min;
std::vector<WeightT> weight_zero_point(weight_scale_zp_size);