diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorMultiHeadAttention.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorMultiHeadAttention.cpp index 9c1a7baeaa..03500d0ee8 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorMultiHeadAttention.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorMultiHeadAttention.cpp @@ -205,12 +205,34 @@ public: else { const auto keyPaddingMaskTensorShape = m_inputTensorDescs[dmlMaskIndex].GetSizes(); - ML_CHECK_VALID_ARGUMENT(keyPaddingMaskTensorShape.size() == 2); + size_t maskDimCount = keyPaddingMaskTensorShape.size(); + ML_CHECK_VALID_ARGUMENT(maskDimCount >= 2 || maskDimCount <= 4); ML_CHECK_VALID_ARGUMENT(keyPaddingMaskTensorShape[0] == batchSize); - ML_CHECK_VALID_ARGUMENT(keyPaddingMaskTensorShape[1] == kvSequenceLength); - const uint32_t actualShape[4] = {batchSize, 1, 1, kvSequenceLength}; - const uint32_t desiredShape[4] = {batchSize, numHeads, sequenceLength, kvSequenceLength}; + std::array actualShape{}; + std::array desiredShape{}; + + if (maskDimCount == 2) + { + ML_CHECK_VALID_ARGUMENT(keyPaddingMaskTensorShape[1] == kvSequenceLength); + actualShape = {batchSize, 1, 1, kvSequenceLength}; + desiredShape = {batchSize, numHeads, sequenceLength, kvSequenceLength}; + } + else if (maskDimCount == 3) + { + ML_CHECK_VALID_ARGUMENT(keyPaddingMaskTensorShape[1] == sequenceLength); + ML_CHECK_VALID_ARGUMENT(keyPaddingMaskTensorShape[2] == totalSequenceLength); + actualShape = {batchSize, 1, sequenceLength, totalSequenceLength}; + desiredShape = {batchSize, numHeads, sequenceLength, totalSequenceLength}; + } + else if (maskDimCount == 4) + { + ML_CHECK_VALID_ARGUMENT(keyPaddingMaskTensorShape[1] == numHeads); + ML_CHECK_VALID_ARGUMENT(keyPaddingMaskTensorShape[2] == sequenceLength); + ML_CHECK_VALID_ARGUMENT(keyPaddingMaskTensorShape[3] == totalSequenceLength); + actualShape = {batchSize, numHeads, sequenceLength, totalSequenceLength}; + desiredShape = {batchSize, numHeads, sequenceLength, totalSequenceLength}; + } m_inputTensorDescs[dmlMaskIndex] = TensorDesc::ConstructBroadcastedTensorDesc( m_inputTensorDescs[dmlMaskIndex].GetMlOperatorDataType(),