[DML EP] Enable more MHA masks (#18120)

Those masks are used for MHA in LLaMA.
This commit is contained in:
Patrice Vignola 2023-10-30 14:02:51 -07:00 committed by GitHub
parent c829550180
commit 53cb9424a4
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -205,12 +205,34 @@ public:
else
{
const auto keyPaddingMaskTensorShape = m_inputTensorDescs[dmlMaskIndex].GetSizes();
ML_CHECK_VALID_ARGUMENT(keyPaddingMaskTensorShape.size() == 2);
size_t maskDimCount = keyPaddingMaskTensorShape.size();
ML_CHECK_VALID_ARGUMENT(maskDimCount >= 2 || maskDimCount <= 4);
ML_CHECK_VALID_ARGUMENT(keyPaddingMaskTensorShape[0] == batchSize);
ML_CHECK_VALID_ARGUMENT(keyPaddingMaskTensorShape[1] == kvSequenceLength);
const uint32_t actualShape[4] = {batchSize, 1, 1, kvSequenceLength};
const uint32_t desiredShape[4] = {batchSize, numHeads, sequenceLength, kvSequenceLength};
std::array<uint32_t, 4> actualShape{};
std::array<uint32_t, 4> desiredShape{};
if (maskDimCount == 2)
{
ML_CHECK_VALID_ARGUMENT(keyPaddingMaskTensorShape[1] == kvSequenceLength);
actualShape = {batchSize, 1, 1, kvSequenceLength};
desiredShape = {batchSize, numHeads, sequenceLength, kvSequenceLength};
}
else if (maskDimCount == 3)
{
ML_CHECK_VALID_ARGUMENT(keyPaddingMaskTensorShape[1] == sequenceLength);
ML_CHECK_VALID_ARGUMENT(keyPaddingMaskTensorShape[2] == totalSequenceLength);
actualShape = {batchSize, 1, sequenceLength, totalSequenceLength};
desiredShape = {batchSize, numHeads, sequenceLength, totalSequenceLength};
}
else if (maskDimCount == 4)
{
ML_CHECK_VALID_ARGUMENT(keyPaddingMaskTensorShape[1] == numHeads);
ML_CHECK_VALID_ARGUMENT(keyPaddingMaskTensorShape[2] == sequenceLength);
ML_CHECK_VALID_ARGUMENT(keyPaddingMaskTensorShape[3] == totalSequenceLength);
actualShape = {batchSize, numHeads, sequenceLength, totalSequenceLength};
desiredShape = {batchSize, numHeads, sequenceLength, totalSequenceLength};
}
m_inputTensorDescs[dmlMaskIndex] = TensorDesc::ConstructBroadcastedTensorDesc(
m_inputTensorDescs[dmlMaskIndex].GetMlOperatorDataType(),