diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorAttention.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorAttention.cpp index af93808248..132b099aab 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorAttention.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorAttention.cpp @@ -16,6 +16,10 @@ Abbreviations: B is batch_size, S is sequence_length, W is hidden_size | / | \ | Slice Slice Slice Identity | | | + | | | | + | Identity Identity Identity // The identities are used to transpose NCHW -> NHCW while + | | | | // keeping the GEMM strides as NCHW to better target metacommands + | | | | | ----- | ----------- | | \ | | @@ -172,31 +176,47 @@ public: uint32_t reshapedTransposedQueryTensorShape[4] = {batchSize, numHeads, sequenceLength, headSize}; uint32_t reshapedTransposedQueryTensorStrides[4] = {sequenceLength * numHeads * headSize, headSize, numHeads * headSize, 1}; TensorDesc reshapedTransposedQueryTensorDesc = TensorDesc( - GetDmlDataTypeFromMlDataType(dataType), - reshapedTransposedQueryTensorShape, - reshapedTransposedQueryTensorStrides, - 0 // guaranteedBaseOffsetAlignment - ); + GetDmlDataTypeFromMlDataType(dataType), + reshapedTransposedQueryTensorShape, + reshapedTransposedQueryTensorStrides); DML_TENSOR_DESC namedReshapedTransposedQueryTensorDesc = reshapedTransposedQueryTensorDesc.GetDmlDesc(); + TensorDesc reshapedTransposedQueryOutputTensorDesc = TensorDesc( + GetDmlDataTypeFromMlDataType(dataType), + reshapedTransposedQueryTensorShape); + DML_TENSOR_DESC namedReshapedTransposedQueryOutputTensorDesc = reshapedTransposedQueryOutputTensorDesc.GetDmlDesc(); + + DML_ELEMENT_WISE_IDENTITY_OPERATOR_DESC transposedQueryOperatorDesc{}; + transposedQueryOperatorDesc.InputTensor = &namedReshapedTransposedQueryTensorDesc; + transposedQueryOperatorDesc.OutputTensor = &namedReshapedTransposedQueryOutputTensorDesc; + const DML_OPERATOR_DESC transposedQueryDesc {DML_OPERATOR_ELEMENT_WISE_IDENTITY, &transposedQueryOperatorDesc}; + uint32_t reshapedTransposedKeyTensorShape[4] = {batchSize, numHeads, headSize, sequenceLength}; uint32_t reshapedTransposedKeyTensorStrides[4] = {sequenceLength * numHeads * headSize, headSize, 1, numHeads * headSize}; TensorDesc reshapedTransposedKeyTensorDesc = TensorDesc( - GetDmlDataTypeFromMlDataType(dataType), - reshapedTransposedKeyTensorShape, - reshapedTransposedKeyTensorStrides, - 0 // guaranteedBaseOffsetAlignment - ); + GetDmlDataTypeFromMlDataType(dataType), + reshapedTransposedKeyTensorShape, + reshapedTransposedKeyTensorStrides); DML_TENSOR_DESC namedReshapedTransposedKeyTensorDesc = reshapedTransposedKeyTensorDesc.GetDmlDesc(); + TensorDesc reshapedTransposedKeyOutputTensorDesc = TensorDesc( + GetDmlDataTypeFromMlDataType(dataType), + reshapedTransposedKeyTensorShape); + DML_TENSOR_DESC namedReshapedTransposedKeyOutputTensorDesc = reshapedTransposedKeyOutputTensorDesc.GetDmlDesc(); + + DML_ELEMENT_WISE_IDENTITY_OPERATOR_DESC transposedKeyOperatorDesc{}; + transposedKeyOperatorDesc.InputTensor = &namedReshapedTransposedKeyTensorDesc; + transposedKeyOperatorDesc.OutputTensor = &namedReshapedTransposedKeyOutputTensorDesc; + const DML_OPERATOR_DESC transposedKeyDesc {DML_OPERATOR_ELEMENT_WISE_IDENTITY, &transposedKeyOperatorDesc}; + uint32_t queryKeyTensorShape[4] = {batchSize, numHeads, sequenceLength, sequenceLength}; TensorDesc queryKeyTensorDesc = TensorDesc::ConstructDefaultTensorDesc(dataType, queryKeyTensorShape); DML_TENSOR_DESC namedQueryKeyTensorDesc = queryKeyTensorDesc.GetDmlDesc(); float alpha = static_cast(1 / sqrt(headSize)); DML_GEMM_OPERATOR_DESC attentionScoreOperatorDesc = {}; - attentionScoreOperatorDesc.ATensor = &namedReshapedTransposedQueryTensorDesc; - attentionScoreOperatorDesc.BTensor = &namedReshapedTransposedKeyTensorDesc; + attentionScoreOperatorDesc.ATensor = &namedReshapedTransposedQueryOutputTensorDesc; + attentionScoreOperatorDesc.BTensor = &namedReshapedTransposedKeyOutputTensorDesc; attentionScoreOperatorDesc.CTensor = &namedCastedMaskIndexTensorDesc; attentionScoreOperatorDesc.OutputTensor = &namedQueryKeyTensorDesc; attentionScoreOperatorDesc.TransA = DML_MATRIX_TRANSFORM_NONE; @@ -226,7 +246,7 @@ public: DML_GEMM_OPERATOR_DESC attentionWeightOperatorDesc = {}; attentionWeightOperatorDesc.ATensor = &namedQueryKeyTensorDesc; - attentionWeightOperatorDesc.BTensor = &namedReshapedTransposedQueryTensorDesc; + attentionWeightOperatorDesc.BTensor = &namedReshapedTransposedQueryOutputTensorDesc; attentionWeightOperatorDesc.CTensor = nullptr; attentionWeightOperatorDesc.OutputTensor = &namedReshapedTransposedOutputTensorDesc; attentionWeightOperatorDesc.TransA = DML_MATRIX_TRANSFORM_NONE; @@ -257,8 +277,11 @@ public: querySlice, keySlice, valueSlice, + queryTranspose, + keyTranspose, attentionScore, softmax, + valueTranspose, attentionWeight, castMaskIndex, mask, @@ -272,8 +295,11 @@ public: &querySlicedDesc, &keySlicedDesc, &valueSlicedDesc, + &transposedQueryDesc, + &transposedKeyDesc, &attentionScoreDesc, &softmaxDesc, + &transposedQueryDesc, &attentionWeightDesc, &castMaskIndexDesc, &maskDesc, @@ -325,19 +351,33 @@ public: gemmToValueSliceEdge.ToNodeInputIndex = 0; intermediateEdges.push_back(gemmToValueSliceEdge); - DML_INTERMEDIATE_GRAPH_EDGE_DESC querySliceToGemm = {}; - querySliceToGemm.FromNodeIndex = NodeIndex::querySlice; - querySliceToGemm.FromNodeOutputIndex = 0; - querySliceToGemm.ToNodeIndex = NodeIndex::attentionScore; - querySliceToGemm.ToNodeInputIndex = 0; - intermediateEdges.push_back(querySliceToGemm); + DML_INTERMEDIATE_GRAPH_EDGE_DESC querySliceToQueryTranspose = {}; + querySliceToQueryTranspose.FromNodeIndex = NodeIndex::querySlice; + querySliceToQueryTranspose.FromNodeOutputIndex = 0; + querySliceToQueryTranspose.ToNodeIndex = NodeIndex::queryTranspose; + querySliceToQueryTranspose.ToNodeInputIndex = 0; + intermediateEdges.push_back(querySliceToQueryTranspose); - DML_INTERMEDIATE_GRAPH_EDGE_DESC keySliceToGemm = {}; - keySliceToGemm.FromNodeIndex = NodeIndex::keySlice; - keySliceToGemm.FromNodeOutputIndex = 0; - keySliceToGemm.ToNodeIndex = NodeIndex::attentionScore; - keySliceToGemm.ToNodeInputIndex = 1; - intermediateEdges.push_back(keySliceToGemm); + DML_INTERMEDIATE_GRAPH_EDGE_DESC keySliceToKeyTranspose = {}; + keySliceToKeyTranspose.FromNodeIndex = NodeIndex::keySlice; + keySliceToKeyTranspose.FromNodeOutputIndex = 0; + keySliceToKeyTranspose.ToNodeIndex = NodeIndex::keyTranspose; + keySliceToKeyTranspose.ToNodeInputIndex = 0; + intermediateEdges.push_back(keySliceToKeyTranspose); + + DML_INTERMEDIATE_GRAPH_EDGE_DESC queryTransposeToGemm = {}; + queryTransposeToGemm.FromNodeIndex = NodeIndex::queryTranspose; + queryTransposeToGemm.FromNodeOutputIndex = 0; + queryTransposeToGemm.ToNodeIndex = NodeIndex::attentionScore; + queryTransposeToGemm.ToNodeInputIndex = 0; + intermediateEdges.push_back(queryTransposeToGemm); + + DML_INTERMEDIATE_GRAPH_EDGE_DESC keyTransposeToGemm = {}; + keyTransposeToGemm.FromNodeIndex = NodeIndex::keyTranspose; + keyTransposeToGemm.FromNodeOutputIndex = 0; + keyTransposeToGemm.ToNodeIndex = NodeIndex::attentionScore; + keyTransposeToGemm.ToNodeInputIndex = 1; + intermediateEdges.push_back(keyTransposeToGemm); DML_INTERMEDIATE_GRAPH_EDGE_DESC castedMaskIndexToIdentity = {}; castedMaskIndexToIdentity.FromNodeIndex = NodeIndex::castMaskIndex; @@ -367,12 +407,19 @@ public: softmaxToGemm.ToNodeInputIndex = 0; intermediateEdges.push_back(softmaxToGemm); - DML_INTERMEDIATE_GRAPH_EDGE_DESC valueSliceToGemm = {}; - valueSliceToGemm.FromNodeIndex = NodeIndex::valueSlice; - valueSliceToGemm.FromNodeOutputIndex = 0; - valueSliceToGemm.ToNodeIndex = NodeIndex::attentionWeight; - valueSliceToGemm.ToNodeInputIndex = 1; - intermediateEdges.push_back(valueSliceToGemm); + DML_INTERMEDIATE_GRAPH_EDGE_DESC valueSliceToValueTranspose = {}; + valueSliceToValueTranspose.FromNodeIndex = NodeIndex::valueSlice; + valueSliceToValueTranspose.FromNodeOutputIndex = 0; + valueSliceToValueTranspose.ToNodeIndex = NodeIndex::valueTranspose; + valueSliceToValueTranspose.ToNodeInputIndex = 0; + intermediateEdges.push_back(valueSliceToValueTranspose); + + DML_INTERMEDIATE_GRAPH_EDGE_DESC valueTransposeToGemm = {}; + valueTransposeToGemm.FromNodeIndex = NodeIndex::valueTranspose; + valueTransposeToGemm.FromNodeOutputIndex = 0; + valueTransposeToGemm.ToNodeIndex = NodeIndex::attentionWeight; + valueTransposeToGemm.ToNodeInputIndex = 1; + intermediateEdges.push_back(valueTransposeToGemm); DML_INTERMEDIATE_GRAPH_EDGE_DESC gemmToIdentity = {}; gemmToIdentity.FromNodeIndex = NodeIndex::attentionWeight;