[DML EP] Add Outer Product Einsum (#15850)

This operator is used in some LLMs to compute the rotary embeddings.
This commit is contained in:
Patrice Vignola 2023-05-09 15:51:55 -07:00 committed by GitHub
parent e0c1fa35a8
commit c7b27f4486
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 30 additions and 4 deletions

View file

@ -26,12 +26,13 @@ public:
}
inputIndices.resize(bindableInputCount);
DmlOperator::Initialize(kernelCreationContext, inputIndices, outputIndices);
constexpr uint32_t dimCount = 2;
DmlOperator::Initialize(kernelCreationContext, inputIndices, outputIndices, std::nullopt, std::nullopt, dimCount);
std::vector<DML_TENSOR_DESC> inputDescs = GetDmlInputDescs();
std::vector<DML_TENSOR_DESC> outputDescs = GetDmlOutputDescs();
static_assert(RecognizedOperatorType::Total == static_cast<RecognizedOperatorType>(11), "Update this switch.");
static_assert(RecognizedOperatorType::Total == static_cast<RecognizedOperatorType>(12), "Update this switch.");
switch (m_recognizedOperatorType)
{
case RecognizedOperatorType::Multiply:
@ -45,6 +46,28 @@ public:
}
break;
case RecognizedOperatorType::OuterProduct:
{
std::array<uint32_t, 2> aSizes = {m_inputTensorDescs[0].GetSizes().back(), 1};
TensorDesc aTensorDesc = TensorDesc(m_inputTensorDescs[0].GetDmlDataType(), aSizes);
auto aDmlTensorDesc = aTensorDesc.GetDmlDesc();
std::array<uint32_t, 2> bSizes = {1, m_inputTensorDescs[1].GetSizes().back()};
TensorDesc bTensorDesc = TensorDesc(m_inputTensorDescs[1].GetDmlDataType(), bSizes);
auto bDmlTensorDesc = bTensorDesc.GetDmlDesc();
DML_GEMM_OPERATOR_DESC operatorDesc = {};
operatorDesc.ATensor = &aDmlTensorDesc;
operatorDesc.BTensor = &bDmlTensorDesc;
operatorDesc.OutputTensor = &outputDescs[0];
operatorDesc.Alpha = 1.0;
operatorDesc.Beta = 0.0;
operatorDesc.FusedActivation = nullptr;
SetDmlOperatorDesc({ DML_OPERATOR_GEMM, &operatorDesc }, kernelCreationContext);
}
break;
case RecognizedOperatorType::MatMul:
case RecognizedOperatorType::MatMulTransposeA:
case RecognizedOperatorType::MatMulTransposeB:
@ -253,7 +276,7 @@ void CALLBACK QueryEinSum(IMLOperatorSupportQueryContextPrivate* context, bool*
EinSumHelper helper(attributes);
auto recognizedOperatorType = helper.GetRecognizedOperatorType();
static_assert(EinSumHelper::RecognizedOperatorType::Total == static_cast<EinSumHelper::RecognizedOperatorType>(11), "Update this function.");
static_assert(EinSumHelper::RecognizedOperatorType::Total == static_cast<EinSumHelper::RecognizedOperatorType>(12), "Update this function.");
*isSupported = (recognizedOperatorType != EinSumHelper::RecognizedOperatorType::None);
}

View file

@ -1525,6 +1525,7 @@ namespace OperatorHelper
{RecognizedOperatorType::MatMul, {2,2,2},{0,1, 1,2, 0,2}}, // ij,jk->ik
{RecognizedOperatorType::MatMul, {3,3,3},{0,1,2, 0,2,3, 0,1,3}}, // bij,bjk->bik
{RecognizedOperatorType::MatMul, {4,4,4},{0,1,2,3, 0,1,3,4, 0,1,2,4}}, // abij,abjk->abik
{RecognizedOperatorType::OuterProduct, {1,1,2},{0, 1, 0,1}}, // i,j->ij
{RecognizedOperatorType::MatMulTransposeA, {2,2,2},{0,1, 0,2, 1,2}}, // ji,jk->ik
{RecognizedOperatorType::MatMulTransposeA, {3,3,3},{0,1,2, 0,1,3, 0,2,3}}, // bji,bjk->bik
{RecognizedOperatorType::MatMulTransposeA, {4,4,4},{0,1,2,3, 0,1,2,4, 0,1,3,4}}, // abji,abjk->abik
@ -1615,7 +1616,8 @@ namespace OperatorHelper
bool EinSumHelper::IsMatMulOperatorType() const noexcept
{
return m_recognizedOperatorType == RecognizedOperatorType::MatMul ||
return m_recognizedOperatorType == RecognizedOperatorType::OuterProduct ||
m_recognizedOperatorType == RecognizedOperatorType::MatMul ||
m_recognizedOperatorType == RecognizedOperatorType::MatMulTransposeA ||
m_recognizedOperatorType == RecognizedOperatorType::MatMulTransposeB ||
m_recognizedOperatorType == RecognizedOperatorType::MatMulNhcw ||

View file

@ -746,6 +746,7 @@ public:
None,
Identity,
Multiply,
OuterProduct,
MatMul,
MatMulTransposeA,
MatMulTransposeB,