[DML EP] Fix Attention regression caused by removing transposes (#14908)

By removing the transposes and using strides instead, the metacommands
are not able to be reached anymore since it's not using NCHW layout.
This commit is contained in:
Patrice Vignola 2023-03-07 11:17:28 -08:00 committed by GitHub
parent 6b604521a6
commit 65f1f840f6
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -16,6 +16,10 @@ Abbreviations: B is batch_size, S is sequence_length, W is hidden_size
| / | \
| Slice Slice Slice
Identity | | |
| | | |
| Identity Identity Identity // The identities are used to transpose NCHW -> NHCW while
| | | | // keeping the GEMM strides as NCHW to better target metacommands
| | | |
| ----- |
----------- | |
\ | |
@ -172,31 +176,47 @@ public:
uint32_t reshapedTransposedQueryTensorShape[4] = {batchSize, numHeads, sequenceLength, headSize};
uint32_t reshapedTransposedQueryTensorStrides[4] = {sequenceLength * numHeads * headSize, headSize, numHeads * headSize, 1};
TensorDesc reshapedTransposedQueryTensorDesc = TensorDesc(
GetDmlDataTypeFromMlDataType(dataType),
reshapedTransposedQueryTensorShape,
reshapedTransposedQueryTensorStrides,
0 // guaranteedBaseOffsetAlignment
);
GetDmlDataTypeFromMlDataType(dataType),
reshapedTransposedQueryTensorShape,
reshapedTransposedQueryTensorStrides);
DML_TENSOR_DESC namedReshapedTransposedQueryTensorDesc = reshapedTransposedQueryTensorDesc.GetDmlDesc();
TensorDesc reshapedTransposedQueryOutputTensorDesc = TensorDesc(
GetDmlDataTypeFromMlDataType(dataType),
reshapedTransposedQueryTensorShape);
DML_TENSOR_DESC namedReshapedTransposedQueryOutputTensorDesc = reshapedTransposedQueryOutputTensorDesc.GetDmlDesc();
DML_ELEMENT_WISE_IDENTITY_OPERATOR_DESC transposedQueryOperatorDesc{};
transposedQueryOperatorDesc.InputTensor = &namedReshapedTransposedQueryTensorDesc;
transposedQueryOperatorDesc.OutputTensor = &namedReshapedTransposedQueryOutputTensorDesc;
const DML_OPERATOR_DESC transposedQueryDesc {DML_OPERATOR_ELEMENT_WISE_IDENTITY, &transposedQueryOperatorDesc};
uint32_t reshapedTransposedKeyTensorShape[4] = {batchSize, numHeads, headSize, sequenceLength};
uint32_t reshapedTransposedKeyTensorStrides[4] = {sequenceLength * numHeads * headSize, headSize, 1, numHeads * headSize};
TensorDesc reshapedTransposedKeyTensorDesc = TensorDesc(
GetDmlDataTypeFromMlDataType(dataType),
reshapedTransposedKeyTensorShape,
reshapedTransposedKeyTensorStrides,
0 // guaranteedBaseOffsetAlignment
);
GetDmlDataTypeFromMlDataType(dataType),
reshapedTransposedKeyTensorShape,
reshapedTransposedKeyTensorStrides);
DML_TENSOR_DESC namedReshapedTransposedKeyTensorDesc = reshapedTransposedKeyTensorDesc.GetDmlDesc();
TensorDesc reshapedTransposedKeyOutputTensorDesc = TensorDesc(
GetDmlDataTypeFromMlDataType(dataType),
reshapedTransposedKeyTensorShape);
DML_TENSOR_DESC namedReshapedTransposedKeyOutputTensorDesc = reshapedTransposedKeyOutputTensorDesc.GetDmlDesc();
DML_ELEMENT_WISE_IDENTITY_OPERATOR_DESC transposedKeyOperatorDesc{};
transposedKeyOperatorDesc.InputTensor = &namedReshapedTransposedKeyTensorDesc;
transposedKeyOperatorDesc.OutputTensor = &namedReshapedTransposedKeyOutputTensorDesc;
const DML_OPERATOR_DESC transposedKeyDesc {DML_OPERATOR_ELEMENT_WISE_IDENTITY, &transposedKeyOperatorDesc};
uint32_t queryKeyTensorShape[4] = {batchSize, numHeads, sequenceLength, sequenceLength};
TensorDesc queryKeyTensorDesc = TensorDesc::ConstructDefaultTensorDesc(dataType, queryKeyTensorShape);
DML_TENSOR_DESC namedQueryKeyTensorDesc = queryKeyTensorDesc.GetDmlDesc();
float alpha = static_cast<float>(1 / sqrt(headSize));
DML_GEMM_OPERATOR_DESC attentionScoreOperatorDesc = {};
attentionScoreOperatorDesc.ATensor = &namedReshapedTransposedQueryTensorDesc;
attentionScoreOperatorDesc.BTensor = &namedReshapedTransposedKeyTensorDesc;
attentionScoreOperatorDesc.ATensor = &namedReshapedTransposedQueryOutputTensorDesc;
attentionScoreOperatorDesc.BTensor = &namedReshapedTransposedKeyOutputTensorDesc;
attentionScoreOperatorDesc.CTensor = &namedCastedMaskIndexTensorDesc;
attentionScoreOperatorDesc.OutputTensor = &namedQueryKeyTensorDesc;
attentionScoreOperatorDesc.TransA = DML_MATRIX_TRANSFORM_NONE;
@ -226,7 +246,7 @@ public:
DML_GEMM_OPERATOR_DESC attentionWeightOperatorDesc = {};
attentionWeightOperatorDesc.ATensor = &namedQueryKeyTensorDesc;
attentionWeightOperatorDesc.BTensor = &namedReshapedTransposedQueryTensorDesc;
attentionWeightOperatorDesc.BTensor = &namedReshapedTransposedQueryOutputTensorDesc;
attentionWeightOperatorDesc.CTensor = nullptr;
attentionWeightOperatorDesc.OutputTensor = &namedReshapedTransposedOutputTensorDesc;
attentionWeightOperatorDesc.TransA = DML_MATRIX_TRANSFORM_NONE;
@ -257,8 +277,11 @@ public:
querySlice,
keySlice,
valueSlice,
queryTranspose,
keyTranspose,
attentionScore,
softmax,
valueTranspose,
attentionWeight,
castMaskIndex,
mask,
@ -272,8 +295,11 @@ public:
&querySlicedDesc,
&keySlicedDesc,
&valueSlicedDesc,
&transposedQueryDesc,
&transposedKeyDesc,
&attentionScoreDesc,
&softmaxDesc,
&transposedQueryDesc,
&attentionWeightDesc,
&castMaskIndexDesc,
&maskDesc,
@ -325,19 +351,33 @@ public:
gemmToValueSliceEdge.ToNodeInputIndex = 0;
intermediateEdges.push_back(gemmToValueSliceEdge);
DML_INTERMEDIATE_GRAPH_EDGE_DESC querySliceToGemm = {};
querySliceToGemm.FromNodeIndex = NodeIndex::querySlice;
querySliceToGemm.FromNodeOutputIndex = 0;
querySliceToGemm.ToNodeIndex = NodeIndex::attentionScore;
querySliceToGemm.ToNodeInputIndex = 0;
intermediateEdges.push_back(querySliceToGemm);
DML_INTERMEDIATE_GRAPH_EDGE_DESC querySliceToQueryTranspose = {};
querySliceToQueryTranspose.FromNodeIndex = NodeIndex::querySlice;
querySliceToQueryTranspose.FromNodeOutputIndex = 0;
querySliceToQueryTranspose.ToNodeIndex = NodeIndex::queryTranspose;
querySliceToQueryTranspose.ToNodeInputIndex = 0;
intermediateEdges.push_back(querySliceToQueryTranspose);
DML_INTERMEDIATE_GRAPH_EDGE_DESC keySliceToGemm = {};
keySliceToGemm.FromNodeIndex = NodeIndex::keySlice;
keySliceToGemm.FromNodeOutputIndex = 0;
keySliceToGemm.ToNodeIndex = NodeIndex::attentionScore;
keySliceToGemm.ToNodeInputIndex = 1;
intermediateEdges.push_back(keySliceToGemm);
DML_INTERMEDIATE_GRAPH_EDGE_DESC keySliceToKeyTranspose = {};
keySliceToKeyTranspose.FromNodeIndex = NodeIndex::keySlice;
keySliceToKeyTranspose.FromNodeOutputIndex = 0;
keySliceToKeyTranspose.ToNodeIndex = NodeIndex::keyTranspose;
keySliceToKeyTranspose.ToNodeInputIndex = 0;
intermediateEdges.push_back(keySliceToKeyTranspose);
DML_INTERMEDIATE_GRAPH_EDGE_DESC queryTransposeToGemm = {};
queryTransposeToGemm.FromNodeIndex = NodeIndex::queryTranspose;
queryTransposeToGemm.FromNodeOutputIndex = 0;
queryTransposeToGemm.ToNodeIndex = NodeIndex::attentionScore;
queryTransposeToGemm.ToNodeInputIndex = 0;
intermediateEdges.push_back(queryTransposeToGemm);
DML_INTERMEDIATE_GRAPH_EDGE_DESC keyTransposeToGemm = {};
keyTransposeToGemm.FromNodeIndex = NodeIndex::keyTranspose;
keyTransposeToGemm.FromNodeOutputIndex = 0;
keyTransposeToGemm.ToNodeIndex = NodeIndex::attentionScore;
keyTransposeToGemm.ToNodeInputIndex = 1;
intermediateEdges.push_back(keyTransposeToGemm);
DML_INTERMEDIATE_GRAPH_EDGE_DESC castedMaskIndexToIdentity = {};
castedMaskIndexToIdentity.FromNodeIndex = NodeIndex::castMaskIndex;
@ -367,12 +407,19 @@ public:
softmaxToGemm.ToNodeInputIndex = 0;
intermediateEdges.push_back(softmaxToGemm);
DML_INTERMEDIATE_GRAPH_EDGE_DESC valueSliceToGemm = {};
valueSliceToGemm.FromNodeIndex = NodeIndex::valueSlice;
valueSliceToGemm.FromNodeOutputIndex = 0;
valueSliceToGemm.ToNodeIndex = NodeIndex::attentionWeight;
valueSliceToGemm.ToNodeInputIndex = 1;
intermediateEdges.push_back(valueSliceToGemm);
DML_INTERMEDIATE_GRAPH_EDGE_DESC valueSliceToValueTranspose = {};
valueSliceToValueTranspose.FromNodeIndex = NodeIndex::valueSlice;
valueSliceToValueTranspose.FromNodeOutputIndex = 0;
valueSliceToValueTranspose.ToNodeIndex = NodeIndex::valueTranspose;
valueSliceToValueTranspose.ToNodeInputIndex = 0;
intermediateEdges.push_back(valueSliceToValueTranspose);
DML_INTERMEDIATE_GRAPH_EDGE_DESC valueTransposeToGemm = {};
valueTransposeToGemm.FromNodeIndex = NodeIndex::valueTranspose;
valueTransposeToGemm.FromNodeOutputIndex = 0;
valueTransposeToGemm.ToNodeIndex = NodeIndex::attentionWeight;
valueTransposeToGemm.ToNodeInputIndex = 1;
intermediateEdges.push_back(valueTransposeToGemm);
DML_INTERMEDIATE_GRAPH_EDGE_DESC gemmToIdentity = {};
gemmToIdentity.FromNodeIndex = NodeIndex::attentionWeight;