From 53cb9424a471fdbfd0f408effdd7bd0f7f95bbb0 Mon Sep 17 00:00:00 2001 From: Patrice Vignola Date: Mon, 30 Oct 2023 14:02:51 -0700 Subject: [PATCH] [DML EP] Enable more MHA masks (#18120) Those masks are used for MHA in LLaMA. --- .../DmlOperatorMultiHeadAttention.cpp | 30 ++++++++++++++++--- 1 file changed, 26 insertions(+), 4 deletions(-) diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorMultiHeadAttention.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorMultiHeadAttention.cpp index 9c1a7baeaa..03500d0ee8 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorMultiHeadAttention.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorMultiHeadAttention.cpp @@ -205,12 +205,34 @@ public: else { const auto keyPaddingMaskTensorShape = m_inputTensorDescs[dmlMaskIndex].GetSizes(); - ML_CHECK_VALID_ARGUMENT(keyPaddingMaskTensorShape.size() == 2); + size_t maskDimCount = keyPaddingMaskTensorShape.size(); + ML_CHECK_VALID_ARGUMENT(maskDimCount >= 2 || maskDimCount <= 4); ML_CHECK_VALID_ARGUMENT(keyPaddingMaskTensorShape[0] == batchSize); - ML_CHECK_VALID_ARGUMENT(keyPaddingMaskTensorShape[1] == kvSequenceLength); - const uint32_t actualShape[4] = {batchSize, 1, 1, kvSequenceLength}; - const uint32_t desiredShape[4] = {batchSize, numHeads, sequenceLength, kvSequenceLength}; + std::array actualShape{}; + std::array desiredShape{}; + + if (maskDimCount == 2) + { + ML_CHECK_VALID_ARGUMENT(keyPaddingMaskTensorShape[1] == kvSequenceLength); + actualShape = {batchSize, 1, 1, kvSequenceLength}; + desiredShape = {batchSize, numHeads, sequenceLength, kvSequenceLength}; + } + else if (maskDimCount == 3) + { + ML_CHECK_VALID_ARGUMENT(keyPaddingMaskTensorShape[1] == sequenceLength); + ML_CHECK_VALID_ARGUMENT(keyPaddingMaskTensorShape[2] == totalSequenceLength); + actualShape = {batchSize, 1, sequenceLength, totalSequenceLength}; + desiredShape = {batchSize, numHeads, sequenceLength, totalSequenceLength}; + } + else if (maskDimCount == 4) + { + ML_CHECK_VALID_ARGUMENT(keyPaddingMaskTensorShape[1] == numHeads); + ML_CHECK_VALID_ARGUMENT(keyPaddingMaskTensorShape[2] == sequenceLength); + ML_CHECK_VALID_ARGUMENT(keyPaddingMaskTensorShape[3] == totalSequenceLength); + actualShape = {batchSize, numHeads, sequenceLength, totalSequenceLength}; + desiredShape = {batchSize, numHeads, sequenceLength, totalSequenceLength}; + } m_inputTensorDescs[dmlMaskIndex] = TensorDesc::ConstructBroadcastedTensorDesc( m_inputTensorDescs[dmlMaskIndex].GetMlOperatorDataType(),