mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-06 00:03:22 +00:00
[DML EP] Enable more MHA masks (#18120)
Those masks are used for MHA in LLaMA.
This commit is contained in:
parent
c829550180
commit
53cb9424a4
1 changed files with 26 additions and 4 deletions
|
|
@ -205,12 +205,34 @@ public:
|
|||
else
|
||||
{
|
||||
const auto keyPaddingMaskTensorShape = m_inputTensorDescs[dmlMaskIndex].GetSizes();
|
||||
ML_CHECK_VALID_ARGUMENT(keyPaddingMaskTensorShape.size() == 2);
|
||||
size_t maskDimCount = keyPaddingMaskTensorShape.size();
|
||||
ML_CHECK_VALID_ARGUMENT(maskDimCount >= 2 || maskDimCount <= 4);
|
||||
ML_CHECK_VALID_ARGUMENT(keyPaddingMaskTensorShape[0] == batchSize);
|
||||
ML_CHECK_VALID_ARGUMENT(keyPaddingMaskTensorShape[1] == kvSequenceLength);
|
||||
|
||||
const uint32_t actualShape[4] = {batchSize, 1, 1, kvSequenceLength};
|
||||
const uint32_t desiredShape[4] = {batchSize, numHeads, sequenceLength, kvSequenceLength};
|
||||
std::array<uint32_t, 4> actualShape{};
|
||||
std::array<uint32_t, 4> desiredShape{};
|
||||
|
||||
if (maskDimCount == 2)
|
||||
{
|
||||
ML_CHECK_VALID_ARGUMENT(keyPaddingMaskTensorShape[1] == kvSequenceLength);
|
||||
actualShape = {batchSize, 1, 1, kvSequenceLength};
|
||||
desiredShape = {batchSize, numHeads, sequenceLength, kvSequenceLength};
|
||||
}
|
||||
else if (maskDimCount == 3)
|
||||
{
|
||||
ML_CHECK_VALID_ARGUMENT(keyPaddingMaskTensorShape[1] == sequenceLength);
|
||||
ML_CHECK_VALID_ARGUMENT(keyPaddingMaskTensorShape[2] == totalSequenceLength);
|
||||
actualShape = {batchSize, 1, sequenceLength, totalSequenceLength};
|
||||
desiredShape = {batchSize, numHeads, sequenceLength, totalSequenceLength};
|
||||
}
|
||||
else if (maskDimCount == 4)
|
||||
{
|
||||
ML_CHECK_VALID_ARGUMENT(keyPaddingMaskTensorShape[1] == numHeads);
|
||||
ML_CHECK_VALID_ARGUMENT(keyPaddingMaskTensorShape[2] == sequenceLength);
|
||||
ML_CHECK_VALID_ARGUMENT(keyPaddingMaskTensorShape[3] == totalSequenceLength);
|
||||
actualShape = {batchSize, numHeads, sequenceLength, totalSequenceLength};
|
||||
desiredShape = {batchSize, numHeads, sequenceLength, totalSequenceLength};
|
||||
}
|
||||
|
||||
m_inputTensorDescs[dmlMaskIndex] = TensorDesc::ConstructBroadcastedTensorDesc(
|
||||
m_inputTensorDescs[dmlMaskIndex].GetMlOperatorDataType(),
|
||||
|
|
|
|||
Loading…
Reference in a new issue