From ac48bdec891942f7467fe66699e97fd3cdf6f2a4 Mon Sep 17 00:00:00 2001 From: Patrice Vignola Date: Tue, 25 Oct 2022 23:09:07 -0700 Subject: [PATCH] DML EP add einsum MatMul NHCW ops (#13440) ### Description This adds the "NHCW" format support for einsum MatMul. The logic is basically a merge of the existing Transpose and MatMul Einsum implementations. ### Motivation and Context Some transformer models that I'm tracking use Einsum quite often during a single inference, and about half of those were "NHCW" MatMul Einsums. Supporting them will reduce the number of copies to the CPU. --- .../src/Operators/DmlOperatorEinSum.cpp | 82 ++++++++++++++++++- .../OperatorAuthorHelper/OperatorHelper.cpp | 38 +++++---- .../dml/OperatorAuthorHelper/OperatorHelper.h | 23 +++--- .../test/providers/cpu/math/einsum_test.cc | 27 ++++++ 4 files changed, 141 insertions(+), 29 deletions(-) diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorEinSum.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorEinSum.cpp index fc65f07306..f50c451818 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorEinSum.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorEinSum.cpp @@ -10,7 +10,7 @@ class DmlOperatorEinSum : public DmlOperator, public EinSumHelper { public: DmlOperatorEinSum(const MLOperatorKernelCreationContext& kernelCreationContext, uint32_t opsetVersion) - : DmlOperator(kernelCreationContext), + : DmlOperator(kernelCreationContext), EinSumHelper(kernelCreationContext, kernelCreationContext.GetTensorShapeDescription(), opsetVersion) { ML_CHECK_VALID_ARGUMENT(kernelCreationContext.GetInputCount() + 1 == m_components.size(), "EinSum input tensor count is inconsistent with the equation component count."); @@ -30,7 +30,7 @@ public: std::vector inputDescs = GetDmlInputDescs(); std::vector outputDescs = GetDmlOutputDescs(); - static_assert(RecognizedOperatorType::Total == static_cast(8), "Update this switch."); + static_assert(RecognizedOperatorType::Total == static_cast(11), "Update this switch."); switch (m_recognizedOperatorType) { case RecognizedOperatorType::Multiply: @@ -62,6 +62,82 @@ public: SetDmlOperatorDesc({ DML_OPERATOR_GEMM, &operatorDesc }, kernelCreationContext); } break; + case RecognizedOperatorType::MatMulNhcw: + case RecognizedOperatorType::MatMulNhcwTransposeA: + case RecognizedOperatorType::MatMulNhcwTransposeB: + { + // Transpose via input strides. The output tensor is not strided. Support only 4D for now. + assert(m_components.size() == 3); + assert(m_components[0].GetDimensionCount() == m_components[2].GetDimensionCount()); + assert(m_components[1].GetDimensionCount() == m_components[2].GetDimensionCount()); + assert(m_components[2].GetDimensionCount() == 4); + + // Remap transposed strides from NCHW to NHCW + constexpr std::array labelIndices = {0, 2, 1, 3}; + + assert(m_inputTensorDescs.size() >= 2); + for (uint32_t i = 0; i < 2; ++i) + { + TensorDesc& tensorDesc = m_inputTensorDescs[i]; + auto originalStrides = tensorDesc.GetStrides(); + std::vector inputSizes = kernelCreationContext.GetTensorShapeDescription().GetInputTensorShape(i); + std::vector inputStrides(inputSizes.size()); + + // If there were no strides, compute them based in descending packed order + // based on the input sizes. + if (originalStrides.empty()) + { + Dml::GetDescendingPackedStrides(inputSizes, /*out*/ inputStrides); + } + else // Copy the original strides. + { + assert(originalStrides.size() >= inputStrides.size()); + size_t offset = originalStrides.size() - inputStrides.size(); + inputStrides.assign(originalStrides.begin() + offset, originalStrides.end()); + } + + std::vector newStrides(inputStrides.size()); + std::vector newSizes(inputStrides.size()); + for (size_t i = 0, dimensionCount = inputStrides.size(); i < dimensionCount; ++i) + { + uint32_t labelIndex = labelIndices[i]; + assert(labelIndex < inputStrides.size()); + newSizes[i] = inputSizes[labelIndex]; + newStrides[i] = inputStrides[labelIndex]; + } + + // Override the initial input tensor with the new strides. + tensorDesc = TensorDesc(tensorDesc.GetDmlDataType(), newSizes, newStrides, 0); + tensorDesc.GetDmlDesc(); // Discard value, but keep side effect of refreshing the DML view. + } + + std::vector outputSizes = kernelCreationContext.GetTensorShapeDescription().GetOutputTensorShape(0); + std::vector newOutputSizes(outputSizes.size()); + assert(outputSizes.size() == labelIndices.size()); + + for (size_t i = 0; i < outputSizes.size(); ++i) + { + uint32_t labelIndex = labelIndices[i]; + newOutputSizes[i] = outputSizes[labelIndex]; + } + + m_outputTensorDescs.front() = TensorDesc(m_outputTensorDescs.front().GetDmlDataType(), newOutputSizes, std::nullopt, 0); + m_outputTensorDescs.front().GetDmlDesc(); // Discard value, but keep side effect of refreshing the DML view. + + DML_GEMM_OPERATOR_DESC operatorDesc = {}; + operatorDesc.ATensor = &inputDescs[0]; + operatorDesc.BTensor = &inputDescs[1]; + // No operatorDesc.CTensor + operatorDesc.OutputTensor = &outputDescs[0]; + operatorDesc.TransA = (m_recognizedOperatorType == RecognizedOperatorType::MatMulNhcwTransposeA) ? DML_MATRIX_TRANSFORM_TRANSPOSE : DML_MATRIX_TRANSFORM_NONE; + operatorDesc.TransB = (m_recognizedOperatorType == RecognizedOperatorType::MatMulNhcwTransposeB) ? DML_MATRIX_TRANSFORM_TRANSPOSE : DML_MATRIX_TRANSFORM_NONE; + operatorDesc.Alpha = 1.0; + operatorDesc.Beta = 0.0; + operatorDesc.FusedActivation = nullptr; + + SetDmlOperatorDesc({ DML_OPERATOR_GEMM, &operatorDesc }, kernelCreationContext); + } + break; case RecognizedOperatorType::ReduceSum: { @@ -176,7 +252,7 @@ void CALLBACK QueryEinSum(IMLOperatorSupportQueryContextPrivate* context, bool* EinSumHelper helper(attributes); auto recognizedOperatorType = helper.GetRecognizedOperatorType(); - static_assert(EinSumHelper::RecognizedOperatorType::Total == static_cast(8), "Verify this test still matches the switch above."); + static_assert(EinSumHelper::RecognizedOperatorType::Total == static_cast(11), "Update this function."); *isSupported = (recognizedOperatorType != EinSumHelper::RecognizedOperatorType::None); } diff --git a/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.cpp b/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.cpp index 32a3a1b053..22d39fe685 100644 --- a/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.cpp +++ b/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.cpp @@ -609,7 +609,7 @@ namespace OperatorHelper // `transBatch` needs to be applied first and then `transpose`. if (transBatch) { - ML_CHECK_VALID_ARGUMENT(dimensionCount > 2, + ML_CHECK_VALID_ARGUMENT(dimensionCount > 2, "FusedMatMul operator: Tensor size should be more than 2, if attribute transBatch is true"); std::rotate(newSizes.begin(), newSizes.end() - 2, newSizes.end() - 1); @@ -702,7 +702,7 @@ namespace OperatorHelper if (inputShape0 != inputShape1) { ML_CHECK_VALID_ARGUMENT( - inputShape0.size() == inputShape1.size() && + inputShape0.size() == inputShape1.size() && inputShape0.size() == inputStride0.size() && inputStride0.size() == inputStride1.size(), "Size of inputShape0, inputStride0, inputShape1 and inputStride1 should be same while broadcasting"); @@ -715,7 +715,7 @@ namespace OperatorHelper auto inStride0Iter = inputStride0.rbegin(); auto inStride1Iter = inputStride1.rbegin(); - + while (rank-- > 0) { DimensionType inDimension0 = *inDim0Iter; @@ -1503,18 +1503,21 @@ namespace OperatorHelper }; const RecognizedOperatorInfo recognizedOperators[] = { - {RecognizedOperatorType::MatMul, {2,2,2},{0,1, 1,2, 0,2}}, // ij,jk->ik - {RecognizedOperatorType::MatMul, {3,3,3},{0,1,2, 0,2,3, 0,1,3}}, // bij,bjk->bik - {RecognizedOperatorType::MatMul, {4,4,4},{0,1,2,3, 0,1,3,4, 0,1,2,4}}, // abij,abjk->abik - {RecognizedOperatorType::MatMulTransposeA, {2,2,2},{0,1, 0,2, 1,2}}, // ji,jk->ik - {RecognizedOperatorType::MatMulTransposeA, {3,3,3},{0,1,2, 0,1,3, 0,2,3}}, // bji,bjk->bik - {RecognizedOperatorType::MatMulTransposeA, {4,4,4},{0,1,2,3, 0,1,2,4, 0,1,3,4}}, // abji,abjk->abik - {RecognizedOperatorType::MatMulTransposeB, {2,2,2},{0,1, 2,1, 0,2}}, // ij,kj->ik - {RecognizedOperatorType::MatMulTransposeB, {3,3,3},{0,1,2, 0,3,2, 0,1,3}}, // bij,bkj->bik - {RecognizedOperatorType::MatMulTransposeB, {4,4,4},{0,1,2,3, 0,1,4,3, 0,1,2,4}}, // abij,abkj->abik - {RecognizedOperatorType::MatMulTransposeB, {1,1,0},{0,0,}}, // i,i-> (1D inner_prod) - {RecognizedOperatorType::ReduceSum, {2,1 },{0,1, 0}}, // ij->i - {RecognizedOperatorType::ReduceSum, {2,1 },{0,1, 1}}, // ij->j + {RecognizedOperatorType::MatMul, {2,2,2},{0,1, 1,2, 0,2}}, // ij,jk->ik + {RecognizedOperatorType::MatMul, {3,3,3},{0,1,2, 0,2,3, 0,1,3}}, // bij,bjk->bik + {RecognizedOperatorType::MatMul, {4,4,4},{0,1,2,3, 0,1,3,4, 0,1,2,4}}, // abij,abjk->abik + {RecognizedOperatorType::MatMulTransposeA, {2,2,2},{0,1, 0,2, 1,2}}, // ji,jk->ik + {RecognizedOperatorType::MatMulTransposeA, {3,3,3},{0,1,2, 0,1,3, 0,2,3}}, // bji,bjk->bik + {RecognizedOperatorType::MatMulTransposeA, {4,4,4},{0,1,2,3, 0,1,2,4, 0,1,3,4}}, // abji,abjk->abik + {RecognizedOperatorType::MatMulTransposeB, {2,2,2},{0,1, 2,1, 0,2}}, // ij,kj->ik + {RecognizedOperatorType::MatMulTransposeB, {3,3,3},{0,1,2, 0,3,2, 0,1,3}}, // bij,bkj->bik + {RecognizedOperatorType::MatMulTransposeB, {4,4,4},{0,1,2,3, 0,1,4,3, 0,1,2,4}}, // abij,abkj->abik + {RecognizedOperatorType::MatMulTransposeB, {1,1,0},{0,0,}}, // i,i-> (1D inner_prod) + {RecognizedOperatorType::MatMulNhcw, {4,4,4},{0,1,2,3, 0,3,2,4, 0,1,2,4}}, // aibj,ajbk->aibk + {RecognizedOperatorType::MatMulNhcwTransposeA, {4,4,4},{0,1,2,3, 0,1,2,4, 0,3,2,4}}, // ajbi,ajbk->aibk + {RecognizedOperatorType::MatMulNhcwTransposeB, {4,4,4},{0,1,2,3, 0,4,2,3, 0,1,2,4}}, // aibj,akbj->aibk + {RecognizedOperatorType::ReduceSum, {2,1 },{0,1, 0}}, // ij->i + {RecognizedOperatorType::ReduceSum, {2,1 },{0,1, 1}}, // ij->j }; // For each recognized operator, compare the labels-per-component and label indices. @@ -1595,7 +1598,10 @@ namespace OperatorHelper { return m_recognizedOperatorType == RecognizedOperatorType::MatMul || m_recognizedOperatorType == RecognizedOperatorType::MatMulTransposeA || - m_recognizedOperatorType == RecognizedOperatorType::MatMulTransposeB; + m_recognizedOperatorType == RecognizedOperatorType::MatMulTransposeB || + m_recognizedOperatorType == RecognizedOperatorType::MatMulNhcw || + m_recognizedOperatorType == RecognizedOperatorType::MatMulNhcwTransposeA || + m_recognizedOperatorType == RecognizedOperatorType::MatMulNhcwTransposeB; } std::vector MatMulHelperBase::GetOutputShapes(const MLShapeInferenceContext& shapeInfo) const diff --git a/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.h b/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.h index 12928d5515..94689c21f2 100644 --- a/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.h +++ b/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.h @@ -234,7 +234,7 @@ void FusedMatMulShapeMapping( std::vector& outputShape); std::pair, std::vector> GetFusedMatMulSizesAndStrides( - gsl::span sizes, + gsl::span sizes, int32_t transBatch = 0, int32_t transpose = 0); @@ -437,7 +437,7 @@ public: enum InputDims { N, C, H, W }; public: - // Info_t is used to obtain attributes which will be used for calculating the output shape later. + // Info_t is used to obtain attributes which will be used for calculating the output shape later. template ConvolutionHelperBase(const Info_t& info, const Shape_t& shape, bool transpose, bool hasDynamicPads, uint32_t inputTensorIndex, uint32_t filterTensorIndex) : m_inputTensorIndex(inputTensorIndex), @@ -445,7 +445,7 @@ public: m_kernel(InitializeKernel(info, shape.GetInputTensorDimensionCount(inputTensorIndex), shape.GetInputTensorShape(filterTensorIndex))) { m_groupCount = info.template GetOptionalAttribute(AttrName::Group, 1); - + if (!transpose) { InitializeKernelAndShapes(ShapeInformationAdapter(shape)); @@ -507,8 +507,8 @@ public: class GemmHelper { public: - // Info_t is used to obtain attributes which will be used for calculating the output shape later. - // Shape_t is used to obtain input shape which will be used for adjusting attribute value. + // Info_t is used to obtain attributes which will be used for calculating the output shape later. + // Shape_t is used to obtain input shape which will be used for adjusting attribute value. template GemmHelper(const Info_t& info, const Shape_t& shape) { @@ -591,8 +591,8 @@ class SliceHelper ); public: - // Info_t is used to obtain attributes which will be used for calculating the output shape later. - // Shape_t is used to obtain input shape which will be used for adjusting attribute value. + // Info_t is used to obtain attributes which will be used for calculating the output shape later. + // Shape_t is used to obtain input shape which will be used for adjusting attribute value. template SliceHelper(const Info_t& info, const Shape_t& shape, uint32_t opsetVersion) { @@ -722,6 +722,9 @@ public: MatMul, MatMulTransposeA, MatMulTransposeB, + MatMulNhcw, + MatMulNhcwTransposeA, + MatMulNhcwTransposeB, ReduceSum, Transpose, Total, @@ -740,7 +743,7 @@ protected: { uint32_t labelIndexBegin; uint32_t labelIndexEnd; - + uint32_t GetDimensionCount() const noexcept { return labelIndexEnd - labelIndexBegin; @@ -1037,8 +1040,8 @@ protected: class UnpoolingHelper { public: - // Info_t is used to obtain attributes which will be used for calculating the output shape later. - // Shape_t is used to obtain input shape which will be used for adjusting attribute value. + // Info_t is used to obtain attributes which will be used for calculating the output shape later. + // Shape_t is used to obtain input shape which will be used for adjusting attribute value. template UnpoolingHelper( const Info_t& info, diff --git a/onnxruntime/test/providers/cpu/math/einsum_test.cc b/onnxruntime/test/providers/cpu/math/einsum_test.cc index ef5c9bb8a9..6c9e429fed 100644 --- a/onnxruntime/test/providers/cpu/math/einsum_test.cc +++ b/onnxruntime/test/providers/cpu/math/einsum_test.cc @@ -179,6 +179,33 @@ TEST(Einsum, ExplicitEinsumAsMatmul) { test.Run(); } +TEST(Einsum, ExplicitEinsumAsMatmulNhcw) { + OpTester test("Einsum", 12, onnxruntime::kOnnxDomain); + test.AddAttribute("equation", "aibj,ajbk->aibk"); + test.AddInput("x", {1, 3, 1, 2}, {1.f, 2.f, 3.f, 4.f, 5.f, 6.f}); + test.AddInput("y", {1, 2, 1, 3}, {1.f, 2.f, 3.f, 4.f, 5.f, 6.f}); + test.AddOutput("o", {1, 3, 1, 3}, {9.f, 12.f, 15.f, 19.f, 26.f, 33.f, 29.f, 40.f, 51.f}); + test.Run(); +} + +TEST(Einsum, ExplicitEinsumAsMatmulNhcwTransposeA) { + OpTester test("Einsum", 12, onnxruntime::kOnnxDomain); + test.AddAttribute("equation", "ajbi,ajbk->aibk"); + test.AddInput("x", {1, 2, 1, 3}, {1.f, 2.f, 3.f, 4.f, 5.f, 6.f}); + test.AddInput("y", {1, 2, 1, 3}, {1.f, 2.f, 3.f, 4.f, 5.f, 6.f}); + test.AddOutput("o", {1, 3, 1, 3}, {17.f, 22.f, 27.f, 22.f, 29.f, 36.f, 27.f, 36.f, 45.f}); + test.Run(); +} + +TEST(Einsum, ExplicitEinsumAsMatmulNhcwTransposeB) { + OpTester test("Einsum", 12, onnxruntime::kOnnxDomain); + test.AddAttribute("equation", "aibj,akbj->aibk"); + test.AddInput("x", {1, 3, 1, 2}, {1.f, 2.f, 3.f, 4.f, 5.f, 6.f}); + test.AddInput("y", {1, 3, 1, 2}, {1.f, 2.f, 3.f, 4.f, 5.f, 6.f}); + test.AddOutput("o", {1, 3, 1, 3}, {5.f, 11.f, 17.f, 11.f, 25.f, 39.f, 17.f, 39.f, 61.f}); + test.Run(); +} + TEST(Einsum, ExplicitEinsumAsMatmulWithUpperCasedLabel) { OpTester test("Einsum", 12, onnxruntime::kOnnxDomain); // 'K' != 'k' (and dim values differ too) and Einsum should handle be able to handle that