mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-18 21:21:17 +00:00
[DML] QAttention (#19766)
### Description DML Implementation for [com.microsoft.QAttention](https://github.com/microsoft/onnxruntime/blob/main/docs/ContribOperators.md#com.microsoft.QAttention) ### Motivation and Context <!-- - Why is this change required? What problem does it solve? - If it fixes an open issue, please link to the issue here. --> --------- Co-authored-by: Xiang Zhang <xianz@microsoft.com>
This commit is contained in:
parent
5479124834
commit
89aa4697b1
8 changed files with 839 additions and 11 deletions
|
|
@ -1277,6 +1277,7 @@ Do not modify directly.*
|
|||
|MatMulIntegerToFloat|*in* A:**T1**<br> *in* B:**T2**<br> *in* a_scale:**T3**<br> *in* b_scale:**T3**<br> *in* a_zero_point:**T1**<br> *in* b_zero_point:**T2**<br> *in* bias:**T3**<br> *out* Y:**T3**|1+|**T1** = tensor(int8), tensor(uint8)<br/> **T2** = tensor(int8), tensor(uint8)<br/> **T3** = tensor(float), tensor(float16)|
|
||||
|MultiHeadAttention|*in* query:**T**<br> *in* key:**T**<br> *in* value:**T**<br> *in* bias:**T**<br> *in* key_padding_mask:**M**<br> *in* relative_position_bias:**T**<br> *in* past_key:**T**<br> *in* past_value:**T**<br> *out* output:**T**<br> *out* present_key:**T**<br> *out* present_value:**T**|1+|**M** = tensor(int32)<br/> **T** = tensor(float), tensor(float16)|
|
||||
|NhwcConv|*in* X:**T**<br> *in* W:**T**<br> *in* B:**T**<br> *out* Y:**T**|1+|**T** = tensor(float), tensor(float16)|
|
||||
|QAttention|*in* input:**T1**<br> *in* weight:**T2**<br> *in* bias:**T3**<br> *in* input_scale:**T3**<br> *in* weight_scale:**T3**<br> *in* mask_index:**T4**<br> *in* input_zero_point:**T1**<br> *in* weight_zero_point:**T2**<br> *in* past:**T3**<br> *out* output:**T3**<br> *out* present:**T3**|1+|**T1** = tensor(int8), tensor(uint8)<br/> **T2** = tensor(int8), tensor(uint8)<br/> **T3** = tensor(float), tensor(float16)<br/> **T4** = tensor(int32)|
|
||||
|QLinearAdd|*in* A:**T**<br> *in* A_scale:**tensor(float)**<br> *in* A_zero_point:**T**<br> *in* B:**T**<br> *in* B_scale:**tensor(float)**<br> *in* B_zero_point:**T**<br> *in* C_scale:**tensor(float)**<br> *in* C_zero_point:**T**<br> *out* C:**T**|1+|**T** = tensor(int8), tensor(uint8)|
|
||||
|QLinearAveragePool|*in* X:**T**<br> *in* x_scale:**tensor(float)**<br> *in* x_zero_point:**T**<br> *in* y_scale:**tensor(float)**<br> *in* y_zero_point:**T**<br> *out* Y:**T**|1+|**T** = tensor(int8), tensor(uint8)|
|
||||
|QLinearConcat|*in* Y_scale:**TF**<br> *in* Y_zero_point:**T8**<br> *in* inputs:**TV**<br> *out* Y:**T8**|1+|**T8** = tensor(int8), tensor(uint8)<br/> **TF** = tensor(float)<br/> **TV** = tensor(float), tensor(int8), tensor(uint8)|
|
||||
|
|
|
|||
|
|
@ -0,0 +1,704 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include "precomp.h"
|
||||
|
||||
/*
|
||||
Abbreviations: B is batch_size, S is sequence_length, W is hidden_size
|
||||
N is number of attention heads, H is head size, and W=N*H
|
||||
|
||||
Input, Weight, Bias, Mask Index and Past are Inputs
|
||||
|
||||
Mask Index/Causal Input Weight Bias
|
||||
| \ | /
|
||||
| \ | /
|
||||
| \ | /
|
||||
| MatMulIntToFloat
|
||||
| / | \
|
||||
| / | \
|
||||
| / | \
|
||||
| Slice Slice Slice
|
||||
| | | |
|
||||
| | | |
|
||||
| Identity Identity Identity // The identities are used to transpose NCHW -> NHCW while
|
||||
| | | | // keeping the GEMM strides as NCHW to better target metacommands
|
||||
| | | |
|
||||
| | | | Past
|
||||
| | | | / \
|
||||
| | | | / \
|
||||
| | | | Slice Slice
|
||||
| | | | | |
|
||||
| | | | | |
|
||||
| | | | | |
|
||||
--------------------------MHA -----------
|
||||
/ | \
|
||||
/ | \
|
||||
/ | \
|
||||
/ | \
|
||||
/ | \
|
||||
/ | \
|
||||
/ presentKey presentValue
|
||||
/ \ /
|
||||
/ \ /
|
||||
/ \ /
|
||||
/ Concat
|
||||
/ |
|
||||
Output1 Output2 (present)
|
||||
|
||||
This kernel creates a DML_GRAPH, as mentioned above.
|
||||
For reference, refer to this Doc:
|
||||
https://github.com/microsoft/onnxruntime/blob/main/docs/ContribOperators.md#commicrosoftqattention
|
||||
*/
|
||||
|
||||
namespace Dml
|
||||
{
|
||||
class DmlOperatorQAttention : public DmlOperator
|
||||
{
|
||||
public:
|
||||
DmlOperatorQAttention(const MLOperatorKernelCreationContext& kernelCreationContext)
|
||||
: DmlOperator(kernelCreationContext)
|
||||
{
|
||||
enum InputIndex : uint32_t
|
||||
{
|
||||
inputIndex,
|
||||
weightsIndex,
|
||||
biasIndex,
|
||||
inputScaleIndex,
|
||||
weightScaleIndex,
|
||||
maskIndex,
|
||||
inputZeroPointIndex,
|
||||
weightZeroPointIndex,
|
||||
pastIndex,
|
||||
inputCount,
|
||||
};
|
||||
|
||||
enum OutputIndex : uint32_t
|
||||
{
|
||||
outputIndex,
|
||||
presentIndex,
|
||||
outputCount,
|
||||
};
|
||||
|
||||
enum MhaInputIndex : uint32_t
|
||||
{
|
||||
mhaQueryIndex,
|
||||
mhaKeyIndex,
|
||||
mhaValueIndex,
|
||||
mhaStackedQueryKeyIndex,
|
||||
mhaStackedKeyValueIndex,
|
||||
mhaStackedQueryKeyValueIndex,
|
||||
mhaBiasIndex,
|
||||
mhaMaskIndex,
|
||||
mhaRelativePositionBiasIndex,
|
||||
mhaPastKeyIndex,
|
||||
mhaPastValueIndex,
|
||||
mhaInputCount,
|
||||
};
|
||||
|
||||
enum MhaOutputIndex : uint32_t
|
||||
{
|
||||
mhaOutputIndex,
|
||||
mhaPresentKeyIndex,
|
||||
mhaPresentValueIndex,
|
||||
mhaOutputCount,
|
||||
};
|
||||
|
||||
ML_CHECK_VALID_ARGUMENT(kernelCreationContext.GetInputCount() >= 5);
|
||||
ML_CHECK_VALID_ARGUMENT(kernelCreationContext.GetOutputCount() >= 1);
|
||||
|
||||
const bool hasBias = kernelCreationContext.IsInputValid(biasIndex);
|
||||
const bool hasMask = kernelCreationContext.IsInputValid(maskIndex);
|
||||
const bool hasUnpaddedBounds = hasMask && kernelCreationContext.GetInputTensorDimensionCount(maskIndex) == 1;
|
||||
const bool hasPast = kernelCreationContext.IsInputValid(pastIndex);
|
||||
|
||||
DmlOperator::Initialize(kernelCreationContext, std::nullopt, std::nullopt, std::nullopt, std::nullopt, 1);
|
||||
|
||||
const bool unidirectional = gsl::narrow_cast<uint32_t>(kernelCreationContext.GetAttribute<int64_t>(AttrName::Unidirectional));
|
||||
const uint32_t numHeads = gsl::narrow_cast<uint32_t>(kernelCreationContext.GetAttribute<int64_t>(AttrName::NumHeads));
|
||||
ML_CHECK_VALID_ARGUMENT(numHeads > 0); // to avoid process crash because of division by zero.
|
||||
|
||||
auto inputTensorShape = m_inputTensorDescs[inputIndex].GetSizes();
|
||||
ML_CHECK_VALID_ARGUMENT(inputTensorShape.size() == 3);
|
||||
|
||||
auto weightTensorShape = m_inputTensorDescs[weightsIndex].GetSizes();
|
||||
ML_CHECK_VALID_ARGUMENT(weightTensorShape.size() == 2);
|
||||
ML_CHECK_VALID_ARGUMENT(weightTensorShape[0] == inputTensorShape[2]);
|
||||
ML_CHECK_VALID_ARGUMENT(weightTensorShape[1] % 3 == 0);
|
||||
|
||||
if (hasBias)
|
||||
{
|
||||
auto biasTensorShape = m_inputTensorDescs[biasIndex].GetSizes();
|
||||
ML_CHECK_VALID_ARGUMENT(biasTensorShape.size() == 1);
|
||||
ML_CHECK_VALID_ARGUMENT(biasTensorShape[0] % 3 == 0);
|
||||
ML_CHECK_VALID_ARGUMENT(weightTensorShape[1] == biasTensorShape[0]);
|
||||
}
|
||||
|
||||
if (hasPast)
|
||||
{
|
||||
ML_CHECK_VALID_ARGUMENT(kernelCreationContext.IsOutputValid(presentIndex));
|
||||
}
|
||||
|
||||
const uint32_t hiddenSize = weightTensorShape[1] / 3;
|
||||
const uint32_t headSize = hiddenSize / numHeads;
|
||||
const uint32_t batchSize = inputTensorShape[0];
|
||||
const uint32_t sequenceLength = inputTensorShape[1];
|
||||
const uint32_t pastSequenceLength = hasPast ? m_inputTensorDescs[pastIndex].GetSizes()[3] : 0;
|
||||
|
||||
uint32_t desiredWeightTensorShape[3] = {batchSize, weightTensorShape[0], 3 * hiddenSize};
|
||||
MLOperatorTensorDataType dataType = kernelCreationContext.GetOutputEdgeDescription(outputIndex).tensorDataType;
|
||||
|
||||
m_inputTensorDescs[weightsIndex] = TensorDesc::ConstructBroadcastedTensorDesc(
|
||||
kernelCreationContext.GetInputEdgeDescription(weightsIndex).tensorDataType,
|
||||
desiredWeightTensorShape,
|
||||
weightTensorShape);
|
||||
|
||||
uint32_t desiredBiasTensorShape[3] = {batchSize, sequenceLength, 3 * hiddenSize};
|
||||
|
||||
if (hasBias)
|
||||
{
|
||||
auto biasTensorShape = m_inputTensorDescs[biasIndex].GetSizes();
|
||||
m_inputTensorDescs[biasIndex] = TensorDesc::ConstructBroadcastedTensorDesc(kernelCreationContext.GetInputEdgeDescription(biasIndex).tensorDataType, desiredBiasTensorShape, biasTensorShape);
|
||||
}
|
||||
|
||||
MLOperatorTensorDataType maskTensorDataType = MLOperatorTensorDataType::Undefined;
|
||||
bool hasMaxSequenceMask = false;
|
||||
DML_MULTIHEAD_ATTENTION_MASK_TYPE maskType = DML_MULTIHEAD_ATTENTION_MASK_TYPE_NONE;
|
||||
if (hasMask)
|
||||
{
|
||||
if (hasUnpaddedBounds)
|
||||
{
|
||||
auto unpaddedKeyBoundsShape = m_inputTensorDescs[maskIndex].GetSizes();
|
||||
ML_CHECK_VALID_ARGUMENT(unpaddedKeyBoundsShape.size() == 1);
|
||||
|
||||
const uint32_t batchGroupCount = unpaddedKeyBoundsShape[0] / batchSize;
|
||||
ML_CHECK_VALID_ARGUMENT(batchGroupCount == 1 || batchGroupCount == 2);
|
||||
|
||||
uint32_t desiredShape[2] = {batchGroupCount, batchSize};
|
||||
m_inputTensorDescs[maskIndex] = TensorDesc(
|
||||
m_inputTensorDescs[maskIndex].GetDmlDataType(),
|
||||
desiredShape);
|
||||
|
||||
maskType = batchGroupCount == 1
|
||||
? DML_MULTIHEAD_ATTENTION_MASK_TYPE_KEY_SEQUENCE_LENGTH
|
||||
: DML_MULTIHEAD_ATTENTION_MASK_TYPE_KEY_SEQUENCE_END_START;
|
||||
}
|
||||
else
|
||||
{
|
||||
auto maskIndexTensorShape = m_inputTensorDescs[maskIndex].GetSizes();
|
||||
ML_CHECK_VALID_ARGUMENT(maskIndexTensorShape.size() > 1 && maskIndexTensorShape.size() <= 4);
|
||||
|
||||
maskType = DML_MULTIHEAD_ATTENTION_MASK_TYPE_BOOLEAN;
|
||||
std::vector<uint32_t> reshapedMaskIndexTensorShape(maskIndexTensorShape.begin(), maskIndexTensorShape.end());
|
||||
if (maskIndexTensorShape.size() == 4 && maskIndexTensorShape[2] != sequenceLength)
|
||||
{
|
||||
hasMaxSequenceMask = true;
|
||||
ML_CHECK_VALID_ARGUMENT(maskIndexTensorShape[2] == maskIndexTensorShape[3]);
|
||||
const uint32_t maxSequenceLength = maskIndexTensorShape[2];
|
||||
uint32_t desiredMaskIndexShape[4] = {batchSize, numHeads, maxSequenceLength, maxSequenceLength};
|
||||
maskTensorDataType = kernelCreationContext.GetInputEdgeDescription(maskIndex).tensorDataType;
|
||||
m_inputTensorDescs[maskIndex] = TensorDesc::ConstructBroadcastedTensorDesc(maskTensorDataType, desiredMaskIndexShape, reshapedMaskIndexTensorShape);
|
||||
}
|
||||
else
|
||||
{
|
||||
uint32_t maskIndexDimensionCount = gsl::narrow_cast<uint32_t>(maskIndexTensorShape.size());
|
||||
reshapedMaskIndexTensorShape.insert(reshapedMaskIndexTensorShape.begin() + 1, 4 - maskIndexDimensionCount, 1);
|
||||
uint32_t desiredMaskIndexShape[4] = {batchSize, numHeads, sequenceLength, sequenceLength};
|
||||
maskTensorDataType = kernelCreationContext.GetInputEdgeDescription(maskIndex).tensorDataType;
|
||||
m_inputTensorDescs[maskIndex] = TensorDesc::ConstructBroadcastedTensorDesc(maskTensorDataType, desiredMaskIndexShape, reshapedMaskIndexTensorShape);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
MLOperatorTensorDataType pastTensorDataType = MLOperatorTensorDataType::Undefined;
|
||||
MLOperatorTensorDataType presentTensorDataType = MLOperatorTensorDataType::Undefined;
|
||||
if (hasPast)
|
||||
{
|
||||
pastTensorDataType = kernelCreationContext.GetInputEdgeDescription(pastIndex).tensorDataType;
|
||||
presentTensorDataType = kernelCreationContext.GetOutputEdgeDescription(presentIndex).tensorDataType;
|
||||
}
|
||||
|
||||
TensorDesc matMulIntToFloatOutputTensorDesc = TensorDesc::ConstructDefaultTensorDesc(dataType, desiredBiasTensorShape);
|
||||
DML_TENSOR_DESC namedMatMulIntToFloatOutputTensorDesc = matMulIntToFloatOutputTensorDesc.GetDmlDesc();
|
||||
|
||||
std::vector<DML_TENSOR_DESC> inputDescs = GetDmlInputDescs();
|
||||
std::vector<DML_TENSOR_DESC> outputDescs = GetDmlOutputDescs();
|
||||
|
||||
DML_MATRIX_MULTIPLY_INTEGER_TO_FLOAT_OPERATOR_DESC matMulIntToFloatOperatorDesc = {};
|
||||
matMulIntToFloatOperatorDesc.ATensor = &inputDescs[InputIndex::inputIndex];
|
||||
matMulIntToFloatOperatorDesc.AScaleTensor = &inputDescs[InputIndex::inputScaleIndex];
|
||||
matMulIntToFloatOperatorDesc.AZeroPointTensor = &inputDescs[InputIndex::inputZeroPointIndex];
|
||||
matMulIntToFloatOperatorDesc.BTensor = &inputDescs[InputIndex::weightsIndex];
|
||||
matMulIntToFloatOperatorDesc.BScaleTensor = &inputDescs[InputIndex::weightScaleIndex];
|
||||
matMulIntToFloatOperatorDesc.BZeroPointTensor = &inputDescs[InputIndex::weightZeroPointIndex];
|
||||
matMulIntToFloatOperatorDesc.BiasTensor = hasBias ? &inputDescs[InputIndex::biasIndex] : nullptr;
|
||||
matMulIntToFloatOperatorDesc.OutputTensor = &namedMatMulIntToFloatOutputTensorDesc;
|
||||
|
||||
const DML_OPERATOR_DESC matMulIntToFloatDesc = { DML_OPERATOR_MATRIX_MULTIPLY_INTEGER_TO_FLOAT, &matMulIntToFloatOperatorDesc};
|
||||
|
||||
std::array<uint32_t, 3> queryKeySlicedTensorShape = {batchSize, sequenceLength, hiddenSize + hiddenSize};
|
||||
TensorDesc queryKeySlicedInputTensorDesc = TensorDesc::ConstructDefaultTensorDesc(dataType, queryKeySlicedTensorShape);
|
||||
DML_TENSOR_DESC namedQueryKeySlicedInputTensorDesc = queryKeySlicedInputTensorDesc.GetDmlDesc();
|
||||
|
||||
std::array<uint32_t, 3> valueSlicedTensorShape = {batchSize, sequenceLength, hiddenSize};
|
||||
TensorDesc valueSlicedInputTensorDesc = TensorDesc::ConstructDefaultTensorDesc(dataType, valueSlicedTensorShape);
|
||||
DML_TENSOR_DESC namedValueSlicedInputTensorDesc = valueSlicedInputTensorDesc.GetDmlDesc();
|
||||
|
||||
// Transpose slice QK from [batchSize, sequenceLength, 2, numHeads, headSize] to [batchSize, sequenceLength, numHeads, 2, headSize]
|
||||
std::array<uint32_t, 5> queryKeyTransposedTensorShape = {batchSize, sequenceLength, numHeads, 2, headSize};
|
||||
std::array<uint32_t, 5> queryKeyTransposedStrides = {
|
||||
sequenceLength * numHeads * 2 * headSize,
|
||||
numHeads * 2 * headSize,
|
||||
headSize,
|
||||
numHeads * headSize,
|
||||
1,
|
||||
};
|
||||
|
||||
TensorDesc queryKeyTransposedInputTensorDesc = TensorDesc(
|
||||
GetDmlDataTypeFromMlDataType(dataType),
|
||||
queryKeyTransposedTensorShape,
|
||||
queryKeyTransposedStrides);
|
||||
DML_TENSOR_DESC namedQueryKeyTransposedInputTensorDesc = queryKeyTransposedInputTensorDesc.GetDmlDesc();
|
||||
|
||||
TensorDesc queryKeyTransposedOutputTensorDesc = TensorDesc(
|
||||
GetDmlDataTypeFromMlDataType(dataType),
|
||||
queryKeyTransposedTensorShape);
|
||||
DML_TENSOR_DESC namedQueryKeyTransposedOutputTensorDesc = queryKeyTransposedOutputTensorDesc.GetDmlDesc();
|
||||
|
||||
// Transpose QKV from [batchSize, sequenceLength, 3, numHeads, headSize] to [batchSize, sequenceLength, numHeads, 3, headSize]
|
||||
std::array<uint32_t, 5> queryKeyValueTransposedTensorShape = {batchSize, sequenceLength, numHeads, 3, headSize};
|
||||
std::array<uint32_t, 5> queryKeyValueTransposedStrides = {
|
||||
sequenceLength * numHeads * 3 * headSize,
|
||||
numHeads * 3 * headSize,
|
||||
headSize,
|
||||
numHeads * headSize,
|
||||
1,
|
||||
};
|
||||
|
||||
TensorDesc queryKeyValueTransposedInputTensorDesc = TensorDesc(
|
||||
GetDmlDataTypeFromMlDataType(dataType),
|
||||
queryKeyValueTransposedTensorShape,
|
||||
queryKeyValueTransposedStrides);
|
||||
DML_TENSOR_DESC namedQueryKeyValueTransposedInputTensorDesc = queryKeyValueTransposedInputTensorDesc.GetDmlDesc();
|
||||
|
||||
TensorDesc queryKeyValueTransposedOutputTensorDesc = TensorDesc(
|
||||
GetDmlDataTypeFromMlDataType(dataType),
|
||||
queryKeyValueTransposedTensorShape);
|
||||
DML_TENSOR_DESC namedQueryKeyValueTransposedOutputTensorDesc = queryKeyValueTransposedOutputTensorDesc.GetDmlDesc();
|
||||
|
||||
std::array<uint32_t, 3> queryKeySliceOffset = {0, 0, 0};
|
||||
std::array<uint32_t, 3> queryKeySliceSize = {batchSize, sequenceLength, hiddenSize + hiddenSize};
|
||||
std::array<int32_t, 3> queryKeySliceStrides = {1, 1, 1};
|
||||
|
||||
std::array<uint32_t, 3> valueSliceOffset = {0, 0, 2 * hiddenSize};
|
||||
std::array<uint32_t, 3> valueSliceSize = {batchSize, sequenceLength, hiddenSize};
|
||||
std::array<int32_t, 3> valueSliceStrides = {1, 1, 1};
|
||||
|
||||
// When Q/K/V all have the same hidden size, we just have to transpose it before sending it to MHA
|
||||
DML_ELEMENT_WISE_IDENTITY_OPERATOR_DESC transposeOperatorDesc = {};
|
||||
|
||||
transposeOperatorDesc.InputTensor = &namedQueryKeyValueTransposedInputTensorDesc;
|
||||
transposeOperatorDesc.OutputTensor = &namedQueryKeyValueTransposedOutputTensorDesc;
|
||||
|
||||
const DML_OPERATOR_DESC transposedDesc = { DML_OPERATOR_ELEMENT_WISE_IDENTITY, &transposeOperatorDesc};
|
||||
|
||||
std::array<uint32_t, 4> maskSliceOutputShape = {batchSize, numHeads, sequenceLength, sequenceLength};
|
||||
std::array<int32_t, 4> maskSliceStrides = {1, 1, 1, 1};
|
||||
std::array<uint32_t, 4> maskSliceOffsets = {0, 0, 0, 0};
|
||||
TensorDesc maskSliceOutputTensorDesc;
|
||||
DML_TENSOR_DESC namedMaskSliceOutputTensorDesc;
|
||||
|
||||
DML_SLICE1_OPERATOR_DESC maskSlicedOperatorDesc = {};
|
||||
if (hasMaxSequenceMask)
|
||||
{
|
||||
maskSliceOutputTensorDesc = TensorDesc::ConstructDefaultTensorDesc(maskTensorDataType, maskSliceOutputShape);
|
||||
namedMaskSliceOutputTensorDesc = maskSliceOutputTensorDesc.GetDmlDesc();
|
||||
maskSlicedOperatorDesc.InputTensor = &inputDescs[maskIndex];
|
||||
maskSlicedOperatorDesc.OutputTensor = &namedMaskSliceOutputTensorDesc;
|
||||
maskSlicedOperatorDesc.DimensionCount = gsl::narrow_cast<uint32_t>(maskSliceOutputShape.size());
|
||||
maskSlicedOperatorDesc.InputWindowOffsets = maskSliceOffsets.data();
|
||||
maskSlicedOperatorDesc.InputWindowSizes = maskSliceOutputShape.data();
|
||||
maskSlicedOperatorDesc.InputWindowStrides = maskSliceStrides.data();
|
||||
}
|
||||
const DML_OPERATOR_DESC maskSlicedDesc = { DML_OPERATOR_SLICE1, &maskSlicedOperatorDesc};
|
||||
|
||||
// We need to slice Past to get PastValue and PastKey tensors for MHA
|
||||
std::array<uint32_t, 5> pastKeyOutputShape = {1, batchSize, numHeads, pastSequenceLength, headSize};
|
||||
std::array<int32_t, 5> pastKeyStrides = {1, 1, 1, 1, 1};
|
||||
std::array<uint32_t, 5> pastKeyOffsets = {0, 0, 0, 0, 0};
|
||||
TensorDesc pastKeyOutputTensorDesc;
|
||||
DML_TENSOR_DESC namedPastKeyOutputTensorDesc;
|
||||
|
||||
std::array<uint32_t, 5> pastValueOutputShape = {1, batchSize, numHeads, pastSequenceLength, headSize};
|
||||
std::array<int32_t, 5> pastValueStrides = {1, 1, 1, 1, 1};
|
||||
std::array<uint32_t, 5> pastValueOffsets = {1, 0, 0, 0, 0};
|
||||
TensorDesc pastValueOutputTensorDesc;
|
||||
DML_TENSOR_DESC namedPastValueOutputTensorDesc;
|
||||
|
||||
DML_SLICE1_OPERATOR_DESC pastKeySlicedOperatorDesc = {};
|
||||
DML_SLICE1_OPERATOR_DESC pastValueSlicedOperatorDesc = {};
|
||||
|
||||
if (hasPast)
|
||||
{
|
||||
pastKeyOutputTensorDesc = TensorDesc::ConstructDefaultTensorDesc(pastTensorDataType, pastKeyOutputShape);
|
||||
namedPastKeyOutputTensorDesc = pastKeyOutputTensorDesc.GetDmlDesc();
|
||||
pastKeySlicedOperatorDesc.InputTensor = &inputDescs[pastIndex];
|
||||
pastKeySlicedOperatorDesc.OutputTensor = &namedPastKeyOutputTensorDesc;
|
||||
pastKeySlicedOperatorDesc.DimensionCount = gsl::narrow_cast<uint32_t>(pastKeyOutputShape.size());
|
||||
pastKeySlicedOperatorDesc.InputWindowOffsets = pastKeyOffsets.data();
|
||||
pastKeySlicedOperatorDesc.InputWindowSizes = pastKeyOutputShape.data();
|
||||
pastKeySlicedOperatorDesc.InputWindowStrides = pastKeyStrides.data();
|
||||
|
||||
pastValueOutputTensorDesc = TensorDesc::ConstructDefaultTensorDesc(pastTensorDataType, pastValueOutputShape);
|
||||
namedPastValueOutputTensorDesc = pastValueOutputTensorDesc.GetDmlDesc();
|
||||
pastValueSlicedOperatorDesc.InputTensor = &inputDescs[pastIndex];
|
||||
pastValueSlicedOperatorDesc.OutputTensor = &namedPastValueOutputTensorDesc;
|
||||
pastValueSlicedOperatorDesc.DimensionCount = gsl::narrow_cast<uint32_t>(pastValueOutputShape.size());
|
||||
pastValueSlicedOperatorDesc.InputWindowOffsets = pastValueOffsets.data();
|
||||
pastValueSlicedOperatorDesc.InputWindowSizes = pastValueOutputShape.data();
|
||||
pastValueSlicedOperatorDesc.InputWindowStrides = pastValueStrides.data();
|
||||
}
|
||||
|
||||
const DML_OPERATOR_DESC pastKeySlicedDesc = { DML_OPERATOR_SLICE1, &pastKeySlicedOperatorDesc};
|
||||
const DML_OPERATOR_DESC pastValueSlicedDesc = { DML_OPERATOR_SLICE1, &pastValueSlicedOperatorDesc};
|
||||
|
||||
// Causal Mask: Upper Triangular Boolean Matrix
|
||||
// Example: [[1, 0, 0, 0, 0],
|
||||
// [1, 1, 0, 0, 0],
|
||||
// [1, 1, 1, 0, 0],
|
||||
// [1, 1, 1, 1, 0]]
|
||||
// DML adds maskFilterValue to the "off" bits in the mask and sets the "on" bits to 0
|
||||
// passed to MHA as maskIndex Tensor when unidirectional == 1
|
||||
std::array<uint32_t, 4> causalMaskOutputShape = {1, 1, sequenceLength, pastSequenceLength + sequenceLength};
|
||||
TensorDesc causalMaskTensorDesc;
|
||||
DML_DIAGONAL_MATRIX1_OPERATOR_DESC causalMaskOperatorDesc = {};
|
||||
DML_TENSOR_DESC namedcausalMaskTensorDesc;
|
||||
|
||||
if (unidirectional && !hasMask)
|
||||
{
|
||||
causalMaskTensorDesc = TensorDesc::ConstructDefaultTensorDesc(MLOperatorTensorDataType::Int32, causalMaskOutputShape);
|
||||
namedcausalMaskTensorDesc = causalMaskTensorDesc.GetDmlDesc();
|
||||
causalMaskOperatorDesc.ValueDataType = DML_TENSOR_DATA_TYPE_INT32;
|
||||
causalMaskOperatorDesc.DiagonalFillBegin = INT32_MIN;
|
||||
causalMaskOperatorDesc.DiagonalFillEnd = pastSequenceLength + 1;
|
||||
causalMaskOperatorDesc.Value.Int32 = 1;
|
||||
causalMaskOperatorDesc.OutputTensor = &namedcausalMaskTensorDesc;
|
||||
maskType = DML_MULTIHEAD_ATTENTION_MASK_TYPE_BOOLEAN;
|
||||
}
|
||||
DML_OPERATOR_DESC causalMaskDesc = { DML_OPERATOR_DIAGONAL_MATRIX1, &causalMaskOperatorDesc };
|
||||
|
||||
DML_MULTIHEAD_ATTENTION_OPERATOR_DESC mhaOperatorDesc = {};
|
||||
std::array<uint32_t, 5> presentKeyOutputShape = {1, batchSize, numHeads, pastSequenceLength + sequenceLength, headSize};
|
||||
std::array<uint32_t, 5> presentValueOutputShape = {1, batchSize, numHeads, pastSequenceLength + sequenceLength, headSize};
|
||||
TensorDesc presentKeyTensorDesc;
|
||||
TensorDesc presentValueTensorDesc;
|
||||
DML_TENSOR_DESC namedPresentKeyOutputTensorDesc;
|
||||
DML_TENSOR_DESC namedPresentValueOutputTensorDesc;
|
||||
|
||||
mhaOperatorDesc.StackedQueryKeyValueTensor = &namedQueryKeyValueTransposedOutputTensorDesc;
|
||||
|
||||
// Broadcast to MHA MaskTensor Shape
|
||||
std::array<uint32_t, 4> mhaMaskTensorShape = {batchSize, numHeads, sequenceLength, pastSequenceLength + sequenceLength};
|
||||
TensorDesc broadcastedcausalMaskTensorDesc;
|
||||
DML_TENSOR_DESC namedbroadcastedcausalMaskTensorDesc;
|
||||
if (unidirectional && !hasMask)
|
||||
{
|
||||
broadcastedcausalMaskTensorDesc = TensorDesc::ConstructBroadcastedTensorDesc(MLOperatorTensorDataType::Int32, mhaMaskTensorShape, causalMaskOutputShape);
|
||||
namedbroadcastedcausalMaskTensorDesc = broadcastedcausalMaskTensorDesc.GetDmlDesc();
|
||||
mhaOperatorDesc.MaskTensor = &namedbroadcastedcausalMaskTensorDesc;
|
||||
}
|
||||
else if (hasMaxSequenceMask)
|
||||
{
|
||||
mhaOperatorDesc.MaskTensor = &namedMaskSliceOutputTensorDesc;
|
||||
}
|
||||
else
|
||||
{
|
||||
mhaOperatorDesc.MaskTensor = hasMask ? &inputDescs[maskIndex] : nullptr;
|
||||
}
|
||||
|
||||
mhaOperatorDesc.RelativePositionBiasTensor = nullptr;
|
||||
mhaOperatorDesc.OutputTensor = &outputDescs[outputIndex];
|
||||
mhaOperatorDesc.Scale = kernelCreationContext.GetOptionalAttribute<float>(AttrName::Scale, gsl::narrow_cast<float>(1.0f / std::sqrt(headSize)));
|
||||
// Set MaskFilterValue to lowest float for Causal Mask
|
||||
mhaOperatorDesc.MaskFilterValue = unidirectional ? std::numeric_limits<float>::lowest() :
|
||||
kernelCreationContext.GetOptionalAttribute<float>(AttrName::MaskFilterValue, -10'000.0f);
|
||||
mhaOperatorDesc.HeadCount = numHeads;
|
||||
mhaOperatorDesc.MaskType = maskType;
|
||||
if (hasPast)
|
||||
{
|
||||
presentKeyTensorDesc = TensorDesc::ConstructDefaultTensorDesc(presentTensorDataType, presentKeyOutputShape);
|
||||
namedPresentKeyOutputTensorDesc = presentKeyTensorDesc.GetDmlDesc();
|
||||
presentValueTensorDesc = TensorDesc::ConstructDefaultTensorDesc(presentTensorDataType, presentValueOutputShape);
|
||||
namedPresentValueOutputTensorDesc = presentValueTensorDesc.GetDmlDesc();
|
||||
mhaOperatorDesc.PastKeyTensor = &namedPastKeyOutputTensorDesc;
|
||||
mhaOperatorDesc.PastValueTensor = &namedPastValueOutputTensorDesc;
|
||||
mhaOperatorDesc.OutputPresentKeyTensor = &namedPresentKeyOutputTensorDesc;
|
||||
mhaOperatorDesc.OutputPresentValueTensor = &namedPresentValueOutputTensorDesc;
|
||||
}
|
||||
|
||||
const DML_OPERATOR_DESC mhaDesc = { DML_OPERATOR_MULTIHEAD_ATTENTION, &mhaOperatorDesc };
|
||||
|
||||
DML_JOIN_OPERATOR_DESC presentKeyValueJoinOperatorDesc = {};
|
||||
std::vector<DML_TENSOR_DESC> joinInputDesc;
|
||||
|
||||
if (hasPast)
|
||||
{
|
||||
joinInputDesc.push_back(namedPresentKeyOutputTensorDesc);
|
||||
joinInputDesc.push_back(namedPresentValueOutputTensorDesc);
|
||||
presentKeyValueJoinOperatorDesc.InputCount = gsl::narrow_cast<uint32_t>(joinInputDesc.size());
|
||||
presentKeyValueJoinOperatorDesc.InputTensors = joinInputDesc.data();
|
||||
presentKeyValueJoinOperatorDesc.OutputTensor = &outputDescs[presentIndex];
|
||||
presentKeyValueJoinOperatorDesc.Axis = gsl::narrow_cast<uint32_t>(0);
|
||||
}
|
||||
|
||||
DML_OPERATOR_DESC presentKeyValueJoinDesc = { DML_OPERATOR_JOIN, &presentKeyValueJoinOperatorDesc };
|
||||
|
||||
// Construct the graph
|
||||
std::vector<DML_INPUT_GRAPH_EDGE_DESC> inputEdges;
|
||||
std::vector<DML_INTERMEDIATE_GRAPH_EDGE_DESC> intermediateEdges;
|
||||
std::vector<DML_OUTPUT_GRAPH_EDGE_DESC> outputEdges;
|
||||
|
||||
std::vector<const DML_OPERATOR_DESC*> opDescs = {
|
||||
&matMulIntToFloatDesc,
|
||||
&mhaDesc,
|
||||
};
|
||||
|
||||
uint32_t currentNodeIndex = 0;
|
||||
const uint32_t matMulIntToFloatNodeIndex = currentNodeIndex++;
|
||||
const uint32_t mhaNodeIndex = currentNodeIndex++;
|
||||
|
||||
uint32_t queryKeyValueTransposedNodeIndex = 0;
|
||||
|
||||
opDescs.push_back(&transposedDesc);
|
||||
queryKeyValueTransposedNodeIndex = currentNodeIndex++;
|
||||
|
||||
uint32_t maskSliceNodeIndex = 0;
|
||||
if (hasMaxSequenceMask)
|
||||
{
|
||||
opDescs.push_back(&maskSlicedDesc);
|
||||
maskSliceNodeIndex = currentNodeIndex++;
|
||||
}
|
||||
|
||||
uint32_t pastKeySliceNodeIndex = 0;
|
||||
uint32_t pastValueSliceNodeIndex = 0;
|
||||
uint32_t concatNodeIndex = 0;
|
||||
if (hasPast)
|
||||
{
|
||||
opDescs.push_back(&pastKeySlicedDesc);
|
||||
pastKeySliceNodeIndex = currentNodeIndex++;
|
||||
opDescs.push_back(&pastValueSlicedDesc);
|
||||
pastValueSliceNodeIndex = currentNodeIndex++;
|
||||
opDescs.push_back(&presentKeyValueJoinDesc);
|
||||
concatNodeIndex = currentNodeIndex++;
|
||||
}
|
||||
|
||||
uint32_t causalMaskNodeIndex = 0;
|
||||
if (unidirectional && !hasMask)
|
||||
{
|
||||
opDescs.push_back(&causalMaskDesc);
|
||||
causalMaskNodeIndex = currentNodeIndex++;
|
||||
}
|
||||
|
||||
DML_INPUT_GRAPH_EDGE_DESC inputToMatMulIntToFloatEdge = {};
|
||||
inputToMatMulIntToFloatEdge.GraphInputIndex = InputIndex::inputIndex;
|
||||
inputToMatMulIntToFloatEdge.ToNodeIndex = matMulIntToFloatNodeIndex;
|
||||
inputToMatMulIntToFloatEdge.ToNodeInputIndex = 0;
|
||||
inputEdges.push_back(inputToMatMulIntToFloatEdge);
|
||||
|
||||
DML_INPUT_GRAPH_EDGE_DESC inputScaleToMatMulIntToFloatEdge = {};
|
||||
inputScaleToMatMulIntToFloatEdge.GraphInputIndex = InputIndex::inputScaleIndex;
|
||||
inputScaleToMatMulIntToFloatEdge.ToNodeIndex = matMulIntToFloatNodeIndex;
|
||||
inputScaleToMatMulIntToFloatEdge.ToNodeInputIndex = 1;
|
||||
inputEdges.push_back(inputScaleToMatMulIntToFloatEdge);
|
||||
|
||||
DML_INPUT_GRAPH_EDGE_DESC inputZeroPointToMatMulIntToFloatEdge = {};
|
||||
inputZeroPointToMatMulIntToFloatEdge.GraphInputIndex = InputIndex::inputZeroPointIndex;
|
||||
inputZeroPointToMatMulIntToFloatEdge.ToNodeIndex = matMulIntToFloatNodeIndex;
|
||||
inputZeroPointToMatMulIntToFloatEdge.ToNodeInputIndex = 2;
|
||||
inputEdges.push_back(inputZeroPointToMatMulIntToFloatEdge);
|
||||
|
||||
DML_INPUT_GRAPH_EDGE_DESC weightToMatMulIntToFloatEdge = {};
|
||||
weightToMatMulIntToFloatEdge.GraphInputIndex = InputIndex::weightsIndex;
|
||||
weightToMatMulIntToFloatEdge.ToNodeIndex = matMulIntToFloatNodeIndex;
|
||||
weightToMatMulIntToFloatEdge.ToNodeInputIndex = 3;
|
||||
inputEdges.push_back(weightToMatMulIntToFloatEdge);
|
||||
|
||||
DML_INPUT_GRAPH_EDGE_DESC weightScaleToMatMulIntToFloatEdge = {};
|
||||
weightScaleToMatMulIntToFloatEdge.GraphInputIndex = InputIndex::weightScaleIndex;
|
||||
weightScaleToMatMulIntToFloatEdge.ToNodeIndex = matMulIntToFloatNodeIndex;
|
||||
weightScaleToMatMulIntToFloatEdge.ToNodeInputIndex = 4;
|
||||
inputEdges.push_back(weightScaleToMatMulIntToFloatEdge);
|
||||
|
||||
DML_INPUT_GRAPH_EDGE_DESC weightZeroPointToMatMulIntToFloatEdge = {};
|
||||
weightZeroPointToMatMulIntToFloatEdge.GraphInputIndex = InputIndex::weightZeroPointIndex;
|
||||
weightZeroPointToMatMulIntToFloatEdge.ToNodeIndex = matMulIntToFloatNodeIndex;
|
||||
weightZeroPointToMatMulIntToFloatEdge.ToNodeInputIndex = 5;
|
||||
inputEdges.push_back(weightZeroPointToMatMulIntToFloatEdge);
|
||||
|
||||
if (hasBias)
|
||||
{
|
||||
DML_INPUT_GRAPH_EDGE_DESC biasToMatMulIntToFloatEdge = {};
|
||||
biasToMatMulIntToFloatEdge.GraphInputIndex = InputIndex::biasIndex;
|
||||
biasToMatMulIntToFloatEdge.ToNodeIndex = matMulIntToFloatNodeIndex;
|
||||
biasToMatMulIntToFloatEdge.ToNodeInputIndex = 6;
|
||||
inputEdges.push_back(biasToMatMulIntToFloatEdge);
|
||||
}
|
||||
|
||||
if (hasMask)
|
||||
{
|
||||
if (hasUnpaddedBounds)
|
||||
{
|
||||
DML_INPUT_GRAPH_EDGE_DESC maskToMhaEdge = {};
|
||||
maskToMhaEdge.GraphInputIndex = InputIndex::maskIndex;
|
||||
maskToMhaEdge.ToNodeIndex = mhaNodeIndex;
|
||||
maskToMhaEdge.ToNodeInputIndex = mhaMaskIndex;
|
||||
inputEdges.push_back(maskToMhaEdge);
|
||||
}
|
||||
else if (hasMaxSequenceMask)
|
||||
{
|
||||
DML_INPUT_GRAPH_EDGE_DESC maskToMaskSliceEdge = {};
|
||||
maskToMaskSliceEdge.GraphInputIndex = InputIndex::maskIndex;
|
||||
maskToMaskSliceEdge.ToNodeIndex = maskSliceNodeIndex;
|
||||
maskToMaskSliceEdge.ToNodeInputIndex = 0;
|
||||
inputEdges.push_back(maskToMaskSliceEdge);
|
||||
|
||||
DML_INTERMEDIATE_GRAPH_EDGE_DESC maskSliceToMhaEdge = {};
|
||||
maskSliceToMhaEdge.FromNodeIndex = maskSliceNodeIndex;
|
||||
maskSliceToMhaEdge.FromNodeOutputIndex = 0;
|
||||
maskSliceToMhaEdge.ToNodeIndex = mhaNodeIndex;
|
||||
maskSliceToMhaEdge.ToNodeInputIndex = mhaMaskIndex;
|
||||
intermediateEdges.push_back(maskSliceToMhaEdge);
|
||||
}
|
||||
else
|
||||
{
|
||||
DML_INPUT_GRAPH_EDGE_DESC maskToMhaEdge = {};
|
||||
maskToMhaEdge.GraphInputIndex = InputIndex::maskIndex;
|
||||
maskToMhaEdge.ToNodeIndex = mhaNodeIndex;
|
||||
maskToMhaEdge.ToNodeInputIndex = mhaMaskIndex;
|
||||
inputEdges.push_back(maskToMhaEdge);
|
||||
}
|
||||
}
|
||||
else if (unidirectional)
|
||||
{
|
||||
DML_INTERMEDIATE_GRAPH_EDGE_DESC causalMaskToMhaEdge = {};
|
||||
causalMaskToMhaEdge.FromNodeIndex = causalMaskNodeIndex;
|
||||
causalMaskToMhaEdge.FromNodeOutputIndex = 0;
|
||||
causalMaskToMhaEdge.ToNodeIndex = mhaNodeIndex;
|
||||
causalMaskToMhaEdge.ToNodeInputIndex = mhaMaskIndex;
|
||||
intermediateEdges.push_back(causalMaskToMhaEdge);
|
||||
}
|
||||
|
||||
if (hasPast)
|
||||
{
|
||||
DML_INPUT_GRAPH_EDGE_DESC pastToPastKeySliceEdge = {};
|
||||
pastToPastKeySliceEdge.GraphInputIndex = InputIndex::pastIndex;
|
||||
pastToPastKeySliceEdge.ToNodeIndex = pastKeySliceNodeIndex;
|
||||
pastToPastKeySliceEdge.ToNodeInputIndex = 0;
|
||||
inputEdges.push_back(pastToPastKeySliceEdge);
|
||||
|
||||
DML_INPUT_GRAPH_EDGE_DESC pastToPastValueSliceEdge = {};
|
||||
pastToPastValueSliceEdge.GraphInputIndex = InputIndex::pastIndex;
|
||||
pastToPastValueSliceEdge.ToNodeIndex = pastValueSliceNodeIndex;
|
||||
pastToPastValueSliceEdge.ToNodeInputIndex = 0;
|
||||
inputEdges.push_back(pastToPastValueSliceEdge);
|
||||
|
||||
DML_INTERMEDIATE_GRAPH_EDGE_DESC pastKeyToMhaEdge = {};
|
||||
pastKeyToMhaEdge.FromNodeIndex = pastKeySliceNodeIndex;
|
||||
pastKeyToMhaEdge.FromNodeOutputIndex = 0;
|
||||
pastKeyToMhaEdge.ToNodeIndex = mhaNodeIndex;
|
||||
pastKeyToMhaEdge.ToNodeInputIndex = mhaPastKeyIndex;
|
||||
intermediateEdges.push_back(pastKeyToMhaEdge);
|
||||
|
||||
DML_INTERMEDIATE_GRAPH_EDGE_DESC pastValueToMhaEdge = {};
|
||||
pastValueToMhaEdge.FromNodeIndex = pastValueSliceNodeIndex;
|
||||
pastValueToMhaEdge.FromNodeOutputIndex = 0;
|
||||
pastValueToMhaEdge.ToNodeIndex = mhaNodeIndex;
|
||||
pastValueToMhaEdge.ToNodeInputIndex = mhaPastValueIndex;
|
||||
intermediateEdges.push_back(pastValueToMhaEdge);
|
||||
|
||||
DML_INTERMEDIATE_GRAPH_EDGE_DESC presentKeyToConcatEdge = {};
|
||||
presentKeyToConcatEdge.FromNodeIndex = mhaNodeIndex;
|
||||
presentKeyToConcatEdge.FromNodeOutputIndex = mhaPresentKeyIndex;
|
||||
presentKeyToConcatEdge.ToNodeIndex = concatNodeIndex;
|
||||
presentKeyToConcatEdge.ToNodeInputIndex = 0;
|
||||
intermediateEdges.push_back(presentKeyToConcatEdge);
|
||||
|
||||
DML_INTERMEDIATE_GRAPH_EDGE_DESC presentValueToConcatEdge = {};
|
||||
presentValueToConcatEdge.FromNodeIndex = mhaNodeIndex;
|
||||
presentValueToConcatEdge.FromNodeOutputIndex = mhaPresentValueIndex;
|
||||
presentValueToConcatEdge.ToNodeIndex = concatNodeIndex;
|
||||
presentValueToConcatEdge.ToNodeInputIndex = 1;
|
||||
intermediateEdges.push_back(presentValueToConcatEdge);
|
||||
}
|
||||
|
||||
DML_INTERMEDIATE_GRAPH_EDGE_DESC matMulIntToFloatToQueryKeyValueTransposeEdge = {};
|
||||
matMulIntToFloatToQueryKeyValueTransposeEdge.FromNodeIndex = matMulIntToFloatNodeIndex;
|
||||
matMulIntToFloatToQueryKeyValueTransposeEdge.FromNodeOutputIndex = 0;
|
||||
matMulIntToFloatToQueryKeyValueTransposeEdge.ToNodeIndex = queryKeyValueTransposedNodeIndex;
|
||||
matMulIntToFloatToQueryKeyValueTransposeEdge.ToNodeInputIndex = 0;
|
||||
intermediateEdges.push_back(matMulIntToFloatToQueryKeyValueTransposeEdge);
|
||||
|
||||
// All we need to do here is transpose the stacked QKV tensor into something DML supports
|
||||
DML_INTERMEDIATE_GRAPH_EDGE_DESC queryKeyValueTransposedToMhaEdge = {};
|
||||
queryKeyValueTransposedToMhaEdge.FromNodeIndex = queryKeyValueTransposedNodeIndex;
|
||||
queryKeyValueTransposedToMhaEdge.FromNodeOutputIndex = 0;
|
||||
queryKeyValueTransposedToMhaEdge.ToNodeIndex = mhaNodeIndex;
|
||||
queryKeyValueTransposedToMhaEdge.ToNodeInputIndex = mhaStackedQueryKeyValueIndex;
|
||||
intermediateEdges.push_back(queryKeyValueTransposedToMhaEdge);
|
||||
|
||||
DML_OUTPUT_GRAPH_EDGE_DESC mhaToOutputEdge = {};
|
||||
mhaToOutputEdge.FromNodeIndex = mhaNodeIndex;
|
||||
mhaToOutputEdge.FromNodeOutputIndex = mhaOutputIndex;
|
||||
mhaToOutputEdge.GraphOutputIndex = OutputIndex::outputIndex;
|
||||
outputEdges.push_back(mhaToOutputEdge);
|
||||
|
||||
if (hasPast)
|
||||
{
|
||||
DML_OUTPUT_GRAPH_EDGE_DESC concatToOutputEdge = {};
|
||||
concatToOutputEdge.FromNodeIndex = concatNodeIndex;
|
||||
concatToOutputEdge.FromNodeOutputIndex = 0;
|
||||
concatToOutputEdge.GraphOutputIndex = OutputIndex::presentIndex;
|
||||
outputEdges.push_back(concatToOutputEdge);
|
||||
}
|
||||
|
||||
MLOperatorGraphDesc operatorGraphDesc = {};
|
||||
operatorGraphDesc.inputEdgeCount = gsl::narrow_cast<uint32_t>(inputEdges.size());
|
||||
operatorGraphDesc.inputEdges = inputEdges.data();
|
||||
operatorGraphDesc.intermediateEdgeCount = gsl::narrow_cast<uint32_t>(intermediateEdges.size());
|
||||
operatorGraphDesc.intermediateEdges = intermediateEdges.data();
|
||||
operatorGraphDesc.outputEdgeCount = gsl::narrow_cast<uint32_t>(outputEdges.size());
|
||||
operatorGraphDesc.outputEdges = outputEdges.data();
|
||||
operatorGraphDesc.nodeCount = gsl::narrow_cast<uint32_t>(opDescs.size());
|
||||
operatorGraphDesc.nodes = opDescs.data();
|
||||
|
||||
SetDmlOperatorGraphDesc(std::move(operatorGraphDesc), kernelCreationContext);
|
||||
}
|
||||
};
|
||||
|
||||
void CALLBACK QueryQAttention(IMLOperatorSupportQueryContextPrivate* context, /*out*/ bool* isSupported)
|
||||
{
|
||||
*isSupported = false;
|
||||
|
||||
// `unidirectional == 1` with Mask Tensor is not supported yet
|
||||
MLOperatorAttributes attributes(context);
|
||||
if (attributes.GetOptionalAttribute<int32_t>(AttrName::Unidirectional, 0) != 0 && context->IsInputValid(5))
|
||||
{
|
||||
return;
|
||||
}
|
||||
|
||||
// `do_rotary == 1` is not supported yet
|
||||
if (attributes.GetOptionalAttribute<int32_t>(AttrName::DoRotary, 0) != 0)
|
||||
{
|
||||
return;
|
||||
}
|
||||
|
||||
// `past_present_share_buffer == 1` is not supported yet
|
||||
if (attributes.GetOptionalAttribute<int32_t>(AttrName::PastPresentShareBuffer, 0) != 0)
|
||||
{
|
||||
return;
|
||||
}
|
||||
|
||||
*isSupported = true;
|
||||
}
|
||||
|
||||
DML_OP_DEFINE_CREATION_FUNCTION(QAttention, DmlOperatorQAttention);
|
||||
} // namespace Dml
|
||||
|
|
@ -178,4 +178,4 @@ void CALLBACK QueryQLinearSigmoid(IMLOperatorSupportQueryContextPrivate* context
|
|||
}
|
||||
|
||||
DML_OP_DEFINE_CREATION_FUNCTION(QLinearSigmoid, DmlOperatorQLinearSigmoid);
|
||||
} // namespace Dml
|
||||
} // namespace Dml
|
||||
|
|
@ -516,6 +516,7 @@ DML_OP_EXTERN_CREATION_FUNCTION(Resize19);
|
|||
|
||||
DML_OP_EXTERN_CREATION_FUNCTION(Shape);
|
||||
DML_OP_EXTERN_CREATION_FUNCTION(Size);
|
||||
DML_OP_EXTERN_CREATION_FUNCTION(QAttention);
|
||||
DML_OP_EXTERN_CREATION_FUNCTION(Attention);
|
||||
DML_OP_EXTERN_CREATION_FUNCTION(MultiHeadAttention);
|
||||
DML_OP_EXTERN_CREATION_FUNCTION(NonZero);
|
||||
|
|
@ -537,6 +538,7 @@ DML_OP_EXTERN_QUERY_FUNCTION(Pad);
|
|||
DML_OP_EXTERN_QUERY_FUNCTION(LayerNormalization);
|
||||
DML_OP_EXTERN_QUERY_FUNCTION(SkipLayerNormalization);
|
||||
DML_OP_EXTERN_QUERY_FUNCTION(QLinearSigmoid);
|
||||
DML_OP_EXTERN_QUERY_FUNCTION(QAttention);
|
||||
DML_OP_EXTERN_QUERY_FUNCTION(Attention);
|
||||
|
||||
constexpr static std::array<const char*, 1> typeNameListDefault = {"T"};
|
||||
|
|
@ -614,15 +616,23 @@ constexpr static std::array<SupportedTensorDataTypes, 2> supportedTypeListLayerN
|
|||
constexpr static std::array<SupportedTensorDataTypes, 2> supportedTypeListShape = {SupportedTensorDataTypes::All, SupportedTensorDataTypes::Int64};
|
||||
constexpr static std::array<SupportedTensorDataTypes, 2> supportedTypeListSize = {SupportedTensorDataTypes::All, SupportedTensorDataTypes::Int64};
|
||||
constexpr static std::array<SupportedTensorDataTypes, 1> supportedTypeListQLinearSigmoid = {SupportedTensorDataTypes::UInt8 | SupportedTensorDataTypes::Int8};
|
||||
|
||||
constexpr static std::array<SupportedTensorDataTypes, 4> supportedTypeListQAttention = {
|
||||
SupportedTensorDataTypes::Ints8Bit,
|
||||
SupportedTensorDataTypes::Ints8Bit,
|
||||
SupportedTensorDataTypes::Float16to32,
|
||||
SupportedTensorDataTypes::Int32
|
||||
};
|
||||
|
||||
constexpr static std::array<SupportedTensorDataTypes, 2> supportedTypeListAttention = {SupportedTensorDataTypes::Float16to32, SupportedTensorDataTypes::Int32};
|
||||
constexpr static std::array<SupportedTensorDataTypes, 2> supportedTypeListRotaryEmbedding = {SupportedTensorDataTypes::Float16to32, SupportedTensorDataTypes::Int64};
|
||||
constexpr static std::array<SupportedTensorDataTypes, 2> supportedTypeListGroupNorm = {SupportedTensorDataTypes::Float16to32, SupportedTensorDataTypes::Float16to32};
|
||||
constexpr static std::array<SupportedTensorDataTypes, 1> supportedTypeListNonZero = {SupportedTensorDataTypes::Float16to32 | SupportedTensorDataTypes::Ints8Bit | SupportedTensorDataTypes::Ints16Bit | SupportedTensorDataTypes::Ints32Bit | SupportedTensorDataTypes::Bool};
|
||||
|
||||
constexpr static std::array<SupportedTensorDataTypes, 3> supportedTypeListQLinearMatMul = {
|
||||
SupportedTensorDataTypes::Int8|SupportedTensorDataTypes::UInt8,
|
||||
SupportedTensorDataTypes::Int8|SupportedTensorDataTypes::UInt8,
|
||||
SupportedTensorDataTypes::Int8|SupportedTensorDataTypes::UInt8
|
||||
SupportedTensorDataTypes::Ints8Bit,
|
||||
SupportedTensorDataTypes::Ints8Bit,
|
||||
SupportedTensorDataTypes::Ints8Bit
|
||||
};
|
||||
|
||||
constexpr static std::array<SupportedTensorDataTypes, 3> supportedTypeListMatMulIntegerToFloat = {
|
||||
|
|
@ -632,9 +642,9 @@ constexpr static std::array<SupportedTensorDataTypes, 3> supportedTypeListMatMul
|
|||
};
|
||||
|
||||
constexpr static std::array<SupportedTensorDataTypes, 4> supportedTypeListQLinearConv = {
|
||||
SupportedTensorDataTypes::Int8|SupportedTensorDataTypes::UInt8,
|
||||
SupportedTensorDataTypes::Int8|SupportedTensorDataTypes::UInt8,
|
||||
SupportedTensorDataTypes::Int8|SupportedTensorDataTypes::UInt8,
|
||||
SupportedTensorDataTypes::Ints8Bit,
|
||||
SupportedTensorDataTypes::Ints8Bit,
|
||||
SupportedTensorDataTypes::Ints8Bit,
|
||||
SupportedTensorDataTypes::Int32
|
||||
};
|
||||
|
||||
|
|
@ -1069,6 +1079,7 @@ constexpr static OperatorRegistrationInformation operatorRegistrationInformation
|
|||
{REG_INFO_MS( 1, DynamicQuantizeMatMul, typeNameListTwo, supportedTypeListDynamicQuantizeLinear, DmlGraphSupport::Supported)},
|
||||
{REG_INFO_MS( 1, FusedMatMulActivation, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
|
||||
{REG_INFO_MS( 1, QLinearSigmoid, typeNameListDefault, supportedTypeListQLinearSigmoid, DmlGraphSupport::Supported, requiredConstantCpuInputs(), std::nullopt, QueryQLinearSigmoid)},
|
||||
{REG_INFO_MS( 1, QAttention, typeNameListFour, supportedTypeListQAttention, DmlGraphSupport::Supported, requiredConstantCpuInputs(), std::nullopt, QueryQAttention)},
|
||||
{REG_INFO_MS( 1, Attention, typeNameListAttention, supportedTypeListAttention, DmlGraphSupport::Supported, requiredConstantCpuInputs(), std::nullopt, QueryAttention)},
|
||||
{REG_INFO_MS( 1, MultiHeadAttention, typeNameListAttention, supportedTypeListAttention, DmlGraphSupport::Supported)},
|
||||
{REG_INFO_MS( 1, RotaryEmbedding, typeNameListRotaryEmbedding, supportedTypeListRotaryEmbedding, DmlGraphSupport::Supported)},
|
||||
|
|
|
|||
|
|
@ -2802,6 +2802,48 @@ namespace OperatorHelper
|
|||
m_qkvHiddenSizes = kernelInformation.GetAttributes().GetOptionalAttributeVectorInt32(AttrName::QkvHiddenSizes);
|
||||
}
|
||||
|
||||
std::vector<EdgeShapes> QAttentionHelper::GetOutputShapes(const MLShapeInferenceContext& shapeInfo) const
|
||||
{
|
||||
ML_CHECK_VALID_ARGUMENT(shapeInfo.GetInputCount() >= 5);
|
||||
|
||||
auto queryShape = shapeInfo.GetInputTensorShape(0);
|
||||
ML_CHECK_VALID_ARGUMENT(queryShape.size() == 3);
|
||||
|
||||
auto weightShape = shapeInfo.GetInputTensorShape(1);
|
||||
ML_CHECK_VALID_ARGUMENT(weightShape.size() == 2);
|
||||
ML_CHECK_VALID_ARGUMENT(weightShape[1] % 3 == 0);
|
||||
|
||||
const uint32_t batchSize = queryShape[0];
|
||||
const uint32_t sequenceLength = queryShape[1];
|
||||
const uint32_t hiddenSize = weightShape[1] / 3;
|
||||
const uint32_t headSize = hiddenSize / m_numHeads;
|
||||
|
||||
std::vector<EdgeShapes> outputShapes(2);
|
||||
|
||||
outputShapes[0] = EdgeShapes({batchSize, sequenceLength, hiddenSize});
|
||||
|
||||
uint32_t totalSequenceLength = sequenceLength;
|
||||
if (shapeInfo.IsInputValid(8))
|
||||
{
|
||||
ML_CHECK_VALID_ARGUMENT(shapeInfo.GetInputTensorDimensionCount(8) == 5);
|
||||
const uint32_t pastSequenceLength = shapeInfo.GetInputTensorShape(8)[3];
|
||||
totalSequenceLength += pastSequenceLength;
|
||||
}
|
||||
|
||||
if (shapeInfo.IsOutputValid(1))
|
||||
{
|
||||
ML_CHECK_VALID_ARGUMENT(shapeInfo.IsInputValid(8));
|
||||
outputShapes[1] = EdgeShapes({2, batchSize, m_numHeads, totalSequenceLength, headSize});
|
||||
}
|
||||
|
||||
return outputShapes;
|
||||
}
|
||||
|
||||
void QAttentionHelper::Initialize(const IKernelInformationAdapter& kernelInformation)
|
||||
{
|
||||
m_numHeads = gsl::narrow_cast<uint32_t>(kernelInformation.GetAttributes().GetAttribute<int64_t>(AttrName::NumHeads));
|
||||
}
|
||||
|
||||
std::vector<EdgeShapes> SkipLayerNormHelper::GetOutputShapes(const MLShapeInferenceContext& shapeInfo) const
|
||||
{
|
||||
ML_CHECK_VALID_ARGUMENT(shapeInfo.GetInputCount() >= 3);
|
||||
|
|
|
|||
|
|
@ -1554,6 +1554,22 @@ private:
|
|||
std::vector<int32_t> m_qkvHiddenSizes;
|
||||
};
|
||||
|
||||
class QAttentionHelper
|
||||
{
|
||||
public:
|
||||
template <typename Info_t, typename Shape_t>
|
||||
QAttentionHelper(const Info_t& info, const Shape_t& shapeInfo)
|
||||
{
|
||||
Initialize(KernelInformationAdapter(info));
|
||||
}
|
||||
|
||||
std::vector<EdgeShapes> GetOutputShapes(const MLShapeInferenceContext& shapeInfo) const;
|
||||
|
||||
private:
|
||||
void Initialize(const IKernelInformationAdapter& kernelInformation);
|
||||
uint32_t m_numHeads;
|
||||
};
|
||||
|
||||
class SkipLayerNormHelper
|
||||
{
|
||||
public:
|
||||
|
|
@ -1699,6 +1715,7 @@ using ShapeInferenceHelper_Affine = GetOutputShapeAsInputShapeHelper;
|
|||
using ShapeInferenceHelper_QuantizeLinear = GetOutputShapeAsInputShapeHelper;
|
||||
using ShapeInferenceHelper_DequantizeLinear = GetOutputShapeAsInputShapeHelper;
|
||||
using ShapeInferenceHelper_QLinearSigmoid = GetOutputShapeAsInputShapeHelper;
|
||||
using ShapeInferenceHelper_QAttention = QAttentionHelper;
|
||||
using ShapeInferenceHelper_Attention = AttentionHelper;
|
||||
using ShapeInferenceHelper_MultiHeadAttention = MultiHeadAttentionHelper;
|
||||
using ShapeInferenceHelper_RotaryEmbedding = GetOutputShapeAsInputShapeHelper;
|
||||
|
|
|
|||
|
|
@ -448,6 +448,7 @@ namespace OperatorHelper
|
|||
static const int sc_sinceVer_FusedMatMul = 1;
|
||||
static const int sc_sinceVer_FusedMatMulActivation = 1;
|
||||
static const int sc_sinceVer_QLinearSigmoid = 1;
|
||||
static const int sc_sinceVer_QAttention = 1;
|
||||
static const int sc_sinceVer_Attention = 1;
|
||||
static const int sc_sinceVer_MatMulIntegerToFloat = 1;
|
||||
static const int sc_sinceVer_MultiHeadAttention = 1;
|
||||
|
|
|
|||
|
|
@ -20,7 +20,8 @@ namespace test {
|
|||
enum class EP : char {
|
||||
CPU,
|
||||
CUDA,
|
||||
DNNL
|
||||
DNNL,
|
||||
DML
|
||||
};
|
||||
|
||||
// input: [batch_size, sequence_length, hidden_size]
|
||||
|
|
@ -111,7 +112,9 @@ void RunQAttention(const std::vector<float>& input_data,
|
|||
execution_providers.push_back(DefaultCudaExecutionProvider());
|
||||
} else if constexpr (ep == EP::CPU) {
|
||||
execution_providers.push_back(DefaultCpuExecutionProvider());
|
||||
} else { // onednn ep
|
||||
} else if constexpr (ep == EP::DML) {
|
||||
execution_providers.push_back(DefaultDmlExecutionProvider());
|
||||
} else { // onednn ep
|
||||
execution_providers.push_back(DefaultDnnlExecutionProvider());
|
||||
}
|
||||
|
||||
|
|
@ -192,6 +195,52 @@ static void RunQAttentionDNNL(
|
|||
#endif
|
||||
}
|
||||
|
||||
static void RunQAttentionDML(
|
||||
const std::vector<float>& input_data,
|
||||
const std::vector<float>& weights_data,
|
||||
const std::vector<float>& bias_data,
|
||||
const std::vector<int32_t>& mask_index_data,
|
||||
const std::vector<float>& output_data,
|
||||
int batch_size,
|
||||
int sequence_length,
|
||||
int hidden_size,
|
||||
int number_of_heads,
|
||||
bool use_special_quantize_parameter = true,
|
||||
bool is_unidirectional = false,
|
||||
int input_hidden_size = 0) {
|
||||
// Return without running code if USE_DML is not defined
|
||||
#ifdef USE_DML
|
||||
bool enable_dml = (nullptr != DefaultDmlExecutionProvider().get());
|
||||
if (enable_dml) {
|
||||
quantization::Params<uint8_t> input_quant_params(/*scale=*/0.0f, /*zero_point=*/0);
|
||||
quantization::Params<int8_t> weights_quant_params(/*scale=*/0.0f, /*zero_point=*/0);
|
||||
if (use_special_quantize_parameter) {
|
||||
input_quant_params.scale = 0.1f;
|
||||
weights_quant_params.scale = 0.1f;
|
||||
input_quant_params.zero_point = 128;
|
||||
weights_quant_params.zero_point = 1;
|
||||
}
|
||||
|
||||
RunQAttention<uint8_t, int8_t, EP::DML>(
|
||||
input_data, weights_data, bias_data, mask_index_data, output_data, input_quant_params, weights_quant_params,
|
||||
batch_size, sequence_length, hidden_size, number_of_heads, is_unidirectional, false, input_hidden_size);
|
||||
}
|
||||
#else
|
||||
ORT_UNUSED_PARAMETER(input_data);
|
||||
ORT_UNUSED_PARAMETER(weights_data);
|
||||
ORT_UNUSED_PARAMETER(bias_data);
|
||||
ORT_UNUSED_PARAMETER(mask_index_data);
|
||||
ORT_UNUSED_PARAMETER(output_data);
|
||||
ORT_UNUSED_PARAMETER(batch_size);
|
||||
ORT_UNUSED_PARAMETER(sequence_length);
|
||||
ORT_UNUSED_PARAMETER(hidden_size);
|
||||
ORT_UNUSED_PARAMETER(number_of_heads);
|
||||
ORT_UNUSED_PARAMETER(use_special_quantize_parameter);
|
||||
ORT_UNUSED_PARAMETER(is_unidirectional);
|
||||
ORT_UNUSED_PARAMETER(input_hidden_size);
|
||||
#endif
|
||||
}
|
||||
|
||||
static void RunQAttentionU8U8(
|
||||
const std::vector<float>& input_data,
|
||||
const std::vector<float>& weights_data,
|
||||
|
|
@ -272,6 +321,9 @@ static void RunQAttentionAll(
|
|||
RunQAttentionDNNL(input_data, weight_data, bias_data, mask_index_data, output_data,
|
||||
batch_size, sequence_length, hidden_size, number_of_heads,
|
||||
use_special_quantize_parameter, is_unidirectional, input_hidden_size);
|
||||
RunQAttentionDML(input_data, weight_data, bias_data, mask_index_data, output_data,
|
||||
batch_size, sequence_length, hidden_size, number_of_heads,
|
||||
use_special_quantize_parameter, is_unidirectional, input_hidden_size);
|
||||
}
|
||||
|
||||
// ONEDNN EP only supports 2D raw mask
|
||||
|
|
@ -859,8 +911,8 @@ void TestQuantizedAttentionPastState(int64_t batch,
|
|||
std::vector<int64_t> input_dims{batch, seq_len, hidden_size};
|
||||
std::vector<InputT> input_data = random.Gaussian<InputT>(input_dims, input_mean, static_cast<InputT>(input_range / 6), input_min, input_max);
|
||||
|
||||
constexpr WeightT weight_min = std::numeric_limits<WeightT>::min();
|
||||
constexpr WeightT weight_max = std::numeric_limits<WeightT>::max();
|
||||
constexpr WeightT weight_min = std::is_same_v<WeightT, int8_t> ? std::numeric_limits<int8_t>::min() / 2 : std::numeric_limits<WeightT>::min();
|
||||
constexpr WeightT weight_max = std::numeric_limits<WeightT>::max() / 2;
|
||||
constexpr int32_t weight_range = weight_max - weight_min;
|
||||
|
||||
std::vector<WeightT> weight_zero_point(weight_scale_zp_size);
|
||||
|
|
|
|||
Loading…
Reference in a new issue