mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-25 02:50:42 +00:00
QAttention calls into MatMulIntToFloat instead of Dequantize+GEMM (#16851)
### Description Update QAttention calling into MatMulIntToFloat instead of Dequantize+GEMM to enable more metacommand path.
This commit is contained in:
parent
c19e4c02e2
commit
cbdd0bb729
2 changed files with 81 additions and 145 deletions
|
|
@ -105,7 +105,7 @@ public:
|
|||
matrixMultiplyIntergerToFloatOperatorDesc.BiasTensor = hasBias? &inputDescs[OnnxInputIndex::Bias] : nullptr;
|
||||
matrixMultiplyIntergerToFloatOperatorDesc.OutputTensor = &outputDescs[0];
|
||||
|
||||
const DML_OPERATOR_DESC opDesc2{ (DML_OPERATOR_TYPE)DML_OPERATOR_MATRIX_MULTIPLY_INTEGER_TO_FLOAT, &matrixMultiplyIntergerToFloatOperatorDesc};
|
||||
const DML_OPERATOR_DESC opDesc2{ static_cast<DML_OPERATOR_TYPE>(DML_OPERATOR_MATRIX_MULTIPLY_INTEGER_TO_FLOAT), &matrixMultiplyIntergerToFloatOperatorDesc};
|
||||
|
||||
MLOperatorGraphDesc operatorGraphDesc = {};
|
||||
std::vector<const DML_OPERATOR_DESC*> opDescs{&opDesc1, &opDesc2};
|
||||
|
|
|
|||
|
|
@ -10,9 +10,7 @@ Abbreviations: B is batch_size, S is sequence_length, W is hidden_size
|
|||
|
||||
M A B C // M, A, B, and C are Inputs
|
||||
| | | /
|
||||
| Dequantize /
|
||||
| \ | /
|
||||
| Gemm
|
||||
| MatMulIntToFloat
|
||||
| / | \
|
||||
| / | \
|
||||
| / | \
|
||||
|
|
@ -133,23 +131,9 @@ public:
|
|||
kernelCreationContext.GetInputEdgeDescription(weightsIndex).tensorDataType,
|
||||
desiredWeightTensorShape,
|
||||
weightTensorShape);
|
||||
m_inputTensorDescs[inputScaleIndex] = TensorDesc::ConstructBroadcastedTensorDesc(
|
||||
kernelCreationContext.GetInputEdgeDescription(inputScaleIndex).tensorDataType,
|
||||
inputTensorShape,
|
||||
m_inputTensorDescs[inputScaleIndex].GetSizes());
|
||||
m_inputTensorDescs[inputZeroPointIndex] = TensorDesc::ConstructBroadcastedTensorDesc(
|
||||
kernelCreationContext.GetInputEdgeDescription(inputZeroPointIndex).tensorDataType,
|
||||
inputTensorShape,
|
||||
m_inputTensorDescs[inputZeroPointIndex].GetSizes());
|
||||
m_inputTensorDescs[weightScaleIndex] = TensorDesc::ConstructBroadcastedTensorDesc(
|
||||
kernelCreationContext.GetInputEdgeDescription(weightScaleIndex).tensorDataType,
|
||||
desiredWeightTensorShape,
|
||||
m_inputTensorDescs[weightScaleIndex].GetSizes());
|
||||
m_inputTensorDescs[weightZeroPointIndex] = TensorDesc::ConstructBroadcastedTensorDesc(
|
||||
kernelCreationContext.GetInputEdgeDescription(weightZeroPointIndex).tensorDataType,
|
||||
desiredWeightTensorShape,
|
||||
m_inputTensorDescs[weightZeroPointIndex].GetSizes());
|
||||
|
||||
uint32_t desiredBiasTensorShape[3] = {batchSize, sequenceLength, hiddenSize + hiddenSize + vHiddenSize};
|
||||
|
||||
if (hasBias)
|
||||
{
|
||||
auto biasTensorShape = m_inputTensorDescs[biasIndex].GetSizes();
|
||||
|
|
@ -190,7 +174,7 @@ public:
|
|||
hasMaxSequenceMask = true;
|
||||
ML_CHECK_VALID_ARGUMENT(maskIndexTensorShape[2] == maskIndexTensorShape[3]);
|
||||
const uint32_t maxSequenceLength = maskIndexTensorShape[2];
|
||||
uint32_t desiredMaskIndexShape[4] {batchSize, numHeads, maxSequenceLength, maxSequenceLength};
|
||||
uint32_t desiredMaskIndexShape[4] = {batchSize, numHeads, maxSequenceLength, maxSequenceLength};
|
||||
maskTensorDataType = kernelCreationContext.GetInputEdgeDescription(maskIndex).tensorDataType;
|
||||
m_inputTensorDescs[maskIndex] = TensorDesc::ConstructBroadcastedTensorDesc(maskTensorDataType, desiredMaskIndexShape, reshapedMaskIndexTensorShape);
|
||||
}
|
||||
|
|
@ -198,65 +182,36 @@ public:
|
|||
{
|
||||
uint32_t maskIndexDimensionCount = gsl::narrow_cast<uint32_t>(maskIndexTensorShape.size());
|
||||
reshapedMaskIndexTensorShape.insert(reshapedMaskIndexTensorShape.begin() + 1, 4 - maskIndexDimensionCount, 1);
|
||||
uint32_t desiredMaskIndexShape[4] {batchSize, numHeads, sequenceLength, sequenceLength};
|
||||
uint32_t desiredMaskIndexShape[4] = {batchSize, numHeads, sequenceLength, sequenceLength};
|
||||
maskTensorDataType = kernelCreationContext.GetInputEdgeDescription(maskIndex).tensorDataType;
|
||||
m_inputTensorDescs[maskIndex] = TensorDesc::ConstructBroadcastedTensorDesc(maskTensorDataType, desiredMaskIndexShape, reshapedMaskIndexTensorShape);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
TensorDesc firstGemmOutputTensorDesc = TensorDesc::ConstructDefaultTensorDesc(dataType, desiredBiasTensorShape);
|
||||
DML_TENSOR_DESC namedFirstGemmOutputTensorDesc = firstGemmOutputTensorDesc.GetDmlDesc();
|
||||
TensorDesc matMulIntToFloatOutputTensorDesc = TensorDesc::ConstructDefaultTensorDesc(dataType, desiredBiasTensorShape);
|
||||
DML_TENSOR_DESC namedMatMulIntToFloatOutputTensorDesc = matMulIntToFloatOutputTensorDesc.GetDmlDesc();
|
||||
|
||||
std::vector<DML_TENSOR_DESC> inputDescs = GetDmlInputDescs();
|
||||
std::vector<DML_TENSOR_DESC> outputDescs = GetDmlOutputDescs();
|
||||
|
||||
// output edge between Dequantize and first GEMM node
|
||||
TensorDesc intermediateInputTensorDesc = TensorDesc::ConstructDefaultTensorDesc(dataType, inputTensorShape);
|
||||
DML_MATRIX_MULTIPLY_INTEGER_TO_FLOAT_OPERATOR_DESC matMulIntToFloatOperatorDesc = {};
|
||||
matMulIntToFloatOperatorDesc.ATensor = &inputDescs[InputIndex::inputIndex];
|
||||
matMulIntToFloatOperatorDesc.AScaleTensor = &inputDescs[InputIndex::inputScaleIndex];
|
||||
matMulIntToFloatOperatorDesc.AZeroPointTensor = &inputDescs[InputIndex::inputZeroPointIndex];
|
||||
matMulIntToFloatOperatorDesc.BTensor = &inputDescs[InputIndex::weightsIndex];
|
||||
matMulIntToFloatOperatorDesc.BScaleTensor = &inputDescs[InputIndex::weightScaleIndex];
|
||||
matMulIntToFloatOperatorDesc.BZeroPointTensor = &inputDescs[InputIndex::weightZeroPointIndex];
|
||||
matMulIntToFloatOperatorDesc.BiasTensor = hasBias ? &inputDescs[InputIndex::biasIndex] : nullptr;
|
||||
matMulIntToFloatOperatorDesc.OutputTensor = &namedMatMulIntToFloatOutputTensorDesc;
|
||||
|
||||
TensorDesc intermediateWeightTensorDesc = TensorDesc::ConstructDefaultTensorDesc(dataType, desiredWeightTensorShape);
|
||||
const DML_OPERATOR_DESC matMulIntToFloatDesc = { static_cast<DML_OPERATOR_TYPE>(DML_OPERATOR_MATRIX_MULTIPLY_INTEGER_TO_FLOAT), &matMulIntToFloatOperatorDesc};
|
||||
|
||||
DML_TENSOR_DESC namedIntermediateInputTensorDesc = intermediateInputTensorDesc.GetDmlDesc();
|
||||
DML_TENSOR_DESC namedIntermediateWeightTensorDesc = intermediateWeightTensorDesc.GetDmlDesc();
|
||||
|
||||
DML_ELEMENT_WISE_DEQUANTIZE_LINEAR_OPERATOR_DESC inputDequantizeOperatorDesc = {};
|
||||
inputDequantizeOperatorDesc.InputTensor = &inputDescs[InputIndex::inputIndex];
|
||||
inputDequantizeOperatorDesc.ScaleTensor = &inputDescs[InputIndex::inputScaleIndex];
|
||||
inputDequantizeOperatorDesc.ZeroPointTensor = &inputDescs[InputIndex::inputZeroPointIndex];
|
||||
inputDequantizeOperatorDesc.OutputTensor = &namedIntermediateInputTensorDesc;
|
||||
|
||||
const DML_OPERATOR_DESC inputDequantizeOpDesc{DML_OPERATOR_ELEMENT_WISE_DEQUANTIZE_LINEAR, &inputDequantizeOperatorDesc};
|
||||
|
||||
DML_ELEMENT_WISE_DEQUANTIZE_LINEAR_OPERATOR_DESC weightDequantizeOperatorDesc = {};
|
||||
weightDequantizeOperatorDesc.InputTensor = &inputDescs[InputIndex::weightsIndex];
|
||||
weightDequantizeOperatorDesc.ScaleTensor = &inputDescs[InputIndex::weightScaleIndex];
|
||||
weightDequantizeOperatorDesc.ZeroPointTensor = &inputDescs[InputIndex::weightZeroPointIndex];
|
||||
weightDequantizeOperatorDesc.OutputTensor = &namedIntermediateWeightTensorDesc;
|
||||
|
||||
const DML_OPERATOR_DESC weightDequantizeOpDesc{DML_OPERATOR_ELEMENT_WISE_DEQUANTIZE_LINEAR, &weightDequantizeOperatorDesc};
|
||||
|
||||
DML_GEMM_OPERATOR_DESC gemmOperatorDesc = {};
|
||||
gemmOperatorDesc.ATensor = inputDequantizeOperatorDesc.OutputTensor;
|
||||
gemmOperatorDesc.BTensor = weightDequantizeOperatorDesc.OutputTensor;
|
||||
|
||||
if (hasBias)
|
||||
{
|
||||
gemmOperatorDesc.CTensor = &inputDescs[2];
|
||||
}
|
||||
|
||||
gemmOperatorDesc.OutputTensor = &namedFirstGemmOutputTensorDesc;
|
||||
gemmOperatorDesc.TransA = DML_MATRIX_TRANSFORM_NONE;
|
||||
gemmOperatorDesc.TransB = DML_MATRIX_TRANSFORM_NONE;
|
||||
gemmOperatorDesc.Alpha = 1.0f;
|
||||
gemmOperatorDesc.Beta = 1.0f;
|
||||
gemmOperatorDesc.FusedActivation = nullptr;
|
||||
const DML_OPERATOR_DESC gemmDesc {DML_OPERATOR_GEMM, &gemmOperatorDesc};
|
||||
|
||||
std::array<uint32_t, 3> queryKeySlicedTensorShape {batchSize, sequenceLength, hiddenSize + hiddenSize};
|
||||
std::array<uint32_t, 3> queryKeySlicedTensorShape = {batchSize, sequenceLength, hiddenSize + hiddenSize};
|
||||
TensorDesc queryKeySlicedInputTensorDesc = TensorDesc::ConstructDefaultTensorDesc(dataType, queryKeySlicedTensorShape);
|
||||
DML_TENSOR_DESC namedQueryKeySlicedInputTensorDesc = queryKeySlicedInputTensorDesc.GetDmlDesc();
|
||||
|
||||
std::array<uint32_t, 3> valueSlicedTensorShape {batchSize, sequenceLength, vHiddenSize};
|
||||
std::array<uint32_t, 3> valueSlicedTensorShape = {batchSize, sequenceLength, vHiddenSize};
|
||||
TensorDesc valueSlicedInputTensorDesc = TensorDesc::ConstructDefaultTensorDesc(dataType, valueSlicedTensorShape);
|
||||
DML_TENSOR_DESC namedValueSlicedInputTensorDesc = valueSlicedInputTensorDesc.GetDmlDesc();
|
||||
|
||||
|
|
@ -282,7 +237,7 @@ public:
|
|||
DML_TENSOR_DESC namedQueryKeyTransposedOutputTensorDesc = queryKeyTransposedOutputTensorDesc.GetDmlDesc();
|
||||
|
||||
// Transpose QKV from [batchSize, sequenceLength, 3, numHeads, headSize] to [batchSize, sequenceLength, numHeads, 3, headSize]
|
||||
std::array<uint32_t, 5> queryKeyValueTransposedTensorShape {batchSize, sequenceLength, numHeads, 3, headSize};
|
||||
std::array<uint32_t, 5> queryKeyValueTransposedTensorShape = {batchSize, sequenceLength, numHeads, 3, headSize};
|
||||
std::array<uint32_t, 5> queryKeyValueTransposedStrides = {
|
||||
sequenceLength * numHeads * 3 * headSize,
|
||||
numHeads * 3 * headSize,
|
||||
|
|
@ -317,14 +272,14 @@ public:
|
|||
DML_ELEMENT_WISE_IDENTITY_OPERATOR_DESC transposeOperatorDesc = {};
|
||||
if (hasSlicedValue)
|
||||
{
|
||||
queryKeySlicedOperatorDesc.InputTensor = &namedFirstGemmOutputTensorDesc;
|
||||
queryKeySlicedOperatorDesc.InputTensor = &namedMatMulIntToFloatOutputTensorDesc;
|
||||
queryKeySlicedOperatorDesc.OutputTensor = &namedQueryKeySlicedInputTensorDesc;
|
||||
queryKeySlicedOperatorDesc.DimensionCount = gsl::narrow_cast<uint32_t>(queryKeySlicedTensorShape.size());
|
||||
queryKeySlicedOperatorDesc.InputWindowOffsets = queryKeySliceOffset.data();
|
||||
queryKeySlicedOperatorDesc.InputWindowSizes = queryKeySliceSize.data();
|
||||
queryKeySlicedOperatorDesc.InputWindowStrides = queryKeySliceStrides.data();
|
||||
|
||||
valueSlicedOperatorDesc.InputTensor = &namedFirstGemmOutputTensorDesc;
|
||||
valueSlicedOperatorDesc.InputTensor = &namedMatMulIntToFloatOutputTensorDesc;
|
||||
valueSlicedOperatorDesc.OutputTensor = &namedValueSlicedInputTensorDesc;
|
||||
valueSlicedOperatorDesc.DimensionCount = gsl::narrow_cast<uint32_t>(valueSlicedTensorShape.size());
|
||||
valueSlicedOperatorDesc.InputWindowOffsets = valueSliceOffset.data();
|
||||
|
|
@ -378,7 +333,6 @@ public:
|
|||
mhaOperatorDesc.MaskTensor = hasMask ? &inputDescs[maskIndex] : nullptr;
|
||||
}
|
||||
|
||||
// mhaOperatorDesc.RelativePositionBiasTensor = hasRelativePositionBias ? &inputDescs[dmlRelativePositionBiasIndex] : nullptr;
|
||||
mhaOperatorDesc.RelativePositionBiasTensor = nullptr;
|
||||
mhaOperatorDesc.OutputTensor = &outputDescs[outputIndex];
|
||||
mhaOperatorDesc.Scale = kernelCreationContext.GetOptionalAttribute<float>(AttrName::Scale, gsl::narrow_cast<float>(1.0f / std::sqrt(headSize)));
|
||||
|
|
@ -393,16 +347,12 @@ public:
|
|||
std::vector<DML_OUTPUT_GRAPH_EDGE_DESC> outputEdges;
|
||||
|
||||
std::vector<const DML_OPERATOR_DESC*> opDescs = {
|
||||
&inputDequantizeOpDesc,
|
||||
&weightDequantizeOpDesc,
|
||||
&gemmDesc,
|
||||
&matMulIntToFloatDesc,
|
||||
&mhaDesc,
|
||||
};
|
||||
|
||||
uint32_t currentNodeIndex = 0;
|
||||
const uint32_t inputDequantizeNodeIndex = currentNodeIndex++;
|
||||
const uint32_t weightDequantizeNodeIndex = currentNodeIndex++;
|
||||
const uint32_t gemmNodeIndex = currentNodeIndex++;
|
||||
const uint32_t matMulIntToFloatNodeIndex = currentNodeIndex++;
|
||||
const uint32_t mhaNodeIndex = currentNodeIndex++;
|
||||
|
||||
uint32_t valueSliceNodeIndex = 0;
|
||||
|
|
@ -433,63 +383,49 @@ public:
|
|||
maskSliceNodeIndex = currentNodeIndex++;
|
||||
}
|
||||
|
||||
DML_INPUT_GRAPH_EDGE_DESC inputToDequantizeEdge = {};
|
||||
inputToDequantizeEdge.GraphInputIndex = InputIndex::inputIndex;
|
||||
inputToDequantizeEdge.ToNodeIndex = inputDequantizeNodeIndex;
|
||||
inputToDequantizeEdge.ToNodeInputIndex = 0;
|
||||
inputEdges.push_back(inputToDequantizeEdge);
|
||||
DML_INPUT_GRAPH_EDGE_DESC inputToMatMulIntToFloatEdge = {};
|
||||
inputToMatMulIntToFloatEdge.GraphInputIndex = InputIndex::inputIndex;
|
||||
inputToMatMulIntToFloatEdge.ToNodeIndex = matMulIntToFloatNodeIndex;
|
||||
inputToMatMulIntToFloatEdge.ToNodeInputIndex = 0;
|
||||
inputEdges.push_back(inputToMatMulIntToFloatEdge);
|
||||
|
||||
DML_INPUT_GRAPH_EDGE_DESC inputScaleToDequantizeEdge = {};
|
||||
inputScaleToDequantizeEdge.GraphInputIndex = InputIndex::inputScaleIndex;
|
||||
inputScaleToDequantizeEdge.ToNodeIndex = inputDequantizeNodeIndex;
|
||||
inputScaleToDequantizeEdge.ToNodeInputIndex = 1;
|
||||
inputEdges.push_back(inputScaleToDequantizeEdge);
|
||||
DML_INPUT_GRAPH_EDGE_DESC inputScaleToMatMulIntToFloatEdge = {};
|
||||
inputScaleToMatMulIntToFloatEdge.GraphInputIndex = InputIndex::inputScaleIndex;
|
||||
inputScaleToMatMulIntToFloatEdge.ToNodeIndex = matMulIntToFloatNodeIndex;
|
||||
inputScaleToMatMulIntToFloatEdge.ToNodeInputIndex = 1;
|
||||
inputEdges.push_back(inputScaleToMatMulIntToFloatEdge);
|
||||
|
||||
DML_INPUT_GRAPH_EDGE_DESC inputZeroPointToDequantizeEdge = {};
|
||||
inputZeroPointToDequantizeEdge.GraphInputIndex = InputIndex::inputZeroPointIndex;
|
||||
inputZeroPointToDequantizeEdge.ToNodeIndex = inputDequantizeNodeIndex;
|
||||
inputZeroPointToDequantizeEdge.ToNodeInputIndex = 2;
|
||||
inputEdges.push_back(inputZeroPointToDequantizeEdge);
|
||||
DML_INPUT_GRAPH_EDGE_DESC inputZeroPointToMatMulIntToFloatEdge = {};
|
||||
inputZeroPointToMatMulIntToFloatEdge.GraphInputIndex = InputIndex::inputZeroPointIndex;
|
||||
inputZeroPointToMatMulIntToFloatEdge.ToNodeIndex = matMulIntToFloatNodeIndex;
|
||||
inputZeroPointToMatMulIntToFloatEdge.ToNodeInputIndex = 2;
|
||||
inputEdges.push_back(inputZeroPointToMatMulIntToFloatEdge);
|
||||
|
||||
DML_INPUT_GRAPH_EDGE_DESC weightToDequantizeEdge = {};
|
||||
weightToDequantizeEdge.GraphInputIndex = InputIndex::weightsIndex;
|
||||
weightToDequantizeEdge.ToNodeIndex = weightDequantizeNodeIndex;
|
||||
weightToDequantizeEdge.ToNodeInputIndex = 0;
|
||||
inputEdges.push_back(weightToDequantizeEdge);
|
||||
DML_INPUT_GRAPH_EDGE_DESC weightToMatMulIntToFloatEdge = {};
|
||||
weightToMatMulIntToFloatEdge.GraphInputIndex = InputIndex::weightsIndex;
|
||||
weightToMatMulIntToFloatEdge.ToNodeIndex = matMulIntToFloatNodeIndex;
|
||||
weightToMatMulIntToFloatEdge.ToNodeInputIndex = 3;
|
||||
inputEdges.push_back(weightToMatMulIntToFloatEdge);
|
||||
|
||||
DML_INPUT_GRAPH_EDGE_DESC weightScaleToDequantizeEdge = {};
|
||||
weightScaleToDequantizeEdge.GraphInputIndex = InputIndex::weightScaleIndex;
|
||||
weightScaleToDequantizeEdge.ToNodeIndex = weightDequantizeNodeIndex;
|
||||
weightScaleToDequantizeEdge.ToNodeInputIndex = 1;
|
||||
inputEdges.push_back(weightScaleToDequantizeEdge);
|
||||
DML_INPUT_GRAPH_EDGE_DESC weightScaleToMatMulIntToFloatEdge = {};
|
||||
weightScaleToMatMulIntToFloatEdge.GraphInputIndex = InputIndex::weightScaleIndex;
|
||||
weightScaleToMatMulIntToFloatEdge.ToNodeIndex = matMulIntToFloatNodeIndex;
|
||||
weightScaleToMatMulIntToFloatEdge.ToNodeInputIndex = 4;
|
||||
inputEdges.push_back(weightScaleToMatMulIntToFloatEdge);
|
||||
|
||||
DML_INPUT_GRAPH_EDGE_DESC weightZeroPointToDequantizeEdge = {};
|
||||
weightZeroPointToDequantizeEdge.GraphInputIndex = InputIndex::weightZeroPointIndex;
|
||||
weightZeroPointToDequantizeEdge.ToNodeIndex = weightDequantizeNodeIndex;
|
||||
weightZeroPointToDequantizeEdge.ToNodeInputIndex = 2;
|
||||
inputEdges.push_back(weightZeroPointToDequantizeEdge);
|
||||
|
||||
DML_INTERMEDIATE_GRAPH_EDGE_DESC inputQuantizeToGemmEdge = {};
|
||||
inputQuantizeToGemmEdge.FromNodeIndex = inputDequantizeNodeIndex;
|
||||
inputQuantizeToGemmEdge.FromNodeOutputIndex = 0;
|
||||
inputQuantizeToGemmEdge.ToNodeIndex = gemmNodeIndex;
|
||||
inputQuantizeToGemmEdge.ToNodeInputIndex = 0;
|
||||
intermediateEdges.push_back(inputQuantizeToGemmEdge);
|
||||
|
||||
DML_INTERMEDIATE_GRAPH_EDGE_DESC weightQuantizeToGemmEdge = {};
|
||||
weightQuantizeToGemmEdge.FromNodeIndex = weightDequantizeNodeIndex;
|
||||
weightQuantizeToGemmEdge.FromNodeOutputIndex = 0;
|
||||
weightQuantizeToGemmEdge.ToNodeIndex = gemmNodeIndex;
|
||||
weightQuantizeToGemmEdge.ToNodeInputIndex = 1;
|
||||
intermediateEdges.push_back(weightQuantizeToGemmEdge);
|
||||
DML_INPUT_GRAPH_EDGE_DESC weightZeroPointToMatMulIntToFloatEdge = {};
|
||||
weightZeroPointToMatMulIntToFloatEdge.GraphInputIndex = InputIndex::weightZeroPointIndex;
|
||||
weightZeroPointToMatMulIntToFloatEdge.ToNodeIndex = matMulIntToFloatNodeIndex;
|
||||
weightZeroPointToMatMulIntToFloatEdge.ToNodeInputIndex = 5;
|
||||
inputEdges.push_back(weightZeroPointToMatMulIntToFloatEdge);
|
||||
|
||||
if (hasBias)
|
||||
{
|
||||
DML_INPUT_GRAPH_EDGE_DESC biasToGemmEdge = {};
|
||||
biasToGemmEdge.GraphInputIndex = biasIndex;
|
||||
biasToGemmEdge.ToNodeIndex = gemmNodeIndex;
|
||||
biasToGemmEdge.ToNodeInputIndex = 2;
|
||||
inputEdges.push_back(biasToGemmEdge);
|
||||
DML_INPUT_GRAPH_EDGE_DESC biasToMatMulIntToFloatEdge = {};
|
||||
biasToMatMulIntToFloatEdge.GraphInputIndex = InputIndex::biasIndex;
|
||||
biasToMatMulIntToFloatEdge.ToNodeIndex = matMulIntToFloatNodeIndex;
|
||||
biasToMatMulIntToFloatEdge.ToNodeInputIndex = 6;
|
||||
inputEdges.push_back(biasToMatMulIntToFloatEdge);
|
||||
}
|
||||
|
||||
if (hasMask)
|
||||
|
|
@ -497,7 +433,7 @@ public:
|
|||
if (hasUnpaddedBounds)
|
||||
{
|
||||
DML_INPUT_GRAPH_EDGE_DESC maskToMhaEdge = {};
|
||||
maskToMhaEdge.GraphInputIndex = maskIndex;
|
||||
maskToMhaEdge.GraphInputIndex = InputIndex::maskIndex;
|
||||
maskToMhaEdge.ToNodeIndex = mhaNodeIndex;
|
||||
maskToMhaEdge.ToNodeInputIndex = mhaMaskIndex;
|
||||
inputEdges.push_back(maskToMhaEdge);
|
||||
|
|
@ -505,7 +441,7 @@ public:
|
|||
else if (hasMaxSequenceMask)
|
||||
{
|
||||
DML_INPUT_GRAPH_EDGE_DESC maskToMaskSliceEdge = {};
|
||||
maskToMaskSliceEdge.GraphInputIndex = maskIndex;
|
||||
maskToMaskSliceEdge.GraphInputIndex = InputIndex::maskIndex;
|
||||
maskToMaskSliceEdge.ToNodeIndex = maskSliceNodeIndex;
|
||||
maskToMaskSliceEdge.ToNodeInputIndex = 0;
|
||||
inputEdges.push_back(maskToMaskSliceEdge);
|
||||
|
|
@ -520,7 +456,7 @@ public:
|
|||
else
|
||||
{
|
||||
DML_INPUT_GRAPH_EDGE_DESC maskToMhaEdge = {};
|
||||
maskToMhaEdge.GraphInputIndex = maskIndex;
|
||||
maskToMhaEdge.GraphInputIndex = InputIndex::maskIndex;
|
||||
maskToMhaEdge.ToNodeIndex = mhaNodeIndex;
|
||||
maskToMhaEdge.ToNodeInputIndex = mhaMaskIndex;
|
||||
inputEdges.push_back(maskToMhaEdge);
|
||||
|
|
@ -530,12 +466,12 @@ public:
|
|||
if (hasSlicedValue)
|
||||
{
|
||||
// We need to slice QK and V, and transpose QK
|
||||
DML_INTERMEDIATE_GRAPH_EDGE_DESC gemmToQueryKeySliceEdge = {};
|
||||
gemmToQueryKeySliceEdge.FromNodeIndex = gemmNodeIndex;
|
||||
gemmToQueryKeySliceEdge.FromNodeOutputIndex = 0;
|
||||
gemmToQueryKeySliceEdge.ToNodeIndex = queryKeySliceNodeIndex;
|
||||
gemmToQueryKeySliceEdge.ToNodeInputIndex = 0;
|
||||
intermediateEdges.push_back(gemmToQueryKeySliceEdge);
|
||||
DML_INTERMEDIATE_GRAPH_EDGE_DESC matMulIntToFloatToQueryKeySliceEdge = {};
|
||||
matMulIntToFloatToQueryKeySliceEdge.FromNodeIndex = matMulIntToFloatNodeIndex;
|
||||
matMulIntToFloatToQueryKeySliceEdge.FromNodeOutputIndex = 0;
|
||||
matMulIntToFloatToQueryKeySliceEdge.ToNodeIndex = queryKeySliceNodeIndex;
|
||||
matMulIntToFloatToQueryKeySliceEdge.ToNodeInputIndex = 0;
|
||||
intermediateEdges.push_back(matMulIntToFloatToQueryKeySliceEdge);
|
||||
|
||||
DML_INTERMEDIATE_GRAPH_EDGE_DESC queryKeySliceToTransposeEdge = {};
|
||||
queryKeySliceToTransposeEdge.FromNodeIndex = queryKeySliceNodeIndex;
|
||||
|
|
@ -551,12 +487,12 @@ public:
|
|||
queryKeyTransposedToMhaEdge.ToNodeInputIndex = mhaStackedQueryKeyIndex;
|
||||
intermediateEdges.push_back(queryKeyTransposedToMhaEdge);
|
||||
|
||||
DML_INTERMEDIATE_GRAPH_EDGE_DESC gemmToValueSliceEdge = {};
|
||||
gemmToValueSliceEdge.FromNodeIndex = gemmNodeIndex;
|
||||
gemmToValueSliceEdge.FromNodeOutputIndex = 0;
|
||||
gemmToValueSliceEdge.ToNodeIndex = valueSliceNodeIndex;
|
||||
gemmToValueSliceEdge.ToNodeInputIndex = 0;
|
||||
intermediateEdges.push_back(gemmToValueSliceEdge);
|
||||
DML_INTERMEDIATE_GRAPH_EDGE_DESC matMulIntToFloatToValueSliceEdge = {};
|
||||
matMulIntToFloatToValueSliceEdge.FromNodeIndex = matMulIntToFloatNodeIndex;
|
||||
matMulIntToFloatToValueSliceEdge.FromNodeOutputIndex = 0;
|
||||
matMulIntToFloatToValueSliceEdge.ToNodeIndex = valueSliceNodeIndex;
|
||||
matMulIntToFloatToValueSliceEdge.ToNodeInputIndex = 0;
|
||||
intermediateEdges.push_back(matMulIntToFloatToValueSliceEdge);
|
||||
|
||||
DML_INTERMEDIATE_GRAPH_EDGE_DESC valueSliceToMhaEdge = {};
|
||||
valueSliceToMhaEdge.FromNodeIndex = valueSliceNodeIndex;
|
||||
|
|
@ -567,12 +503,12 @@ public:
|
|||
}
|
||||
else
|
||||
{
|
||||
DML_INTERMEDIATE_GRAPH_EDGE_DESC gemmToQueryKeyValueTransposeEdge = {};
|
||||
gemmToQueryKeyValueTransposeEdge.FromNodeIndex = gemmNodeIndex;
|
||||
gemmToQueryKeyValueTransposeEdge.FromNodeOutputIndex = 0;
|
||||
gemmToQueryKeyValueTransposeEdge.ToNodeIndex = queryKeyValueTransposedNodeIndex;
|
||||
gemmToQueryKeyValueTransposeEdge.ToNodeInputIndex = 0;
|
||||
intermediateEdges.push_back(gemmToQueryKeyValueTransposeEdge);
|
||||
DML_INTERMEDIATE_GRAPH_EDGE_DESC matMulIntToFloatToQueryKeyValueTransposeEdge = {};
|
||||
matMulIntToFloatToQueryKeyValueTransposeEdge.FromNodeIndex = matMulIntToFloatNodeIndex;
|
||||
matMulIntToFloatToQueryKeyValueTransposeEdge.FromNodeOutputIndex = 0;
|
||||
matMulIntToFloatToQueryKeyValueTransposeEdge.ToNodeIndex = queryKeyValueTransposedNodeIndex;
|
||||
matMulIntToFloatToQueryKeyValueTransposeEdge.ToNodeInputIndex = 0;
|
||||
intermediateEdges.push_back(matMulIntToFloatToQueryKeyValueTransposeEdge);
|
||||
|
||||
// All we need to do here is transpose the stacked QKV tensor into something DML supports
|
||||
DML_INTERMEDIATE_GRAPH_EDGE_DESC queryKeyValueTransposedToMhaEdge = {};
|
||||
|
|
|
|||
Loading…
Reference in a new issue