mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-07-01 03:45:06 +00:00
Enable Past and Present Tensors for QAttention (#17947)
### Description Enabling support for `Past`, `Present` and `unidirectional` for [QAttention](https://github.com/microsoft/onnxruntime/blob/main/docs/ContribOperators.md#com.microsoft.QAttention) Contrib Op ``` Note: Google Test filter = *QAttention* [==========] Running 14 tests from 2 test suites. [----------] Global test environment set-up. [----------] 1 test from CPU_U8S8_Precision_Tests [ RUN ] CPU_U8S8_Precision_Tests.QAttention [ OK ] CPU_U8S8_Precision_Tests.QAttention (104 ms) [----------] 1 test from CPU_U8S8_Precision_Tests (105 ms total) [----------] 13 tests from QAttentionTest [ RUN ] QAttentionTest.QAttentionBatch1 [ OK ] QAttentionTest.QAttentionBatch1 (255 ms) [ RUN ] QAttentionTest.QAttentionBatch1_Float16 [ OK ] QAttentionTest.QAttentionBatch1_Float16 (0 ms) [ RUN ] QAttentionTest.QAttentionBatch2 [ OK ] QAttentionTest.QAttentionBatch2 (201 ms) [ RUN ] QAttentionTest.QAttentionMaskPartialSequence [ OK ] QAttentionTest.QAttentionMaskPartialSequence (197 ms) [ RUN ] QAttentionTest.QAttentionMaskExceedSequence [ OK ] QAttentionTest.QAttentionMaskExceedSequence (192 ms) [ RUN ] QAttentionTest.QAttentionNoMaskIndex [ OK ] QAttentionTest.QAttentionNoMaskIndex (186 ms) [ RUN ] QAttentionTest.QAttentionUnidirectional_U8U8 [ OK ] QAttentionTest.QAttentionUnidirectional_U8U8 (9 ms) [ RUN ] QAttentionTest.QAttentionUnidirectional_U8S8 [ OK ] QAttentionTest.QAttentionUnidirectional_U8S8 (9 ms) [ RUN ] QAttentionTest.QAttentionUnidirectional_CUDA [ OK ] QAttentionTest.QAttentionUnidirectional_CUDA (0 ms) [ RUN ] QAttentionTest.QAttentionPastState_u8u8 C:\workspace\ORT\onnxruntime\onnxruntime\test\providers\checkers.cc(268): error: The difference between cur_expected[i] and cur_actual[i] is 0.75743968039751053, which exceeds *(params.absolute_error), where cur_expected[i] evaluates to 0.67312467098236084, cur_actual[i] evaluates to -0.084315009415149689, and *(params.absolute_error) evaluates to 0.00019999999494757503. i:0 Google Test trace: C:\workspace\ORT\onnxruntime\onnxruntime\test\providers\checkers.cc(484): provider type: DmlExecutionProvider C:\workspace\ORT\onnxruntime\onnxruntime\test/common/random_generator.h(49): ORT test random seed: 2178993560 C:\workspace\ORT\onnxruntime\onnxruntime\test\providers\checkers.cc(268): error: The difference between cur_expected[i] and cur_actual[i] is 0.75743968039751053, which exceeds *(params.absolute_error), where cur_expected[i] evaluates to 0.67312467098236084, cur_actual[i] evaluates to -0.084315009415149689, and *(params.absolute_error) evaluates to 0.00019999999494757503. i:0 Google Test trace: C:\workspace\ORT\onnxruntime\onnxruntime\test\providers\checkers.cc(484): provider type: DmlExecutionProvider C:\workspace\ORT\onnxruntime\onnxruntime\test/common/random_generator.h(49): ORT test random seed: 2178993560 C:\workspace\ORT\onnxruntime\onnxruntime\test\providers\checkers.cc(268): error: The difference between cur_expected[i] and cur_actual[i] is 0.03001787792891264, which exceeds *(params.absolute_error), where cur_expected[i] evaluates to -0.021467097103595734, cur_actual[i] evaluates to 0.008550780825316906, and *(params.absolute_error) evaluates to 0.00019999999494757503. i:0 Google Test trace: C:\workspace\ORT\onnxruntime\onnxruntime\test\providers\checkers.cc(484): provider type: DmlExecutionProvider C:\workspace\ORT\onnxruntime\onnxruntime\test/common/random_generator.h(49): ORT test random seed: 2178993560 C:\workspace\ORT\onnxruntime\onnxruntime\test\providers\checkers.cc(268): error: The difference between cur_expected[i] and cur_actual[i] is 0.03001787792891264, which exceeds *(params.absolute_error), where cur_expected[i] evaluates to -0.021467097103595734, cur_actual[i] evaluates to 0.008550780825316906, and *(params.absolute_error) evaluates to 0.00019999999494757503. i:0 Google Test trace: C:\workspace\ORT\onnxruntime\onnxruntime\test\providers\checkers.cc(484): provider type: DmlExecutionProvider C:\workspace\ORT\onnxruntime\onnxruntime\test/common/random_generator.h(49): ORT test random seed: 2178993560 [ FAILED ] QAttentionTest.QAttentionPastState_u8u8 (2067 ms) [ RUN ] QAttentionTest.QAttentionPastState_u8s8 C:\workspace\ORT\onnxruntime\onnxruntime\test\providers\checkers.cc(268): error: The difference between cur_expected[i] and cur_actual[i] is 0.74043640494346619, which exceeds *(params.absolute_error), where cur_expected[i] evaluates to 0.65650326013565063, cur_actual[i] evaluates to -0.083933144807815552, and *(params.absolute_error) evaluates to 0.00019999999494757503. i:0 Google Test trace: C:\workspace\ORT\onnxruntime\onnxruntime\test\providers\checkers.cc(484): provider type: DmlExecutionProvider C:\workspace\ORT\onnxruntime\onnxruntime\test/common/random_generator.h(49): ORT test random seed: 2178993560 C:\workspace\ORT\onnxruntime\onnxruntime\test\providers\checkers.cc(268): error: The difference between cur_expected[i] and cur_actual[i] is 0.081788420677185059, which exceeds *(params.absolute_error), where cur_expected[i] evaluates to 1.0076344013214111, cur_actual[i] evaluates to 1.0894228219985962, and *(params.absolute_error) evaluates to 0.00019999999494757503. i:965 Google Test trace: C:\workspace\ORT\onnxruntime\onnxruntime\test\providers\checkers.cc(484): provider type: DmlExecutionProvider C:\workspace\ORT\onnxruntime\onnxruntime\test/common/random_generator.h(49): ORT test random seed: 2178993560 C:\workspace\ORT\onnxruntime\onnxruntime\test\providers\checkers.cc(268): error: The difference between cur_expected[i] and cur_actual[i] is 0.74043640494346619, which exceeds *(params.absolute_error), where cur_expected[i] evaluates to 0.65650326013565063, cur_actual[i] evaluates to -0.083933144807815552, and *(params.absolute_error) evaluates to 0.00019999999494757503. i:0 Google Test trace: C:\workspace\ORT\onnxruntime\onnxruntime\test\providers\checkers.cc(484): provider type: DmlExecutionProvider C:\workspace\ORT\onnxruntime\onnxruntime\test/common/random_generator.h(49): ORT test random seed: 2178993560 C:\workspace\ORT\onnxruntime\onnxruntime\test\providers\checkers.cc(268): error: The difference between cur_expected[i] and cur_actual[i] is 0.081788420677185059, which exceeds *(params.absolute_error), where cur_expected[i] evaluates to 1.0076344013214111, cur_actual[i] evaluates to 1.0894228219985962, and *(params.absolute_error) evaluates to 0.00019999999494757503. i:965 Google Test trace: C:\workspace\ORT\onnxruntime\onnxruntime\test\providers\checkers.cc(484): provider type: DmlExecutionProvider C:\workspace\ORT\onnxruntime\onnxruntime\test/common/random_generator.h(49): ORT test random seed: 2178993560 C:\workspace\ORT\onnxruntime\onnxruntime\test\providers\checkers.cc(268): error: The difference between cur_expected[i] and cur_actual[i] is 0.024714200757443905, which exceeds *(params.absolute_error), where cur_expected[i] evaluates to -0.016048312187194824, cur_actual[i] evaluates to 0.0086658885702490807, and *(params.absolute_error) evaluates to 0.00019999999494757503. i:0 Google Test trace: C:\workspace\ORT\onnxruntime\onnxruntime\test\providers\checkers.cc(484): provider type: DmlExecutionProvider C:\workspace\ORT\onnxruntime\onnxruntime\test/common/random_generator.h(49): ORT test random seed: 2178993560 C:\workspace\ORT\onnxruntime\onnxruntime\test\providers\checkers.cc(268): error: The difference between cur_expected[i] and cur_actual[i] is 0.0092324763536453247, which exceeds *(params.absolute_error), where cur_expected[i] evaluates to 0.24175386130809784, cur_actual[i] evaluates to 0.25098633766174316, and *(params.absolute_error) evaluates to 0.00019999999494757503. i:979 Google Test trace: C:\workspace\ORT\onnxruntime\onnxruntime\test\providers\checkers.cc(484): provider type: DmlExecutionProvider C:\workspace\ORT\onnxruntime\onnxruntime\test/common/random_generator.h(49): ORT test random seed: 2178993560 C:\workspace\ORT\onnxruntime\onnxruntime\test\providers\checkers.cc(268): error: The difference between cur_expected[i] and cur_actual[i] is 0.024714200757443905, which exceeds *(params.absolute_error), where cur_expected[i] evaluates to -0.016048312187194824, cur_actual[i] evaluates to 0.0086658885702490807, and *(params.absolute_error) evaluates to 0.00019999999494757503. i:0 Google Test trace: C:\workspace\ORT\onnxruntime\onnxruntime\test\providers\checkers.cc(484): provider type: DmlExecutionProvider C:\workspace\ORT\onnxruntime\onnxruntime\test/common/random_generator.h(49): ORT test random seed: 2178993560 C:\workspace\ORT\onnxruntime\onnxruntime\test\providers\checkers.cc(268): error: The difference between cur_expected[i] and cur_actual[i] is 0.0092324763536453247, which exceeds *(params.absolute_error), where cur_expected[i] evaluates to 0.24175386130809784, cur_actual[i] evaluates to 0.25098633766174316, and *(params.absolute_error) evaluates to 0.00019999999494757503. i:979 Google Test trace: C:\workspace\ORT\onnxruntime\onnxruntime\test\providers\checkers.cc(484): provider type: DmlExecutionProvider C:\workspace\ORT\onnxruntime\onnxruntime\test/common/random_generator.h(49): ORT test random seed: 2178993560 [ FAILED ] QAttentionTest.QAttentionPastState_u8s8 (2079 ms) [ RUN ] QAttentionTest.QAttentionPrunedModel [ OK ] QAttentionTest.QAttentionPrunedModel (206 ms) [ RUN ] QAttentionTest.SharedPrepackedWeights [ OK ] QAttentionTest.SharedPrepackedWeights (79 ms) [----------] 13 tests from QAttentionTest (5492 ms total) [----------] Global test environment tear-down [==========] 14 tests from 2 test suites ran. (5600 ms total) [ PASSED ] 12 tests. [ FAILED ] 2 tests, listed below: [ FAILED ] QAttentionTest.QAttentionPastState_u8u8 [ FAILED ] QAttentionTest.QAttentionPastState_u8s8 2 FAILED TESTS memleakdbg: ----- No memory leaks detected ----- ``` ### Motivation and Context <!-- - Why is this change required? What problem does it solve? - If it fixes an open issue, please link to the issue here. -->
This commit is contained in:
parent
25d3b0b313
commit
fcb48ae260
7 changed files with 348 additions and 166 deletions
|
|
@ -848,7 +848,7 @@ constexpr DML_SCHEMA_FIELD DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_SCHEMA_
|
|||
|
||||
constexpr DML_OPERATOR_SCHEMA DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_SCHEMA {
|
||||
"DML_OPERATOR_QUANTIZED_LINEAR_AVERAGE_POOLING",
|
||||
static_cast<DML_OPERATOR_TYPE>(DML_OPERATOR_QUANTIZED_LINEAR_AVERAGE_POOLING),
|
||||
DML_OPERATOR_QUANTIZED_LINEAR_AVERAGE_POOLING,
|
||||
DML_SCHEMA_OPERATOR_SUPPORT_FLAG_NONE,
|
||||
13,
|
||||
DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_SCHEMA_FIELDS,
|
||||
|
|
@ -1883,7 +1883,7 @@ constexpr DML_SCHEMA_FIELD DML_MATRIX_MULTIPLY_INTEGER_TO_FLOAT_OPERATOR_SCHEMA_
|
|||
|
||||
constexpr DML_OPERATOR_SCHEMA DML_MATRIX_MULTIPLY_INTEGER_TO_FLOAT_OPERATOR_SCHEMA {
|
||||
"DML_OPERATOR_MATRIX_MULTIPLY_INTEGER_TO_FLOAT",
|
||||
static_cast<DML_OPERATOR_TYPE>(DML_OPERATOR_MATRIX_MULTIPLY_INTEGER_TO_FLOAT),
|
||||
DML_OPERATOR_MATRIX_MULTIPLY_INTEGER_TO_FLOAT,
|
||||
DML_SCHEMA_OPERATOR_SUPPORT_FLAG_NONE,
|
||||
8,
|
||||
DML_MATRIX_MULTIPLY_INTEGER_TO_FLOAT_OPERATOR_SCHEMA_FIELDS,
|
||||
|
|
|
|||
|
|
@ -571,6 +571,12 @@ void CALLBACK QueryAttention(IMLOperatorSupportQueryContextPrivate* context, /*o
|
|||
return;
|
||||
}
|
||||
|
||||
// `past_present_share_buffer == 1` is not supported yet
|
||||
if (attributes.GetOptionalAttribute<int32_t>(AttrName::PastPresentShareBuffer, 0) != 0)
|
||||
{
|
||||
return;
|
||||
}
|
||||
|
||||
*isSupported = true;
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -105,7 +105,7 @@ public:
|
|||
matrixMultiplyIntergerToFloatOperatorDesc.BiasTensor = hasBias? &inputDescs[OnnxInputIndex::Bias] : nullptr;
|
||||
matrixMultiplyIntergerToFloatOperatorDesc.OutputTensor = &outputDescs[0];
|
||||
|
||||
const DML_OPERATOR_DESC opDesc2{ static_cast<DML_OPERATOR_TYPE>(DML_OPERATOR_MATRIX_MULTIPLY_INTEGER_TO_FLOAT), &matrixMultiplyIntergerToFloatOperatorDesc};
|
||||
const DML_OPERATOR_DESC opDesc2{ DML_OPERATOR_MATRIX_MULTIPLY_INTEGER_TO_FLOAT, &matrixMultiplyIntergerToFloatOperatorDesc};
|
||||
|
||||
MLOperatorGraphDesc operatorGraphDesc = {};
|
||||
std::vector<const DML_OPERATOR_DESC*> opDescs{&opDesc1, &opDesc2};
|
||||
|
|
|
|||
|
|
@ -8,7 +8,9 @@ Abbreviations: B is batch_size, S is sequence_length, W is hidden_size
|
|||
N is number of attention heads, H is head size, and W=N*H
|
||||
M is mask_index tensor
|
||||
|
||||
M A B C // M, A, B, and C are Inputs
|
||||
M, A, B, C and P are Inputs
|
||||
|
||||
M A B C
|
||||
| | | /
|
||||
| MatMulIntToFloat
|
||||
| / | \
|
||||
|
|
@ -20,10 +22,27 @@ Abbreviations: B is batch_size, S is sequence_length, W is hidden_size
|
|||
| Identity Identity Identity // The identities are used to transpose NCHW -> NHCW while
|
||||
| | | | // keeping the GEMM strides as NCHW to better target metacommands
|
||||
| | | |
|
||||
----------------- MHA -----
|
||||
|
|
||||
|
|
||||
Output // Final output
|
||||
| | | | P
|
||||
| | | | / \
|
||||
| | | | / \
|
||||
| | | | Slice Slice
|
||||
| | | | | |
|
||||
| | | | | |
|
||||
| | | | | |
|
||||
--------------------------MHA -----------
|
||||
/ | \
|
||||
/ | \
|
||||
/ | \
|
||||
/ | \
|
||||
/ | \
|
||||
/ | \
|
||||
/ presentKey presentValue
|
||||
/ \ /
|
||||
/ \ /
|
||||
/ \ /
|
||||
/ Concat
|
||||
/ |
|
||||
Output1 Output2 (present)
|
||||
|
||||
This kernel creates a DML_GRAPH, as mentioned above.
|
||||
For reference, refer to this Doc:
|
||||
|
|
@ -39,22 +58,6 @@ public:
|
|||
: DmlOperator(kernelCreationContext)
|
||||
{
|
||||
|
||||
enum DmlInputIndex : uint32_t
|
||||
{
|
||||
mhaQueryIndex,
|
||||
mhaKeyIndex,
|
||||
mhaValueIndex,
|
||||
mhaStackedQueryKeyIndex,
|
||||
mhaStackedKeyValueIndex,
|
||||
mhaStackedQueryKeyValueIndex,
|
||||
mhaBiasIndex,
|
||||
mhaMaskIndex,
|
||||
mhaRelativePositionBiasIndex,
|
||||
mhaPastKeyIndex,
|
||||
mhaPastValueIndex,
|
||||
mhaInputCount,
|
||||
};
|
||||
|
||||
enum InputIndex : uint32_t
|
||||
{
|
||||
inputIndex,
|
||||
|
|
@ -72,18 +75,45 @@ public:
|
|||
enum OutputIndex : uint32_t
|
||||
{
|
||||
outputIndex,
|
||||
presentIndex,
|
||||
outputCount,
|
||||
};
|
||||
|
||||
ML_CHECK_VALID_ARGUMENT(kernelCreationContext.GetInputCount() >= 2);
|
||||
enum MhaInputIndex : uint32_t
|
||||
{
|
||||
mhaQueryIndex,
|
||||
mhaKeyIndex,
|
||||
mhaValueIndex,
|
||||
mhaStackedQueryKeyIndex,
|
||||
mhaStackedKeyValueIndex,
|
||||
mhaStackedQueryKeyValueIndex,
|
||||
mhaBiasIndex,
|
||||
mhaMaskIndex,
|
||||
mhaRelativePositionBiasIndex,
|
||||
mhaPastKeyIndex,
|
||||
mhaPastValueIndex,
|
||||
mhaInputCount,
|
||||
};
|
||||
|
||||
enum MhaOutputIndex : uint32_t
|
||||
{
|
||||
mhaOutputIndex,
|
||||
mhaPresentKeyIndex,
|
||||
mhaPresentValueIndex,
|
||||
mhaOutputCount,
|
||||
};
|
||||
|
||||
ML_CHECK_VALID_ARGUMENT(kernelCreationContext.GetInputCount() >= 5);
|
||||
ML_CHECK_VALID_ARGUMENT(kernelCreationContext.GetOutputCount() >= 1);
|
||||
|
||||
const bool hasBias = kernelCreationContext.IsInputValid(biasIndex);
|
||||
const bool hasMask = kernelCreationContext.IsInputValid(maskIndex);
|
||||
const bool hasUnpaddedBounds = hasMask && kernelCreationContext.GetInputTensorDimensionCount(maskIndex) == 1;
|
||||
const bool hasPast = kernelCreationContext.IsInputValid(pastIndex);
|
||||
|
||||
DmlOperator::Initialize(kernelCreationContext, std::nullopt, std::nullopt, std::nullopt, std::nullopt, 1);
|
||||
|
||||
const bool unidirectional = gsl::narrow_cast<uint32_t>(kernelCreationContext.GetAttribute<int64_t>(AttrName::Unidirectional));
|
||||
const uint32_t numHeads = gsl::narrow_cast<uint32_t>(kernelCreationContext.GetAttribute<int64_t>(AttrName::NumHeads));
|
||||
ML_CHECK_VALID_ARGUMENT(numHeads > 0); // to avoid process crash because of division by zero.
|
||||
|
||||
|
|
@ -93,38 +123,28 @@ public:
|
|||
auto weightTensorShape = m_inputTensorDescs[weightsIndex].GetSizes();
|
||||
ML_CHECK_VALID_ARGUMENT(weightTensorShape.size() == 2);
|
||||
ML_CHECK_VALID_ARGUMENT(weightTensorShape[0] == inputTensorShape[2]);
|
||||
ML_CHECK_VALID_ARGUMENT(weightTensorShape[1] % 3 == 0);
|
||||
|
||||
const auto qkvHiddenSizes = kernelCreationContext.GetOptionalAttributeVectorInt32(AttrName::QkvHiddenSizes);
|
||||
if (hasBias)
|
||||
{
|
||||
auto biasTensorShape = m_inputTensorDescs[biasIndex].GetSizes();
|
||||
ML_CHECK_VALID_ARGUMENT(biasTensorShape.size() == 1);
|
||||
ML_CHECK_VALID_ARGUMENT(biasTensorShape[0] % 3 == 0);
|
||||
ML_CHECK_VALID_ARGUMENT(weightTensorShape[1] == biasTensorShape[0]);
|
||||
|
||||
if (qkvHiddenSizes.empty())
|
||||
{
|
||||
ML_CHECK_VALID_ARGUMENT(biasTensorShape[0] % 3 == 0);
|
||||
}
|
||||
}
|
||||
|
||||
if (!qkvHiddenSizes.empty())
|
||||
if (hasPast)
|
||||
{
|
||||
ML_CHECK_VALID_ARGUMENT(qkvHiddenSizes.size() == 3);
|
||||
ML_CHECK_VALID_ARGUMENT(qkvHiddenSizes[0] == qkvHiddenSizes[1]);
|
||||
}
|
||||
else
|
||||
{
|
||||
ML_CHECK_VALID_ARGUMENT(weightTensorShape[1] % 3 == 0);
|
||||
ML_CHECK_VALID_ARGUMENT(kernelCreationContext.IsOutputValid(presentIndex));
|
||||
}
|
||||
|
||||
const uint32_t hiddenSize = qkvHiddenSizes.empty() ? weightTensorShape[1] / 3 : qkvHiddenSizes[0];
|
||||
const uint32_t vHiddenSize = qkvHiddenSizes.empty() ? weightTensorShape[1] / 3 : qkvHiddenSizes[2];
|
||||
const uint32_t hiddenSize = weightTensorShape[1] / 3;
|
||||
const uint32_t headSize = hiddenSize / numHeads;
|
||||
const uint32_t vHeadSize = vHiddenSize / numHeads;
|
||||
const uint32_t batchSize = inputTensorShape[0];
|
||||
const uint32_t sequenceLength = inputTensorShape[1];
|
||||
const uint32_t pastSequenceLength = hasPast ? m_inputTensorDescs[pastIndex].GetSizes()[3] : 0;
|
||||
|
||||
uint32_t desiredWeightTensorShape[3] = {batchSize, weightTensorShape[0], hiddenSize + hiddenSize + vHiddenSize};
|
||||
uint32_t desiredWeightTensorShape[3] = {batchSize, weightTensorShape[0], 3 * hiddenSize};
|
||||
MLOperatorTensorDataType dataType = kernelCreationContext.GetOutputEdgeDescription(outputIndex).tensorDataType;
|
||||
|
||||
m_inputTensorDescs[weightsIndex] = TensorDesc::ConstructBroadcastedTensorDesc(
|
||||
|
|
@ -132,7 +152,7 @@ public:
|
|||
desiredWeightTensorShape,
|
||||
weightTensorShape);
|
||||
|
||||
uint32_t desiredBiasTensorShape[3] = {batchSize, sequenceLength, hiddenSize + hiddenSize + vHiddenSize};
|
||||
uint32_t desiredBiasTensorShape[3] = {batchSize, sequenceLength, 3 * hiddenSize};
|
||||
|
||||
if (hasBias)
|
||||
{
|
||||
|
|
@ -189,6 +209,14 @@ public:
|
|||
}
|
||||
}
|
||||
|
||||
MLOperatorTensorDataType pastTensorDataType = MLOperatorTensorDataType::Undefined;
|
||||
MLOperatorTensorDataType presentTensorDataType = MLOperatorTensorDataType::Undefined;
|
||||
if (hasPast)
|
||||
{
|
||||
pastTensorDataType = kernelCreationContext.GetInputEdgeDescription(pastIndex).tensorDataType;
|
||||
presentTensorDataType = kernelCreationContext.GetOutputEdgeDescription(presentIndex).tensorDataType;
|
||||
}
|
||||
|
||||
TensorDesc matMulIntToFloatOutputTensorDesc = TensorDesc::ConstructDefaultTensorDesc(dataType, desiredBiasTensorShape);
|
||||
DML_TENSOR_DESC namedMatMulIntToFloatOutputTensorDesc = matMulIntToFloatOutputTensorDesc.GetDmlDesc();
|
||||
|
||||
|
|
@ -205,13 +233,13 @@ public:
|
|||
matMulIntToFloatOperatorDesc.BiasTensor = hasBias ? &inputDescs[InputIndex::biasIndex] : nullptr;
|
||||
matMulIntToFloatOperatorDesc.OutputTensor = &namedMatMulIntToFloatOutputTensorDesc;
|
||||
|
||||
const DML_OPERATOR_DESC matMulIntToFloatDesc = { static_cast<DML_OPERATOR_TYPE>(DML_OPERATOR_MATRIX_MULTIPLY_INTEGER_TO_FLOAT), &matMulIntToFloatOperatorDesc};
|
||||
const DML_OPERATOR_DESC matMulIntToFloatDesc = { DML_OPERATOR_MATRIX_MULTIPLY_INTEGER_TO_FLOAT, &matMulIntToFloatOperatorDesc};
|
||||
|
||||
std::array<uint32_t, 3> queryKeySlicedTensorShape = {batchSize, sequenceLength, hiddenSize + hiddenSize};
|
||||
TensorDesc queryKeySlicedInputTensorDesc = TensorDesc::ConstructDefaultTensorDesc(dataType, queryKeySlicedTensorShape);
|
||||
DML_TENSOR_DESC namedQueryKeySlicedInputTensorDesc = queryKeySlicedInputTensorDesc.GetDmlDesc();
|
||||
|
||||
std::array<uint32_t, 3> valueSlicedTensorShape = {batchSize, sequenceLength, vHiddenSize};
|
||||
std::array<uint32_t, 3> valueSlicedTensorShape = {batchSize, sequenceLength, hiddenSize};
|
||||
TensorDesc valueSlicedInputTensorDesc = TensorDesc::ConstructDefaultTensorDesc(dataType, valueSlicedTensorShape);
|
||||
DML_TENSOR_DESC namedValueSlicedInputTensorDesc = valueSlicedInputTensorDesc.GetDmlDesc();
|
||||
|
||||
|
|
@ -262,44 +290,18 @@ public:
|
|||
std::array<int32_t, 3> queryKeySliceStrides = {1, 1, 1};
|
||||
|
||||
std::array<uint32_t, 3> valueSliceOffset = {0, 0, 2 * hiddenSize};
|
||||
std::array<uint32_t, 3> valueSliceSize = {batchSize, sequenceLength, vHiddenSize};
|
||||
std::array<uint32_t, 3> valueSliceSize = {batchSize, sequenceLength, hiddenSize};
|
||||
std::array<int32_t, 3> valueSliceStrides = {1, 1, 1};
|
||||
const bool hasSlicedValue = hiddenSize != vHiddenSize;
|
||||
|
||||
// We need to slice the value tensor when its hidden size is different from the query and key
|
||||
DML_SLICE1_OPERATOR_DESC queryKeySlicedOperatorDesc = {};
|
||||
DML_SLICE1_OPERATOR_DESC valueSlicedOperatorDesc = {};
|
||||
// When Q/K/V all have the same hidden size, we just have to transpose it before sending it to MHA
|
||||
DML_ELEMENT_WISE_IDENTITY_OPERATOR_DESC transposeOperatorDesc = {};
|
||||
if (hasSlicedValue)
|
||||
{
|
||||
queryKeySlicedOperatorDesc.InputTensor = &namedMatMulIntToFloatOutputTensorDesc;
|
||||
queryKeySlicedOperatorDesc.OutputTensor = &namedQueryKeySlicedInputTensorDesc;
|
||||
queryKeySlicedOperatorDesc.DimensionCount = gsl::narrow_cast<uint32_t>(queryKeySlicedTensorShape.size());
|
||||
queryKeySlicedOperatorDesc.InputWindowOffsets = queryKeySliceOffset.data();
|
||||
queryKeySlicedOperatorDesc.InputWindowSizes = queryKeySliceSize.data();
|
||||
queryKeySlicedOperatorDesc.InputWindowStrides = queryKeySliceStrides.data();
|
||||
|
||||
valueSlicedOperatorDesc.InputTensor = &namedMatMulIntToFloatOutputTensorDesc;
|
||||
valueSlicedOperatorDesc.OutputTensor = &namedValueSlicedInputTensorDesc;
|
||||
valueSlicedOperatorDesc.DimensionCount = gsl::narrow_cast<uint32_t>(valueSlicedTensorShape.size());
|
||||
valueSlicedOperatorDesc.InputWindowOffsets = valueSliceOffset.data();
|
||||
valueSlicedOperatorDesc.InputWindowSizes = valueSliceSize.data();
|
||||
valueSlicedOperatorDesc.InputWindowStrides = valueSliceStrides.data();
|
||||
transposeOperatorDesc.InputTensor = &namedQueryKeyValueTransposedInputTensorDesc;
|
||||
transposeOperatorDesc.OutputTensor = &namedQueryKeyValueTransposedOutputTensorDesc;
|
||||
|
||||
transposeOperatorDesc.InputTensor = &namedQueryKeyTransposedInputTensorDesc;
|
||||
transposeOperatorDesc.OutputTensor = &namedQueryKeyTransposedOutputTensorDesc;
|
||||
}
|
||||
else
|
||||
{
|
||||
// When Q/K/V all have the same hidden size, we just have to transpose it before sending it to MHA
|
||||
transposeOperatorDesc.InputTensor = &namedQueryKeyValueTransposedInputTensorDesc;
|
||||
transposeOperatorDesc.OutputTensor = &namedQueryKeyValueTransposedOutputTensorDesc;
|
||||
}
|
||||
const DML_OPERATOR_DESC queryKeySlicedDesc = { DML_OPERATOR_SLICE1, &queryKeySlicedOperatorDesc};
|
||||
const DML_OPERATOR_DESC valueSlicedDesc = { DML_OPERATOR_SLICE1, &valueSlicedOperatorDesc};
|
||||
const DML_OPERATOR_DESC transposedDesc = { DML_OPERATOR_ELEMENT_WISE_IDENTITY, &transposeOperatorDesc};
|
||||
|
||||
std::array<uint32_t, 4> maskSliceOutputShape {batchSize, numHeads, sequenceLength, sequenceLength};
|
||||
std::array<uint32_t, 4> maskSliceOutputShape = {batchSize, numHeads, sequenceLength, sequenceLength};
|
||||
std::array<int32_t, 4> maskSliceStrides = {1, 1, 1, 1};
|
||||
std::array<uint32_t, 4> maskSliceOffsets = {0, 0, 0, 0};
|
||||
TensorDesc maskSliceOutputTensorDesc;
|
||||
|
|
@ -319,12 +321,81 @@ public:
|
|||
}
|
||||
const DML_OPERATOR_DESC maskSlicedDesc = { DML_OPERATOR_SLICE1, &maskSlicedOperatorDesc};
|
||||
|
||||
DML_MULTIHEAD_ATTENTION_OPERATOR_DESC mhaOperatorDesc = {};
|
||||
mhaOperatorDesc.ValueTensor = hasSlicedValue ? &namedValueSlicedInputTensorDesc : nullptr;
|
||||
mhaOperatorDesc.StackedQueryKeyTensor = hasSlicedValue ? &namedQueryKeyTransposedOutputTensorDesc : nullptr;
|
||||
mhaOperatorDesc.StackedQueryKeyValueTensor = hasSlicedValue ? nullptr : &namedQueryKeyValueTransposedOutputTensorDesc;
|
||||
// We need to slice Past to get PastValue and PastKey tensors for MHA
|
||||
std::array<uint32_t, 5> pastKeyOutputShape = {1, batchSize, numHeads, pastSequenceLength, headSize};
|
||||
std::array<int32_t, 5> pastKeyStrides = {1, 1, 1, 1, 1};
|
||||
std::array<uint32_t, 5> pastKeyOffsets = {0, 0, 0, 0, 0};
|
||||
TensorDesc pastKeyOutputTensorDesc;
|
||||
DML_TENSOR_DESC namedPastKeyOutputTensorDesc;
|
||||
|
||||
std::array<uint32_t, 5> pastValueOutputShape = {1, batchSize, numHeads, pastSequenceLength, headSize};
|
||||
std::array<int32_t, 5> pastValueStrides = {1, 1, 1, 1, 1};
|
||||
std::array<uint32_t, 5> pastValueOffsets = {1, 0, 0, 0, 0};
|
||||
TensorDesc pastValueOutputTensorDesc;
|
||||
DML_TENSOR_DESC namedPastValueOutputTensorDesc;
|
||||
|
||||
if (hasMaxSequenceMask)
|
||||
DML_SLICE1_OPERATOR_DESC pastKeySlicedOperatorDesc = {};
|
||||
DML_SLICE1_OPERATOR_DESC pastValueSlicedOperatorDesc = {};
|
||||
|
||||
if (hasPast)
|
||||
{
|
||||
pastKeyOutputTensorDesc = TensorDesc::ConstructDefaultTensorDesc(pastTensorDataType, pastKeyOutputShape);
|
||||
namedPastKeyOutputTensorDesc = pastKeyOutputTensorDesc.GetDmlDesc();
|
||||
pastKeySlicedOperatorDesc.InputTensor = &inputDescs[pastIndex];
|
||||
pastKeySlicedOperatorDesc.OutputTensor = &namedPastKeyOutputTensorDesc;
|
||||
pastKeySlicedOperatorDesc.DimensionCount = gsl::narrow_cast<uint32_t>(pastKeyOutputShape.size());
|
||||
pastKeySlicedOperatorDesc.InputWindowOffsets = pastKeyOffsets.data();
|
||||
pastKeySlicedOperatorDesc.InputWindowSizes = pastKeyOutputShape.data();
|
||||
pastKeySlicedOperatorDesc.InputWindowStrides = pastKeyStrides.data();
|
||||
|
||||
pastValueOutputTensorDesc = TensorDesc::ConstructDefaultTensorDesc(pastTensorDataType, pastValueOutputShape);
|
||||
namedPastValueOutputTensorDesc = pastValueOutputTensorDesc.GetDmlDesc();
|
||||
pastValueSlicedOperatorDesc.InputTensor = &inputDescs[pastIndex];
|
||||
pastValueSlicedOperatorDesc.OutputTensor = &namedPastValueOutputTensorDesc;
|
||||
pastValueSlicedOperatorDesc.DimensionCount = gsl::narrow_cast<uint32_t>(pastValueOutputShape.size());
|
||||
pastValueSlicedOperatorDesc.InputWindowOffsets = pastValueOffsets.data();
|
||||
pastValueSlicedOperatorDesc.InputWindowSizes = pastValueOutputShape.data();
|
||||
pastValueSlicedOperatorDesc.InputWindowStrides = pastValueStrides.data();
|
||||
}
|
||||
|
||||
const DML_OPERATOR_DESC pastKeySlicedDesc = { DML_OPERATOR_SLICE1, &pastKeySlicedOperatorDesc};
|
||||
const DML_OPERATOR_DESC pastValueSlicedDesc = { DML_OPERATOR_SLICE1, &pastValueSlicedOperatorDesc};
|
||||
|
||||
// Causal Mask: [pastSequenceLength, pastSequenceLength + 1 ... pastSequenceLength + batchSize -1]
|
||||
// passed to MHA as maskIndex Tensor when unidirectional == 1
|
||||
std::array<uint32_t, 2> causalMaskOutputShape = {1, batchSize};
|
||||
TensorDesc causalMaskTensorDesc;
|
||||
DML_FILL_VALUE_SEQUENCE_OPERATOR_DESC causalMaskOperatorDesc = {};
|
||||
DML_TENSOR_DESC namedcausalMaskTensorDesc;
|
||||
|
||||
if (unidirectional && !hasMask)
|
||||
{
|
||||
causalMaskTensorDesc = TensorDesc::ConstructDefaultTensorDesc(MLOperatorTensorDataType::Int32, causalMaskOutputShape);
|
||||
namedcausalMaskTensorDesc = causalMaskTensorDesc.GetDmlDesc();
|
||||
causalMaskOperatorDesc.ValueDataType = DML_TENSOR_DATA_TYPE_INT32;
|
||||
causalMaskOperatorDesc.ValueStart.Int32 = pastSequenceLength;
|
||||
causalMaskOperatorDesc.ValueDelta.Int32 = 1;
|
||||
causalMaskOperatorDesc.OutputTensor = &namedcausalMaskTensorDesc;
|
||||
|
||||
maskType = DML_MULTIHEAD_ATTENTION_MASK_TYPE_KEY_SEQUENCE_LENGTH;
|
||||
}
|
||||
DML_OPERATOR_DESC causalMaskDesc = { DML_OPERATOR_FILL_VALUE_SEQUENCE, &causalMaskOperatorDesc };
|
||||
|
||||
DML_MULTIHEAD_ATTENTION_OPERATOR_DESC mhaOperatorDesc = {};
|
||||
std::array<uint32_t, 5> presentKeyOutputShape = {1, batchSize, numHeads, pastSequenceLength + sequenceLength, headSize};
|
||||
std::array<uint32_t, 5> presentValueOutputShape = {1, batchSize, numHeads, pastSequenceLength + sequenceLength, headSize};
|
||||
TensorDesc presentKeyTensorDesc;
|
||||
TensorDesc presentValueTensorDesc;
|
||||
DML_TENSOR_DESC namedPresentKeyOutputTensorDesc;
|
||||
DML_TENSOR_DESC namedPresentValueOutputTensorDesc;
|
||||
|
||||
mhaOperatorDesc.StackedQueryKeyValueTensor = &namedQueryKeyValueTransposedOutputTensorDesc;
|
||||
|
||||
if (unidirectional && !hasMask)
|
||||
{
|
||||
mhaOperatorDesc.MaskTensor = &namedcausalMaskTensorDesc;
|
||||
}
|
||||
else if (hasMaxSequenceMask)
|
||||
{
|
||||
mhaOperatorDesc.MaskTensor = &namedMaskSliceOutputTensorDesc;
|
||||
}
|
||||
|
|
@ -339,8 +410,35 @@ public:
|
|||
mhaOperatorDesc.MaskFilterValue = kernelCreationContext.GetOptionalAttribute<float>(AttrName::MaskFilterValue, -10'000.0f);
|
||||
mhaOperatorDesc.HeadCount = numHeads;
|
||||
mhaOperatorDesc.MaskType = maskType;
|
||||
if (hasPast)
|
||||
{
|
||||
presentKeyTensorDesc = TensorDesc::ConstructDefaultTensorDesc(presentTensorDataType, presentKeyOutputShape);
|
||||
namedPresentKeyOutputTensorDesc = presentKeyTensorDesc.GetDmlDesc();
|
||||
presentValueTensorDesc = TensorDesc::ConstructDefaultTensorDesc(presentTensorDataType, presentValueOutputShape);
|
||||
namedPresentValueOutputTensorDesc = presentValueTensorDesc.GetDmlDesc();
|
||||
mhaOperatorDesc.PastKeyTensor = &namedPastKeyOutputTensorDesc;
|
||||
mhaOperatorDesc.PastValueTensor = &namedPastValueOutputTensorDesc;
|
||||
mhaOperatorDesc.OutputPresentKeyTensor = &namedPresentKeyOutputTensorDesc;
|
||||
mhaOperatorDesc.OutputPresentValueTensor = &namedPresentValueOutputTensorDesc;
|
||||
}
|
||||
|
||||
const DML_OPERATOR_DESC mhaDesc = { DML_OPERATOR_MULTIHEAD_ATTENTION, &mhaOperatorDesc };
|
||||
|
||||
DML_JOIN_OPERATOR_DESC presentKeyValueJoinOperatorDesc = {};
|
||||
std::vector<DML_TENSOR_DESC> joinInputDesc;
|
||||
|
||||
if (hasPast)
|
||||
{
|
||||
joinInputDesc.push_back(namedPresentKeyOutputTensorDesc);
|
||||
joinInputDesc.push_back(namedPresentValueOutputTensorDesc);
|
||||
presentKeyValueJoinOperatorDesc.InputCount = gsl::narrow_cast<uint32_t>(joinInputDesc.size());
|
||||
presentKeyValueJoinOperatorDesc.InputTensors = joinInputDesc.data();
|
||||
presentKeyValueJoinOperatorDesc.OutputTensor = &outputDescs[presentIndex];
|
||||
presentKeyValueJoinOperatorDesc.Axis = gsl::narrow_cast<uint32_t>(0);
|
||||
}
|
||||
|
||||
DML_OPERATOR_DESC presentKeyValueJoinDesc = { DML_OPERATOR_JOIN, &presentKeyValueJoinOperatorDesc };
|
||||
|
||||
// Construct the graph
|
||||
std::vector<DML_INPUT_GRAPH_EDGE_DESC> inputEdges;
|
||||
std::vector<DML_INTERMEDIATE_GRAPH_EDGE_DESC> intermediateEdges;
|
||||
|
|
@ -355,26 +453,10 @@ public:
|
|||
const uint32_t matMulIntToFloatNodeIndex = currentNodeIndex++;
|
||||
const uint32_t mhaNodeIndex = currentNodeIndex++;
|
||||
|
||||
uint32_t valueSliceNodeIndex = 0;
|
||||
uint32_t queryKeySliceNodeIndex = 0;
|
||||
uint32_t queryKeyTransposedNodeIndex = 0;
|
||||
uint32_t queryKeyValueTransposedNodeIndex = 0;
|
||||
if (hasSlicedValue)
|
||||
{
|
||||
opDescs.push_back(&queryKeySlicedDesc);
|
||||
queryKeySliceNodeIndex = currentNodeIndex++;
|
||||
|
||||
opDescs.push_back(&valueSlicedDesc);
|
||||
valueSliceNodeIndex = currentNodeIndex++;
|
||||
|
||||
opDescs.push_back(&transposedDesc);
|
||||
queryKeyTransposedNodeIndex = currentNodeIndex++;
|
||||
}
|
||||
else
|
||||
{
|
||||
opDescs.push_back(&transposedDesc);
|
||||
queryKeyValueTransposedNodeIndex = currentNodeIndex++;
|
||||
}
|
||||
opDescs.push_back(&transposedDesc);
|
||||
queryKeyValueTransposedNodeIndex = currentNodeIndex++;
|
||||
|
||||
uint32_t maskSliceNodeIndex = 0;
|
||||
if (hasMaxSequenceMask)
|
||||
|
|
@ -383,6 +465,26 @@ public:
|
|||
maskSliceNodeIndex = currentNodeIndex++;
|
||||
}
|
||||
|
||||
uint32_t pastKeySliceNodeIndex = 0;
|
||||
uint32_t pastValueSliceNodeIndex = 0;
|
||||
uint32_t concatNodeIndex = 0;
|
||||
if (hasPast)
|
||||
{
|
||||
opDescs.push_back(&pastKeySlicedDesc);
|
||||
pastKeySliceNodeIndex = currentNodeIndex++;
|
||||
opDescs.push_back(&pastValueSlicedDesc);
|
||||
pastValueSliceNodeIndex = currentNodeIndex++;
|
||||
opDescs.push_back(&presentKeyValueJoinDesc);
|
||||
concatNodeIndex = currentNodeIndex++;
|
||||
}
|
||||
|
||||
uint32_t causalMaskNodeIndex = 0;
|
||||
if (unidirectional && !hasMask)
|
||||
{
|
||||
opDescs.push_back(&causalMaskDesc);
|
||||
causalMaskNodeIndex = currentNodeIndex++;
|
||||
}
|
||||
|
||||
DML_INPUT_GRAPH_EDGE_DESC inputToMatMulIntToFloatEdge = {};
|
||||
inputToMatMulIntToFloatEdge.GraphInputIndex = InputIndex::inputIndex;
|
||||
inputToMatMulIntToFloatEdge.ToNodeIndex = matMulIntToFloatNodeIndex;
|
||||
|
|
@ -462,69 +564,89 @@ public:
|
|||
inputEdges.push_back(maskToMhaEdge);
|
||||
}
|
||||
}
|
||||
|
||||
if (hasSlicedValue)
|
||||
else if (unidirectional)
|
||||
{
|
||||
// We need to slice QK and V, and transpose QK
|
||||
DML_INTERMEDIATE_GRAPH_EDGE_DESC matMulIntToFloatToQueryKeySliceEdge = {};
|
||||
matMulIntToFloatToQueryKeySliceEdge.FromNodeIndex = matMulIntToFloatNodeIndex;
|
||||
matMulIntToFloatToQueryKeySliceEdge.FromNodeOutputIndex = 0;
|
||||
matMulIntToFloatToQueryKeySliceEdge.ToNodeIndex = queryKeySliceNodeIndex;
|
||||
matMulIntToFloatToQueryKeySliceEdge.ToNodeInputIndex = 0;
|
||||
intermediateEdges.push_back(matMulIntToFloatToQueryKeySliceEdge);
|
||||
|
||||
DML_INTERMEDIATE_GRAPH_EDGE_DESC queryKeySliceToTransposeEdge = {};
|
||||
queryKeySliceToTransposeEdge.FromNodeIndex = queryKeySliceNodeIndex;
|
||||
queryKeySliceToTransposeEdge.FromNodeOutputIndex = 0;
|
||||
queryKeySliceToTransposeEdge.ToNodeIndex = queryKeyTransposedNodeIndex;
|
||||
queryKeySliceToTransposeEdge.ToNodeInputIndex = 0;
|
||||
intermediateEdges.push_back(queryKeySliceToTransposeEdge);
|
||||
|
||||
DML_INTERMEDIATE_GRAPH_EDGE_DESC queryKeyTransposedToMhaEdge = {};
|
||||
queryKeyTransposedToMhaEdge.FromNodeIndex = queryKeyTransposedNodeIndex;
|
||||
queryKeyTransposedToMhaEdge.FromNodeOutputIndex = 0;
|
||||
queryKeyTransposedToMhaEdge.ToNodeIndex = mhaNodeIndex;
|
||||
queryKeyTransposedToMhaEdge.ToNodeInputIndex = mhaStackedQueryKeyIndex;
|
||||
intermediateEdges.push_back(queryKeyTransposedToMhaEdge);
|
||||
|
||||
DML_INTERMEDIATE_GRAPH_EDGE_DESC matMulIntToFloatToValueSliceEdge = {};
|
||||
matMulIntToFloatToValueSliceEdge.FromNodeIndex = matMulIntToFloatNodeIndex;
|
||||
matMulIntToFloatToValueSliceEdge.FromNodeOutputIndex = 0;
|
||||
matMulIntToFloatToValueSliceEdge.ToNodeIndex = valueSliceNodeIndex;
|
||||
matMulIntToFloatToValueSliceEdge.ToNodeInputIndex = 0;
|
||||
intermediateEdges.push_back(matMulIntToFloatToValueSliceEdge);
|
||||
|
||||
DML_INTERMEDIATE_GRAPH_EDGE_DESC valueSliceToMhaEdge = {};
|
||||
valueSliceToMhaEdge.FromNodeIndex = valueSliceNodeIndex;
|
||||
valueSliceToMhaEdge.FromNodeOutputIndex = 0;
|
||||
valueSliceToMhaEdge.ToNodeIndex = mhaNodeIndex;
|
||||
valueSliceToMhaEdge.ToNodeInputIndex = mhaValueIndex;
|
||||
intermediateEdges.push_back(valueSliceToMhaEdge);
|
||||
DML_INTERMEDIATE_GRAPH_EDGE_DESC causalMaskToMhaEdge = {};
|
||||
causalMaskToMhaEdge.FromNodeIndex = causalMaskNodeIndex;
|
||||
causalMaskToMhaEdge.FromNodeOutputIndex = 0;
|
||||
causalMaskToMhaEdge.ToNodeIndex = mhaNodeIndex;
|
||||
causalMaskToMhaEdge.ToNodeInputIndex = mhaMaskIndex ;
|
||||
intermediateEdges.push_back(causalMaskToMhaEdge);
|
||||
}
|
||||
else
|
||||
|
||||
if (hasPast)
|
||||
{
|
||||
DML_INTERMEDIATE_GRAPH_EDGE_DESC matMulIntToFloatToQueryKeyValueTransposeEdge = {};
|
||||
matMulIntToFloatToQueryKeyValueTransposeEdge.FromNodeIndex = matMulIntToFloatNodeIndex;
|
||||
matMulIntToFloatToQueryKeyValueTransposeEdge.FromNodeOutputIndex = 0;
|
||||
matMulIntToFloatToQueryKeyValueTransposeEdge.ToNodeIndex = queryKeyValueTransposedNodeIndex;
|
||||
matMulIntToFloatToQueryKeyValueTransposeEdge.ToNodeInputIndex = 0;
|
||||
intermediateEdges.push_back(matMulIntToFloatToQueryKeyValueTransposeEdge);
|
||||
DML_INPUT_GRAPH_EDGE_DESC pastToPastKeySliceEdge = {};
|
||||
pastToPastKeySliceEdge.GraphInputIndex = InputIndex::pastIndex;
|
||||
pastToPastKeySliceEdge.ToNodeIndex = pastKeySliceNodeIndex;
|
||||
pastToPastKeySliceEdge.ToNodeInputIndex = 0;
|
||||
inputEdges.push_back(pastToPastKeySliceEdge);
|
||||
|
||||
// All we need to do here is transpose the stacked QKV tensor into something DML supports
|
||||
DML_INTERMEDIATE_GRAPH_EDGE_DESC queryKeyValueTransposedToMhaEdge = {};
|
||||
queryKeyValueTransposedToMhaEdge.FromNodeIndex = queryKeyValueTransposedNodeIndex;
|
||||
queryKeyValueTransposedToMhaEdge.FromNodeOutputIndex = 0;
|
||||
queryKeyValueTransposedToMhaEdge.ToNodeIndex = mhaNodeIndex;
|
||||
queryKeyValueTransposedToMhaEdge.ToNodeInputIndex = mhaStackedQueryKeyValueIndex;
|
||||
intermediateEdges.push_back(queryKeyValueTransposedToMhaEdge);
|
||||
DML_INPUT_GRAPH_EDGE_DESC pastToPastValueSliceEdge = {};
|
||||
pastToPastValueSliceEdge.GraphInputIndex = InputIndex::pastIndex;
|
||||
pastToPastValueSliceEdge.ToNodeIndex = pastValueSliceNodeIndex;
|
||||
pastToPastValueSliceEdge.ToNodeInputIndex = 0;
|
||||
inputEdges.push_back(pastToPastValueSliceEdge);
|
||||
|
||||
DML_INTERMEDIATE_GRAPH_EDGE_DESC pastKeyToMhaEdge = {};
|
||||
pastKeyToMhaEdge.FromNodeIndex = pastKeySliceNodeIndex;
|
||||
pastKeyToMhaEdge.FromNodeOutputIndex = 0;
|
||||
pastKeyToMhaEdge.ToNodeIndex = mhaNodeIndex;
|
||||
pastKeyToMhaEdge.ToNodeInputIndex = mhaPastKeyIndex;
|
||||
intermediateEdges.push_back(pastKeyToMhaEdge);
|
||||
|
||||
DML_INTERMEDIATE_GRAPH_EDGE_DESC pastValueToMhaEdge = {};
|
||||
pastValueToMhaEdge.FromNodeIndex = pastValueSliceNodeIndex;
|
||||
pastValueToMhaEdge.FromNodeOutputIndex = 0;
|
||||
pastValueToMhaEdge.ToNodeIndex = mhaNodeIndex;
|
||||
pastValueToMhaEdge.ToNodeInputIndex = mhaPastValueIndex;
|
||||
intermediateEdges.push_back(pastValueToMhaEdge);
|
||||
|
||||
DML_INTERMEDIATE_GRAPH_EDGE_DESC presentKeyToConcatEdge = {};
|
||||
presentKeyToConcatEdge.FromNodeIndex = mhaNodeIndex;
|
||||
presentKeyToConcatEdge.FromNodeOutputIndex = mhaPresentKeyIndex;
|
||||
presentKeyToConcatEdge.ToNodeIndex = concatNodeIndex;
|
||||
presentKeyToConcatEdge.ToNodeInputIndex = 0;
|
||||
intermediateEdges.push_back(presentKeyToConcatEdge);
|
||||
|
||||
DML_INTERMEDIATE_GRAPH_EDGE_DESC presentValueToConcatEdge = {};
|
||||
presentValueToConcatEdge.FromNodeIndex = mhaNodeIndex;
|
||||
presentValueToConcatEdge.FromNodeOutputIndex = mhaPresentValueIndex;
|
||||
presentValueToConcatEdge.ToNodeIndex = concatNodeIndex;
|
||||
presentValueToConcatEdge.ToNodeInputIndex = 1;
|
||||
intermediateEdges.push_back(presentValueToConcatEdge);
|
||||
}
|
||||
|
||||
DML_INTERMEDIATE_GRAPH_EDGE_DESC matMulIntToFloatToQueryKeyValueTransposeEdge = {};
|
||||
matMulIntToFloatToQueryKeyValueTransposeEdge.FromNodeIndex = matMulIntToFloatNodeIndex;
|
||||
matMulIntToFloatToQueryKeyValueTransposeEdge.FromNodeOutputIndex = 0;
|
||||
matMulIntToFloatToQueryKeyValueTransposeEdge.ToNodeIndex = queryKeyValueTransposedNodeIndex;
|
||||
matMulIntToFloatToQueryKeyValueTransposeEdge.ToNodeInputIndex = 0;
|
||||
intermediateEdges.push_back(matMulIntToFloatToQueryKeyValueTransposeEdge);
|
||||
|
||||
// All we need to do here is transpose the stacked QKV tensor into something DML supports
|
||||
DML_INTERMEDIATE_GRAPH_EDGE_DESC queryKeyValueTransposedToMhaEdge = {};
|
||||
queryKeyValueTransposedToMhaEdge.FromNodeIndex = queryKeyValueTransposedNodeIndex;
|
||||
queryKeyValueTransposedToMhaEdge.FromNodeOutputIndex = 0;
|
||||
queryKeyValueTransposedToMhaEdge.ToNodeIndex = mhaNodeIndex;
|
||||
queryKeyValueTransposedToMhaEdge.ToNodeInputIndex = mhaStackedQueryKeyValueIndex;
|
||||
intermediateEdges.push_back(queryKeyValueTransposedToMhaEdge);
|
||||
|
||||
DML_OUTPUT_GRAPH_EDGE_DESC mhaToOutputEdge = {};
|
||||
mhaToOutputEdge.FromNodeIndex = mhaNodeIndex;
|
||||
mhaToOutputEdge.FromNodeOutputIndex = 0;
|
||||
mhaToOutputEdge.GraphOutputIndex = 0;
|
||||
mhaToOutputEdge.FromNodeOutputIndex = mhaOutputIndex;
|
||||
mhaToOutputEdge.GraphOutputIndex = OutputIndex::outputIndex;
|
||||
outputEdges.push_back(mhaToOutputEdge);
|
||||
|
||||
if (hasPast)
|
||||
{
|
||||
DML_OUTPUT_GRAPH_EDGE_DESC concatToOutputEdge = {};
|
||||
concatToOutputEdge.FromNodeIndex = concatNodeIndex;
|
||||
concatToOutputEdge.FromNodeOutputIndex = 0;
|
||||
concatToOutputEdge.GraphOutputIndex = OutputIndex::presentIndex;
|
||||
outputEdges.push_back(concatToOutputEdge);
|
||||
}
|
||||
|
||||
MLOperatorGraphDesc operatorGraphDesc = {};
|
||||
operatorGraphDesc.inputEdgeCount = gsl::narrow_cast<uint32_t>(inputEdges.size());
|
||||
operatorGraphDesc.inputEdges = inputEdges.data();
|
||||
|
|
@ -542,21 +664,10 @@ public:
|
|||
void CALLBACK QueryQAttention(IMLOperatorSupportQueryContextPrivate* context, /*out*/ bool* isSupported)
|
||||
{
|
||||
*isSupported = false;
|
||||
// `past` input tensor is not supported yet
|
||||
if (context->IsInputValid(8))
|
||||
{
|
||||
return;
|
||||
}
|
||||
|
||||
// `present` output tensor is not supported yet
|
||||
if (context->IsOutputValid(1))
|
||||
{
|
||||
return;
|
||||
}
|
||||
|
||||
// `unidirectional == 1` is not supported yet
|
||||
// `unidirectional == 1` with Mask Tensor is not supported yet
|
||||
MLOperatorAttributes attributes(context);
|
||||
if (attributes.GetOptionalAttribute<int32_t>(AttrName::Unidirectional, 0) != 0)
|
||||
if (attributes.GetOptionalAttribute<int32_t>(AttrName::Unidirectional, 0) != 0 && context->IsInputValid(5))
|
||||
{
|
||||
return;
|
||||
}
|
||||
|
|
@ -567,6 +678,12 @@ void CALLBACK QueryQAttention(IMLOperatorSupportQueryContextPrivate* context, /*
|
|||
return;
|
||||
}
|
||||
|
||||
// `past_present_share_buffer == 1` is not supported yet
|
||||
if (attributes.GetOptionalAttribute<int32_t>(AttrName::PastPresentShareBuffer, 0) != 0)
|
||||
{
|
||||
return;
|
||||
}
|
||||
|
||||
*isSupported = true;
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -107,6 +107,7 @@ namespace AttrName
|
|||
static constexpr const char* QkvHiddenSizes = "qkv_hidden_sizes";
|
||||
static constexpr const char* Unidirectional = "unidirectional";
|
||||
static constexpr const char* NumHeads = "num_heads";
|
||||
static constexpr const char* PastPresentShareBuffer = "past_present_share_buffer";
|
||||
|
||||
static constexpr const char* FusedActivation = "fused_activation";
|
||||
static constexpr const char* FusedActivationDomain = "fused_activation_domain";
|
||||
|
|
|
|||
|
|
@ -2744,6 +2744,48 @@ namespace OperatorHelper
|
|||
m_qkvHiddenSizes = kernelInformation.GetAttributes().GetOptionalAttributeVectorInt32(AttrName::QkvHiddenSizes);
|
||||
}
|
||||
|
||||
std::vector<EdgeShapes> QAttentionHelper::GetOutputShapes(const MLShapeInferenceContext& shapeInfo) const
|
||||
{
|
||||
ML_CHECK_VALID_ARGUMENT(shapeInfo.GetInputCount() >= 5);
|
||||
|
||||
auto queryShape = shapeInfo.GetInputTensorShape(0);
|
||||
ML_CHECK_VALID_ARGUMENT(queryShape.size() == 3);
|
||||
|
||||
auto weightShape = shapeInfo.GetInputTensorShape(1);
|
||||
ML_CHECK_VALID_ARGUMENT(weightShape.size() == 2);
|
||||
ML_CHECK_VALID_ARGUMENT(weightShape[1] % 3 == 0);
|
||||
|
||||
const uint32_t batchSize = queryShape[0];
|
||||
const uint32_t sequenceLength = queryShape[1];
|
||||
const uint32_t hiddenSize = weightShape[1] / 3;
|
||||
const uint32_t headSize = hiddenSize / m_numHeads;
|
||||
|
||||
std::vector<EdgeShapes> outputShapes(2);
|
||||
|
||||
outputShapes[0] = EdgeShapes({batchSize, sequenceLength, hiddenSize});
|
||||
|
||||
uint32_t totalSequenceLength = sequenceLength;
|
||||
if (shapeInfo.IsInputValid(8))
|
||||
{
|
||||
ML_CHECK_VALID_ARGUMENT(shapeInfo.GetInputTensorDimensionCount(8) == 5);
|
||||
const uint32_t pastSequenceLength = shapeInfo.GetInputTensorShape(8)[3];
|
||||
totalSequenceLength += pastSequenceLength;
|
||||
}
|
||||
|
||||
if (shapeInfo.IsOutputValid(1))
|
||||
{
|
||||
ML_CHECK_VALID_ARGUMENT(shapeInfo.IsInputValid(8));
|
||||
outputShapes[1] = EdgeShapes({2, batchSize, m_numHeads, totalSequenceLength, headSize});
|
||||
}
|
||||
|
||||
return outputShapes;
|
||||
}
|
||||
|
||||
void QAttentionHelper::Initialize(const IKernelInformationAdapter& kernelInformation)
|
||||
{
|
||||
m_numHeads = gsl::narrow_cast<uint32_t>(kernelInformation.GetAttributes().GetAttribute<int64_t>(AttrName::NumHeads));
|
||||
}
|
||||
|
||||
std::vector<EdgeShapes> SkipLayerNormHelper::GetOutputShapes(const MLShapeInferenceContext& shapeInfo) const
|
||||
{
|
||||
ML_CHECK_VALID_ARGUMENT(shapeInfo.GetInputCount() >= 3);
|
||||
|
|
|
|||
|
|
@ -1490,6 +1490,22 @@ private:
|
|||
std::vector<int32_t> m_qkvHiddenSizes;
|
||||
};
|
||||
|
||||
class QAttentionHelper
|
||||
{
|
||||
public:
|
||||
template <typename Info_t, typename Shape_t>
|
||||
QAttentionHelper(const Info_t& info, const Shape_t& shapeInfo)
|
||||
{
|
||||
Initialize(KernelInformationAdapter(info));
|
||||
}
|
||||
|
||||
std::vector<EdgeShapes> GetOutputShapes(const MLShapeInferenceContext& shapeInfo) const;
|
||||
|
||||
private:
|
||||
void Initialize(const IKernelInformationAdapter& kernelInformation);
|
||||
uint32_t m_numHeads;
|
||||
};
|
||||
|
||||
class SkipLayerNormHelper
|
||||
{
|
||||
public:
|
||||
|
|
@ -1630,7 +1646,7 @@ using ShapeInferenceHelper_Affine = GetOutputShapeAsInputShapeHelper;
|
|||
using ShapeInferenceHelper_QuantizeLinear = GetOutputShapeAsInputShapeHelper;
|
||||
using ShapeInferenceHelper_DequantizeLinear = GetOutputShapeAsInputShapeHelper;
|
||||
using ShapeInferenceHelper_QLinearSigmoid = GetOutputShapeAsInputShapeHelper;
|
||||
using ShapeInferenceHelper_QAttention = AttentionHelper;
|
||||
using ShapeInferenceHelper_QAttention = QAttentionHelper;
|
||||
using ShapeInferenceHelper_Attention = AttentionHelper;
|
||||
using ShapeInferenceHelper_MultiHeadAttention = MultiHeadAttentionHelper;
|
||||
using ShapeInferenceHelper_Sign = GetBroadcastedOutputShapeHelper;
|
||||
|
|
|
|||
Loading…
Reference in a new issue