mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-07-03 03:58:54 +00:00
[DML EP] Add Outer Product Einsum (#15850)
This operator is used in some LLMs to compute the rotary embeddings.
This commit is contained in:
parent
e0c1fa35a8
commit
c7b27f4486
3 changed files with 30 additions and 4 deletions
|
|
@ -26,12 +26,13 @@ public:
|
|||
}
|
||||
inputIndices.resize(bindableInputCount);
|
||||
|
||||
DmlOperator::Initialize(kernelCreationContext, inputIndices, outputIndices);
|
||||
constexpr uint32_t dimCount = 2;
|
||||
DmlOperator::Initialize(kernelCreationContext, inputIndices, outputIndices, std::nullopt, std::nullopt, dimCount);
|
||||
|
||||
std::vector<DML_TENSOR_DESC> inputDescs = GetDmlInputDescs();
|
||||
std::vector<DML_TENSOR_DESC> outputDescs = GetDmlOutputDescs();
|
||||
|
||||
static_assert(RecognizedOperatorType::Total == static_cast<RecognizedOperatorType>(11), "Update this switch.");
|
||||
static_assert(RecognizedOperatorType::Total == static_cast<RecognizedOperatorType>(12), "Update this switch.");
|
||||
switch (m_recognizedOperatorType)
|
||||
{
|
||||
case RecognizedOperatorType::Multiply:
|
||||
|
|
@ -45,6 +46,28 @@ public:
|
|||
}
|
||||
break;
|
||||
|
||||
case RecognizedOperatorType::OuterProduct:
|
||||
{
|
||||
std::array<uint32_t, 2> aSizes = {m_inputTensorDescs[0].GetSizes().back(), 1};
|
||||
TensorDesc aTensorDesc = TensorDesc(m_inputTensorDescs[0].GetDmlDataType(), aSizes);
|
||||
auto aDmlTensorDesc = aTensorDesc.GetDmlDesc();
|
||||
|
||||
std::array<uint32_t, 2> bSizes = {1, m_inputTensorDescs[1].GetSizes().back()};
|
||||
TensorDesc bTensorDesc = TensorDesc(m_inputTensorDescs[1].GetDmlDataType(), bSizes);
|
||||
auto bDmlTensorDesc = bTensorDesc.GetDmlDesc();
|
||||
|
||||
DML_GEMM_OPERATOR_DESC operatorDesc = {};
|
||||
operatorDesc.ATensor = &aDmlTensorDesc;
|
||||
operatorDesc.BTensor = &bDmlTensorDesc;
|
||||
operatorDesc.OutputTensor = &outputDescs[0];
|
||||
operatorDesc.Alpha = 1.0;
|
||||
operatorDesc.Beta = 0.0;
|
||||
operatorDesc.FusedActivation = nullptr;
|
||||
|
||||
SetDmlOperatorDesc({ DML_OPERATOR_GEMM, &operatorDesc }, kernelCreationContext);
|
||||
}
|
||||
break;
|
||||
|
||||
case RecognizedOperatorType::MatMul:
|
||||
case RecognizedOperatorType::MatMulTransposeA:
|
||||
case RecognizedOperatorType::MatMulTransposeB:
|
||||
|
|
@ -253,7 +276,7 @@ void CALLBACK QueryEinSum(IMLOperatorSupportQueryContextPrivate* context, bool*
|
|||
EinSumHelper helper(attributes);
|
||||
auto recognizedOperatorType = helper.GetRecognizedOperatorType();
|
||||
|
||||
static_assert(EinSumHelper::RecognizedOperatorType::Total == static_cast<EinSumHelper::RecognizedOperatorType>(11), "Update this function.");
|
||||
static_assert(EinSumHelper::RecognizedOperatorType::Total == static_cast<EinSumHelper::RecognizedOperatorType>(12), "Update this function.");
|
||||
*isSupported = (recognizedOperatorType != EinSumHelper::RecognizedOperatorType::None);
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -1525,6 +1525,7 @@ namespace OperatorHelper
|
|||
{RecognizedOperatorType::MatMul, {2,2,2},{0,1, 1,2, 0,2}}, // ij,jk->ik
|
||||
{RecognizedOperatorType::MatMul, {3,3,3},{0,1,2, 0,2,3, 0,1,3}}, // bij,bjk->bik
|
||||
{RecognizedOperatorType::MatMul, {4,4,4},{0,1,2,3, 0,1,3,4, 0,1,2,4}}, // abij,abjk->abik
|
||||
{RecognizedOperatorType::OuterProduct, {1,1,2},{0, 1, 0,1}}, // i,j->ij
|
||||
{RecognizedOperatorType::MatMulTransposeA, {2,2,2},{0,1, 0,2, 1,2}}, // ji,jk->ik
|
||||
{RecognizedOperatorType::MatMulTransposeA, {3,3,3},{0,1,2, 0,1,3, 0,2,3}}, // bji,bjk->bik
|
||||
{RecognizedOperatorType::MatMulTransposeA, {4,4,4},{0,1,2,3, 0,1,2,4, 0,1,3,4}}, // abji,abjk->abik
|
||||
|
|
@ -1615,7 +1616,8 @@ namespace OperatorHelper
|
|||
|
||||
bool EinSumHelper::IsMatMulOperatorType() const noexcept
|
||||
{
|
||||
return m_recognizedOperatorType == RecognizedOperatorType::MatMul ||
|
||||
return m_recognizedOperatorType == RecognizedOperatorType::OuterProduct ||
|
||||
m_recognizedOperatorType == RecognizedOperatorType::MatMul ||
|
||||
m_recognizedOperatorType == RecognizedOperatorType::MatMulTransposeA ||
|
||||
m_recognizedOperatorType == RecognizedOperatorType::MatMulTransposeB ||
|
||||
m_recognizedOperatorType == RecognizedOperatorType::MatMulNhcw ||
|
||||
|
|
|
|||
|
|
@ -746,6 +746,7 @@ public:
|
|||
None,
|
||||
Identity,
|
||||
Multiply,
|
||||
OuterProduct,
|
||||
MatMul,
|
||||
MatMulTransposeA,
|
||||
MatMulTransposeB,
|
||||
|
|
|
|||
Loading…
Reference in a new issue