diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorEinSum.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorEinSum.cpp index c8217e4343..d5bf54de53 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorEinSum.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorEinSum.cpp @@ -26,12 +26,13 @@ public: } inputIndices.resize(bindableInputCount); - DmlOperator::Initialize(kernelCreationContext, inputIndices, outputIndices); + constexpr uint32_t dimCount = 2; + DmlOperator::Initialize(kernelCreationContext, inputIndices, outputIndices, std::nullopt, std::nullopt, dimCount); std::vector inputDescs = GetDmlInputDescs(); std::vector outputDescs = GetDmlOutputDescs(); - static_assert(RecognizedOperatorType::Total == static_cast(11), "Update this switch."); + static_assert(RecognizedOperatorType::Total == static_cast(12), "Update this switch."); switch (m_recognizedOperatorType) { case RecognizedOperatorType::Multiply: @@ -45,6 +46,28 @@ public: } break; + case RecognizedOperatorType::OuterProduct: + { + std::array aSizes = {m_inputTensorDescs[0].GetSizes().back(), 1}; + TensorDesc aTensorDesc = TensorDesc(m_inputTensorDescs[0].GetDmlDataType(), aSizes); + auto aDmlTensorDesc = aTensorDesc.GetDmlDesc(); + + std::array bSizes = {1, m_inputTensorDescs[1].GetSizes().back()}; + TensorDesc bTensorDesc = TensorDesc(m_inputTensorDescs[1].GetDmlDataType(), bSizes); + auto bDmlTensorDesc = bTensorDesc.GetDmlDesc(); + + DML_GEMM_OPERATOR_DESC operatorDesc = {}; + operatorDesc.ATensor = &aDmlTensorDesc; + operatorDesc.BTensor = &bDmlTensorDesc; + operatorDesc.OutputTensor = &outputDescs[0]; + operatorDesc.Alpha = 1.0; + operatorDesc.Beta = 0.0; + operatorDesc.FusedActivation = nullptr; + + SetDmlOperatorDesc({ DML_OPERATOR_GEMM, &operatorDesc }, kernelCreationContext); + } + break; + case RecognizedOperatorType::MatMul: case RecognizedOperatorType::MatMulTransposeA: case RecognizedOperatorType::MatMulTransposeB: @@ -253,7 +276,7 @@ void CALLBACK QueryEinSum(IMLOperatorSupportQueryContextPrivate* context, bool* EinSumHelper helper(attributes); auto recognizedOperatorType = helper.GetRecognizedOperatorType(); - static_assert(EinSumHelper::RecognizedOperatorType::Total == static_cast(11), "Update this function."); + static_assert(EinSumHelper::RecognizedOperatorType::Total == static_cast(12), "Update this function."); *isSupported = (recognizedOperatorType != EinSumHelper::RecognizedOperatorType::None); } diff --git a/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.cpp b/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.cpp index 1ea466b771..403f15ebf4 100644 --- a/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.cpp +++ b/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.cpp @@ -1525,6 +1525,7 @@ namespace OperatorHelper {RecognizedOperatorType::MatMul, {2,2,2},{0,1, 1,2, 0,2}}, // ij,jk->ik {RecognizedOperatorType::MatMul, {3,3,3},{0,1,2, 0,2,3, 0,1,3}}, // bij,bjk->bik {RecognizedOperatorType::MatMul, {4,4,4},{0,1,2,3, 0,1,3,4, 0,1,2,4}}, // abij,abjk->abik + {RecognizedOperatorType::OuterProduct, {1,1,2},{0, 1, 0,1}}, // i,j->ij {RecognizedOperatorType::MatMulTransposeA, {2,2,2},{0,1, 0,2, 1,2}}, // ji,jk->ik {RecognizedOperatorType::MatMulTransposeA, {3,3,3},{0,1,2, 0,1,3, 0,2,3}}, // bji,bjk->bik {RecognizedOperatorType::MatMulTransposeA, {4,4,4},{0,1,2,3, 0,1,2,4, 0,1,3,4}}, // abji,abjk->abik @@ -1615,7 +1616,8 @@ namespace OperatorHelper bool EinSumHelper::IsMatMulOperatorType() const noexcept { - return m_recognizedOperatorType == RecognizedOperatorType::MatMul || + return m_recognizedOperatorType == RecognizedOperatorType::OuterProduct || + m_recognizedOperatorType == RecognizedOperatorType::MatMul || m_recognizedOperatorType == RecognizedOperatorType::MatMulTransposeA || m_recognizedOperatorType == RecognizedOperatorType::MatMulTransposeB || m_recognizedOperatorType == RecognizedOperatorType::MatMulNhcw || diff --git a/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.h b/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.h index 8372113393..f2be3cf05b 100644 --- a/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.h +++ b/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.h @@ -746,6 +746,7 @@ public: None, Identity, Multiply, + OuterProduct, MatMul, MatMulTransposeA, MatMulTransposeB,