mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-23 22:13:38 +00:00
[DML EP] Fix Attention regression caused by removing transposes (#14908)
By removing the transposes and using strides instead, the metacommands are not able to be reached anymore since it's not using NCHW layout.
This commit is contained in:
parent
6b604521a6
commit
65f1f840f6
1 changed files with 78 additions and 31 deletions
|
|
@ -16,6 +16,10 @@ Abbreviations: B is batch_size, S is sequence_length, W is hidden_size
|
|||
| / | \
|
||||
| Slice Slice Slice
|
||||
Identity | | |
|
||||
| | | |
|
||||
| Identity Identity Identity // The identities are used to transpose NCHW -> NHCW while
|
||||
| | | | // keeping the GEMM strides as NCHW to better target metacommands
|
||||
| | | |
|
||||
| ----- |
|
||||
----------- | |
|
||||
\ | |
|
||||
|
|
@ -172,31 +176,47 @@ public:
|
|||
uint32_t reshapedTransposedQueryTensorShape[4] = {batchSize, numHeads, sequenceLength, headSize};
|
||||
uint32_t reshapedTransposedQueryTensorStrides[4] = {sequenceLength * numHeads * headSize, headSize, numHeads * headSize, 1};
|
||||
TensorDesc reshapedTransposedQueryTensorDesc = TensorDesc(
|
||||
GetDmlDataTypeFromMlDataType(dataType),
|
||||
reshapedTransposedQueryTensorShape,
|
||||
reshapedTransposedQueryTensorStrides,
|
||||
0 // guaranteedBaseOffsetAlignment
|
||||
);
|
||||
GetDmlDataTypeFromMlDataType(dataType),
|
||||
reshapedTransposedQueryTensorShape,
|
||||
reshapedTransposedQueryTensorStrides);
|
||||
DML_TENSOR_DESC namedReshapedTransposedQueryTensorDesc = reshapedTransposedQueryTensorDesc.GetDmlDesc();
|
||||
|
||||
TensorDesc reshapedTransposedQueryOutputTensorDesc = TensorDesc(
|
||||
GetDmlDataTypeFromMlDataType(dataType),
|
||||
reshapedTransposedQueryTensorShape);
|
||||
DML_TENSOR_DESC namedReshapedTransposedQueryOutputTensorDesc = reshapedTransposedQueryOutputTensorDesc.GetDmlDesc();
|
||||
|
||||
DML_ELEMENT_WISE_IDENTITY_OPERATOR_DESC transposedQueryOperatorDesc{};
|
||||
transposedQueryOperatorDesc.InputTensor = &namedReshapedTransposedQueryTensorDesc;
|
||||
transposedQueryOperatorDesc.OutputTensor = &namedReshapedTransposedQueryOutputTensorDesc;
|
||||
const DML_OPERATOR_DESC transposedQueryDesc {DML_OPERATOR_ELEMENT_WISE_IDENTITY, &transposedQueryOperatorDesc};
|
||||
|
||||
uint32_t reshapedTransposedKeyTensorShape[4] = {batchSize, numHeads, headSize, sequenceLength};
|
||||
uint32_t reshapedTransposedKeyTensorStrides[4] = {sequenceLength * numHeads * headSize, headSize, 1, numHeads * headSize};
|
||||
TensorDesc reshapedTransposedKeyTensorDesc = TensorDesc(
|
||||
GetDmlDataTypeFromMlDataType(dataType),
|
||||
reshapedTransposedKeyTensorShape,
|
||||
reshapedTransposedKeyTensorStrides,
|
||||
0 // guaranteedBaseOffsetAlignment
|
||||
);
|
||||
GetDmlDataTypeFromMlDataType(dataType),
|
||||
reshapedTransposedKeyTensorShape,
|
||||
reshapedTransposedKeyTensorStrides);
|
||||
DML_TENSOR_DESC namedReshapedTransposedKeyTensorDesc = reshapedTransposedKeyTensorDesc.GetDmlDesc();
|
||||
|
||||
TensorDesc reshapedTransposedKeyOutputTensorDesc = TensorDesc(
|
||||
GetDmlDataTypeFromMlDataType(dataType),
|
||||
reshapedTransposedKeyTensorShape);
|
||||
DML_TENSOR_DESC namedReshapedTransposedKeyOutputTensorDesc = reshapedTransposedKeyOutputTensorDesc.GetDmlDesc();
|
||||
|
||||
DML_ELEMENT_WISE_IDENTITY_OPERATOR_DESC transposedKeyOperatorDesc{};
|
||||
transposedKeyOperatorDesc.InputTensor = &namedReshapedTransposedKeyTensorDesc;
|
||||
transposedKeyOperatorDesc.OutputTensor = &namedReshapedTransposedKeyOutputTensorDesc;
|
||||
const DML_OPERATOR_DESC transposedKeyDesc {DML_OPERATOR_ELEMENT_WISE_IDENTITY, &transposedKeyOperatorDesc};
|
||||
|
||||
uint32_t queryKeyTensorShape[4] = {batchSize, numHeads, sequenceLength, sequenceLength};
|
||||
TensorDesc queryKeyTensorDesc = TensorDesc::ConstructDefaultTensorDesc(dataType, queryKeyTensorShape);
|
||||
DML_TENSOR_DESC namedQueryKeyTensorDesc = queryKeyTensorDesc.GetDmlDesc();
|
||||
|
||||
float alpha = static_cast<float>(1 / sqrt(headSize));
|
||||
DML_GEMM_OPERATOR_DESC attentionScoreOperatorDesc = {};
|
||||
attentionScoreOperatorDesc.ATensor = &namedReshapedTransposedQueryTensorDesc;
|
||||
attentionScoreOperatorDesc.BTensor = &namedReshapedTransposedKeyTensorDesc;
|
||||
attentionScoreOperatorDesc.ATensor = &namedReshapedTransposedQueryOutputTensorDesc;
|
||||
attentionScoreOperatorDesc.BTensor = &namedReshapedTransposedKeyOutputTensorDesc;
|
||||
attentionScoreOperatorDesc.CTensor = &namedCastedMaskIndexTensorDesc;
|
||||
attentionScoreOperatorDesc.OutputTensor = &namedQueryKeyTensorDesc;
|
||||
attentionScoreOperatorDesc.TransA = DML_MATRIX_TRANSFORM_NONE;
|
||||
|
|
@ -226,7 +246,7 @@ public:
|
|||
|
||||
DML_GEMM_OPERATOR_DESC attentionWeightOperatorDesc = {};
|
||||
attentionWeightOperatorDesc.ATensor = &namedQueryKeyTensorDesc;
|
||||
attentionWeightOperatorDesc.BTensor = &namedReshapedTransposedQueryTensorDesc;
|
||||
attentionWeightOperatorDesc.BTensor = &namedReshapedTransposedQueryOutputTensorDesc;
|
||||
attentionWeightOperatorDesc.CTensor = nullptr;
|
||||
attentionWeightOperatorDesc.OutputTensor = &namedReshapedTransposedOutputTensorDesc;
|
||||
attentionWeightOperatorDesc.TransA = DML_MATRIX_TRANSFORM_NONE;
|
||||
|
|
@ -257,8 +277,11 @@ public:
|
|||
querySlice,
|
||||
keySlice,
|
||||
valueSlice,
|
||||
queryTranspose,
|
||||
keyTranspose,
|
||||
attentionScore,
|
||||
softmax,
|
||||
valueTranspose,
|
||||
attentionWeight,
|
||||
castMaskIndex,
|
||||
mask,
|
||||
|
|
@ -272,8 +295,11 @@ public:
|
|||
&querySlicedDesc,
|
||||
&keySlicedDesc,
|
||||
&valueSlicedDesc,
|
||||
&transposedQueryDesc,
|
||||
&transposedKeyDesc,
|
||||
&attentionScoreDesc,
|
||||
&softmaxDesc,
|
||||
&transposedQueryDesc,
|
||||
&attentionWeightDesc,
|
||||
&castMaskIndexDesc,
|
||||
&maskDesc,
|
||||
|
|
@ -325,19 +351,33 @@ public:
|
|||
gemmToValueSliceEdge.ToNodeInputIndex = 0;
|
||||
intermediateEdges.push_back(gemmToValueSliceEdge);
|
||||
|
||||
DML_INTERMEDIATE_GRAPH_EDGE_DESC querySliceToGemm = {};
|
||||
querySliceToGemm.FromNodeIndex = NodeIndex::querySlice;
|
||||
querySliceToGemm.FromNodeOutputIndex = 0;
|
||||
querySliceToGemm.ToNodeIndex = NodeIndex::attentionScore;
|
||||
querySliceToGemm.ToNodeInputIndex = 0;
|
||||
intermediateEdges.push_back(querySliceToGemm);
|
||||
DML_INTERMEDIATE_GRAPH_EDGE_DESC querySliceToQueryTranspose = {};
|
||||
querySliceToQueryTranspose.FromNodeIndex = NodeIndex::querySlice;
|
||||
querySliceToQueryTranspose.FromNodeOutputIndex = 0;
|
||||
querySliceToQueryTranspose.ToNodeIndex = NodeIndex::queryTranspose;
|
||||
querySliceToQueryTranspose.ToNodeInputIndex = 0;
|
||||
intermediateEdges.push_back(querySliceToQueryTranspose);
|
||||
|
||||
DML_INTERMEDIATE_GRAPH_EDGE_DESC keySliceToGemm = {};
|
||||
keySliceToGemm.FromNodeIndex = NodeIndex::keySlice;
|
||||
keySliceToGemm.FromNodeOutputIndex = 0;
|
||||
keySliceToGemm.ToNodeIndex = NodeIndex::attentionScore;
|
||||
keySliceToGemm.ToNodeInputIndex = 1;
|
||||
intermediateEdges.push_back(keySliceToGemm);
|
||||
DML_INTERMEDIATE_GRAPH_EDGE_DESC keySliceToKeyTranspose = {};
|
||||
keySliceToKeyTranspose.FromNodeIndex = NodeIndex::keySlice;
|
||||
keySliceToKeyTranspose.FromNodeOutputIndex = 0;
|
||||
keySliceToKeyTranspose.ToNodeIndex = NodeIndex::keyTranspose;
|
||||
keySliceToKeyTranspose.ToNodeInputIndex = 0;
|
||||
intermediateEdges.push_back(keySliceToKeyTranspose);
|
||||
|
||||
DML_INTERMEDIATE_GRAPH_EDGE_DESC queryTransposeToGemm = {};
|
||||
queryTransposeToGemm.FromNodeIndex = NodeIndex::queryTranspose;
|
||||
queryTransposeToGemm.FromNodeOutputIndex = 0;
|
||||
queryTransposeToGemm.ToNodeIndex = NodeIndex::attentionScore;
|
||||
queryTransposeToGemm.ToNodeInputIndex = 0;
|
||||
intermediateEdges.push_back(queryTransposeToGemm);
|
||||
|
||||
DML_INTERMEDIATE_GRAPH_EDGE_DESC keyTransposeToGemm = {};
|
||||
keyTransposeToGemm.FromNodeIndex = NodeIndex::keyTranspose;
|
||||
keyTransposeToGemm.FromNodeOutputIndex = 0;
|
||||
keyTransposeToGemm.ToNodeIndex = NodeIndex::attentionScore;
|
||||
keyTransposeToGemm.ToNodeInputIndex = 1;
|
||||
intermediateEdges.push_back(keyTransposeToGemm);
|
||||
|
||||
DML_INTERMEDIATE_GRAPH_EDGE_DESC castedMaskIndexToIdentity = {};
|
||||
castedMaskIndexToIdentity.FromNodeIndex = NodeIndex::castMaskIndex;
|
||||
|
|
@ -367,12 +407,19 @@ public:
|
|||
softmaxToGemm.ToNodeInputIndex = 0;
|
||||
intermediateEdges.push_back(softmaxToGemm);
|
||||
|
||||
DML_INTERMEDIATE_GRAPH_EDGE_DESC valueSliceToGemm = {};
|
||||
valueSliceToGemm.FromNodeIndex = NodeIndex::valueSlice;
|
||||
valueSliceToGemm.FromNodeOutputIndex = 0;
|
||||
valueSliceToGemm.ToNodeIndex = NodeIndex::attentionWeight;
|
||||
valueSliceToGemm.ToNodeInputIndex = 1;
|
||||
intermediateEdges.push_back(valueSliceToGemm);
|
||||
DML_INTERMEDIATE_GRAPH_EDGE_DESC valueSliceToValueTranspose = {};
|
||||
valueSliceToValueTranspose.FromNodeIndex = NodeIndex::valueSlice;
|
||||
valueSliceToValueTranspose.FromNodeOutputIndex = 0;
|
||||
valueSliceToValueTranspose.ToNodeIndex = NodeIndex::valueTranspose;
|
||||
valueSliceToValueTranspose.ToNodeInputIndex = 0;
|
||||
intermediateEdges.push_back(valueSliceToValueTranspose);
|
||||
|
||||
DML_INTERMEDIATE_GRAPH_EDGE_DESC valueTransposeToGemm = {};
|
||||
valueTransposeToGemm.FromNodeIndex = NodeIndex::valueTranspose;
|
||||
valueTransposeToGemm.FromNodeOutputIndex = 0;
|
||||
valueTransposeToGemm.ToNodeIndex = NodeIndex::attentionWeight;
|
||||
valueTransposeToGemm.ToNodeInputIndex = 1;
|
||||
intermediateEdges.push_back(valueTransposeToGemm);
|
||||
|
||||
DML_INTERMEDIATE_GRAPH_EDGE_DESC gemmToIdentity = {};
|
||||
gemmToIdentity.FromNodeIndex = NodeIndex::attentionWeight;
|
||||
|
|
|
|||
Loading…
Reference in a new issue