mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-29 23:06:41 +00:00
DML EP add einsum MatMul NHCW ops (#13440)
### Description This adds the "NHCW" format support for einsum MatMul. The logic is basically a merge of the existing Transpose and MatMul Einsum implementations. ### Motivation and Context Some transformer models that I'm tracking use Einsum quite often during a single inference, and about half of those were "NHCW" MatMul Einsums. Supporting them will reduce the number of copies to the CPU.
This commit is contained in:
parent
d5e8d59243
commit
ac48bdec89
4 changed files with 141 additions and 29 deletions
|
|
@ -10,7 +10,7 @@ class DmlOperatorEinSum : public DmlOperator, public EinSumHelper
|
|||
{
|
||||
public:
|
||||
DmlOperatorEinSum(const MLOperatorKernelCreationContext& kernelCreationContext, uint32_t opsetVersion)
|
||||
: DmlOperator(kernelCreationContext),
|
||||
: DmlOperator(kernelCreationContext),
|
||||
EinSumHelper(kernelCreationContext, kernelCreationContext.GetTensorShapeDescription(), opsetVersion)
|
||||
{
|
||||
ML_CHECK_VALID_ARGUMENT(kernelCreationContext.GetInputCount() + 1 == m_components.size(), "EinSum input tensor count is inconsistent with the equation component count.");
|
||||
|
|
@ -30,7 +30,7 @@ public:
|
|||
std::vector<DML_TENSOR_DESC> inputDescs = GetDmlInputDescs();
|
||||
std::vector<DML_TENSOR_DESC> outputDescs = GetDmlOutputDescs();
|
||||
|
||||
static_assert(RecognizedOperatorType::Total == static_cast<RecognizedOperatorType>(8), "Update this switch.");
|
||||
static_assert(RecognizedOperatorType::Total == static_cast<RecognizedOperatorType>(11), "Update this switch.");
|
||||
switch (m_recognizedOperatorType)
|
||||
{
|
||||
case RecognizedOperatorType::Multiply:
|
||||
|
|
@ -62,6 +62,82 @@ public:
|
|||
SetDmlOperatorDesc({ DML_OPERATOR_GEMM, &operatorDesc }, kernelCreationContext);
|
||||
}
|
||||
break;
|
||||
case RecognizedOperatorType::MatMulNhcw:
|
||||
case RecognizedOperatorType::MatMulNhcwTransposeA:
|
||||
case RecognizedOperatorType::MatMulNhcwTransposeB:
|
||||
{
|
||||
// Transpose via input strides. The output tensor is not strided. Support only 4D for now.
|
||||
assert(m_components.size() == 3);
|
||||
assert(m_components[0].GetDimensionCount() == m_components[2].GetDimensionCount());
|
||||
assert(m_components[1].GetDimensionCount() == m_components[2].GetDimensionCount());
|
||||
assert(m_components[2].GetDimensionCount() == 4);
|
||||
|
||||
// Remap transposed strides from NCHW to NHCW
|
||||
constexpr std::array<uint32_t, 4> labelIndices = {0, 2, 1, 3};
|
||||
|
||||
assert(m_inputTensorDescs.size() >= 2);
|
||||
for (uint32_t i = 0; i < 2; ++i)
|
||||
{
|
||||
TensorDesc& tensorDesc = m_inputTensorDescs[i];
|
||||
auto originalStrides = tensorDesc.GetStrides();
|
||||
std::vector<uint32_t> inputSizes = kernelCreationContext.GetTensorShapeDescription().GetInputTensorShape(i);
|
||||
std::vector<uint32_t> inputStrides(inputSizes.size());
|
||||
|
||||
// If there were no strides, compute them based in descending packed order
|
||||
// based on the input sizes.
|
||||
if (originalStrides.empty())
|
||||
{
|
||||
Dml::GetDescendingPackedStrides(inputSizes, /*out*/ inputStrides);
|
||||
}
|
||||
else // Copy the original strides.
|
||||
{
|
||||
assert(originalStrides.size() >= inputStrides.size());
|
||||
size_t offset = originalStrides.size() - inputStrides.size();
|
||||
inputStrides.assign(originalStrides.begin() + offset, originalStrides.end());
|
||||
}
|
||||
|
||||
std::vector<uint32_t> newStrides(inputStrides.size());
|
||||
std::vector<uint32_t> newSizes(inputStrides.size());
|
||||
for (size_t i = 0, dimensionCount = inputStrides.size(); i < dimensionCount; ++i)
|
||||
{
|
||||
uint32_t labelIndex = labelIndices[i];
|
||||
assert(labelIndex < inputStrides.size());
|
||||
newSizes[i] = inputSizes[labelIndex];
|
||||
newStrides[i] = inputStrides[labelIndex];
|
||||
}
|
||||
|
||||
// Override the initial input tensor with the new strides.
|
||||
tensorDesc = TensorDesc(tensorDesc.GetDmlDataType(), newSizes, newStrides, 0);
|
||||
tensorDesc.GetDmlDesc(); // Discard value, but keep side effect of refreshing the DML view.
|
||||
}
|
||||
|
||||
std::vector<uint32_t> outputSizes = kernelCreationContext.GetTensorShapeDescription().GetOutputTensorShape(0);
|
||||
std::vector<uint32_t> newOutputSizes(outputSizes.size());
|
||||
assert(outputSizes.size() == labelIndices.size());
|
||||
|
||||
for (size_t i = 0; i < outputSizes.size(); ++i)
|
||||
{
|
||||
uint32_t labelIndex = labelIndices[i];
|
||||
newOutputSizes[i] = outputSizes[labelIndex];
|
||||
}
|
||||
|
||||
m_outputTensorDescs.front() = TensorDesc(m_outputTensorDescs.front().GetDmlDataType(), newOutputSizes, std::nullopt, 0);
|
||||
m_outputTensorDescs.front().GetDmlDesc(); // Discard value, but keep side effect of refreshing the DML view.
|
||||
|
||||
DML_GEMM_OPERATOR_DESC operatorDesc = {};
|
||||
operatorDesc.ATensor = &inputDescs[0];
|
||||
operatorDesc.BTensor = &inputDescs[1];
|
||||
// No operatorDesc.CTensor
|
||||
operatorDesc.OutputTensor = &outputDescs[0];
|
||||
operatorDesc.TransA = (m_recognizedOperatorType == RecognizedOperatorType::MatMulNhcwTransposeA) ? DML_MATRIX_TRANSFORM_TRANSPOSE : DML_MATRIX_TRANSFORM_NONE;
|
||||
operatorDesc.TransB = (m_recognizedOperatorType == RecognizedOperatorType::MatMulNhcwTransposeB) ? DML_MATRIX_TRANSFORM_TRANSPOSE : DML_MATRIX_TRANSFORM_NONE;
|
||||
operatorDesc.Alpha = 1.0;
|
||||
operatorDesc.Beta = 0.0;
|
||||
operatorDesc.FusedActivation = nullptr;
|
||||
|
||||
SetDmlOperatorDesc({ DML_OPERATOR_GEMM, &operatorDesc }, kernelCreationContext);
|
||||
}
|
||||
break;
|
||||
|
||||
case RecognizedOperatorType::ReduceSum:
|
||||
{
|
||||
|
|
@ -176,7 +252,7 @@ void CALLBACK QueryEinSum(IMLOperatorSupportQueryContextPrivate* context, bool*
|
|||
EinSumHelper helper(attributes);
|
||||
auto recognizedOperatorType = helper.GetRecognizedOperatorType();
|
||||
|
||||
static_assert(EinSumHelper::RecognizedOperatorType::Total == static_cast<EinSumHelper::RecognizedOperatorType>(8), "Verify this test still matches the switch above.");
|
||||
static_assert(EinSumHelper::RecognizedOperatorType::Total == static_cast<EinSumHelper::RecognizedOperatorType>(11), "Update this function.");
|
||||
*isSupported = (recognizedOperatorType != EinSumHelper::RecognizedOperatorType::None);
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -609,7 +609,7 @@ namespace OperatorHelper
|
|||
// `transBatch` needs to be applied first and then `transpose`.
|
||||
if (transBatch)
|
||||
{
|
||||
ML_CHECK_VALID_ARGUMENT(dimensionCount > 2,
|
||||
ML_CHECK_VALID_ARGUMENT(dimensionCount > 2,
|
||||
"FusedMatMul operator: Tensor size should be more than 2, if attribute transBatch is true");
|
||||
|
||||
std::rotate(newSizes.begin(), newSizes.end() - 2, newSizes.end() - 1);
|
||||
|
|
@ -702,7 +702,7 @@ namespace OperatorHelper
|
|||
if (inputShape0 != inputShape1)
|
||||
{
|
||||
ML_CHECK_VALID_ARGUMENT(
|
||||
inputShape0.size() == inputShape1.size() &&
|
||||
inputShape0.size() == inputShape1.size() &&
|
||||
inputShape0.size() == inputStride0.size() &&
|
||||
inputStride0.size() == inputStride1.size(),
|
||||
"Size of inputShape0, inputStride0, inputShape1 and inputStride1 should be same while broadcasting");
|
||||
|
|
@ -715,7 +715,7 @@ namespace OperatorHelper
|
|||
|
||||
auto inStride0Iter = inputStride0.rbegin();
|
||||
auto inStride1Iter = inputStride1.rbegin();
|
||||
|
||||
|
||||
while (rank-- > 0)
|
||||
{
|
||||
DimensionType inDimension0 = *inDim0Iter;
|
||||
|
|
@ -1503,18 +1503,21 @@ namespace OperatorHelper
|
|||
};
|
||||
|
||||
const RecognizedOperatorInfo recognizedOperators[] = {
|
||||
{RecognizedOperatorType::MatMul, {2,2,2},{0,1, 1,2, 0,2}}, // ij,jk->ik
|
||||
{RecognizedOperatorType::MatMul, {3,3,3},{0,1,2, 0,2,3, 0,1,3}}, // bij,bjk->bik
|
||||
{RecognizedOperatorType::MatMul, {4,4,4},{0,1,2,3, 0,1,3,4, 0,1,2,4}}, // abij,abjk->abik
|
||||
{RecognizedOperatorType::MatMulTransposeA, {2,2,2},{0,1, 0,2, 1,2}}, // ji,jk->ik
|
||||
{RecognizedOperatorType::MatMulTransposeA, {3,3,3},{0,1,2, 0,1,3, 0,2,3}}, // bji,bjk->bik
|
||||
{RecognizedOperatorType::MatMulTransposeA, {4,4,4},{0,1,2,3, 0,1,2,4, 0,1,3,4}}, // abji,abjk->abik
|
||||
{RecognizedOperatorType::MatMulTransposeB, {2,2,2},{0,1, 2,1, 0,2}}, // ij,kj->ik
|
||||
{RecognizedOperatorType::MatMulTransposeB, {3,3,3},{0,1,2, 0,3,2, 0,1,3}}, // bij,bkj->bik
|
||||
{RecognizedOperatorType::MatMulTransposeB, {4,4,4},{0,1,2,3, 0,1,4,3, 0,1,2,4}}, // abij,abkj->abik
|
||||
{RecognizedOperatorType::MatMulTransposeB, {1,1,0},{0,0,}}, // i,i-> (1D inner_prod)
|
||||
{RecognizedOperatorType::ReduceSum, {2,1 },{0,1, 0}}, // ij->i
|
||||
{RecognizedOperatorType::ReduceSum, {2,1 },{0,1, 1}}, // ij->j
|
||||
{RecognizedOperatorType::MatMul, {2,2,2},{0,1, 1,2, 0,2}}, // ij,jk->ik
|
||||
{RecognizedOperatorType::MatMul, {3,3,3},{0,1,2, 0,2,3, 0,1,3}}, // bij,bjk->bik
|
||||
{RecognizedOperatorType::MatMul, {4,4,4},{0,1,2,3, 0,1,3,4, 0,1,2,4}}, // abij,abjk->abik
|
||||
{RecognizedOperatorType::MatMulTransposeA, {2,2,2},{0,1, 0,2, 1,2}}, // ji,jk->ik
|
||||
{RecognizedOperatorType::MatMulTransposeA, {3,3,3},{0,1,2, 0,1,3, 0,2,3}}, // bji,bjk->bik
|
||||
{RecognizedOperatorType::MatMulTransposeA, {4,4,4},{0,1,2,3, 0,1,2,4, 0,1,3,4}}, // abji,abjk->abik
|
||||
{RecognizedOperatorType::MatMulTransposeB, {2,2,2},{0,1, 2,1, 0,2}}, // ij,kj->ik
|
||||
{RecognizedOperatorType::MatMulTransposeB, {3,3,3},{0,1,2, 0,3,2, 0,1,3}}, // bij,bkj->bik
|
||||
{RecognizedOperatorType::MatMulTransposeB, {4,4,4},{0,1,2,3, 0,1,4,3, 0,1,2,4}}, // abij,abkj->abik
|
||||
{RecognizedOperatorType::MatMulTransposeB, {1,1,0},{0,0,}}, // i,i-> (1D inner_prod)
|
||||
{RecognizedOperatorType::MatMulNhcw, {4,4,4},{0,1,2,3, 0,3,2,4, 0,1,2,4}}, // aibj,ajbk->aibk
|
||||
{RecognizedOperatorType::MatMulNhcwTransposeA, {4,4,4},{0,1,2,3, 0,1,2,4, 0,3,2,4}}, // ajbi,ajbk->aibk
|
||||
{RecognizedOperatorType::MatMulNhcwTransposeB, {4,4,4},{0,1,2,3, 0,4,2,3, 0,1,2,4}}, // aibj,akbj->aibk
|
||||
{RecognizedOperatorType::ReduceSum, {2,1 },{0,1, 0}}, // ij->i
|
||||
{RecognizedOperatorType::ReduceSum, {2,1 },{0,1, 1}}, // ij->j
|
||||
};
|
||||
|
||||
// For each recognized operator, compare the labels-per-component and label indices.
|
||||
|
|
@ -1595,7 +1598,10 @@ namespace OperatorHelper
|
|||
{
|
||||
return m_recognizedOperatorType == RecognizedOperatorType::MatMul ||
|
||||
m_recognizedOperatorType == RecognizedOperatorType::MatMulTransposeA ||
|
||||
m_recognizedOperatorType == RecognizedOperatorType::MatMulTransposeB;
|
||||
m_recognizedOperatorType == RecognizedOperatorType::MatMulTransposeB ||
|
||||
m_recognizedOperatorType == RecognizedOperatorType::MatMulNhcw ||
|
||||
m_recognizedOperatorType == RecognizedOperatorType::MatMulNhcwTransposeA ||
|
||||
m_recognizedOperatorType == RecognizedOperatorType::MatMulNhcwTransposeB;
|
||||
}
|
||||
|
||||
std::vector<EdgeShapes> MatMulHelperBase::GetOutputShapes(const MLShapeInferenceContext& shapeInfo) const
|
||||
|
|
|
|||
|
|
@ -234,7 +234,7 @@ void FusedMatMulShapeMapping(
|
|||
std::vector<DimensionType>& outputShape);
|
||||
|
||||
std::pair<std::vector<uint32_t>, std::vector<uint32_t>> GetFusedMatMulSizesAndStrides(
|
||||
gsl::span<const uint32_t> sizes,
|
||||
gsl::span<const uint32_t> sizes,
|
||||
int32_t transBatch = 0,
|
||||
int32_t transpose = 0);
|
||||
|
||||
|
|
@ -437,7 +437,7 @@ public:
|
|||
enum InputDims { N, C, H, W };
|
||||
|
||||
public:
|
||||
// Info_t is used to obtain attributes which will be used for calculating the output shape later.
|
||||
// Info_t is used to obtain attributes which will be used for calculating the output shape later.
|
||||
template<typename Info_t, typename Shape_t>
|
||||
ConvolutionHelperBase(const Info_t& info, const Shape_t& shape, bool transpose, bool hasDynamicPads, uint32_t inputTensorIndex, uint32_t filterTensorIndex) :
|
||||
m_inputTensorIndex(inputTensorIndex),
|
||||
|
|
@ -445,7 +445,7 @@ public:
|
|||
m_kernel(InitializeKernel(info, shape.GetInputTensorDimensionCount(inputTensorIndex), shape.GetInputTensorShape(filterTensorIndex)))
|
||||
{
|
||||
m_groupCount = info.template GetOptionalAttribute<uint32_t>(AttrName::Group, 1);
|
||||
|
||||
|
||||
if (!transpose)
|
||||
{
|
||||
InitializeKernelAndShapes(ShapeInformationAdapter(shape));
|
||||
|
|
@ -507,8 +507,8 @@ public:
|
|||
class GemmHelper
|
||||
{
|
||||
public:
|
||||
// Info_t is used to obtain attributes which will be used for calculating the output shape later.
|
||||
// Shape_t is used to obtain input shape which will be used for adjusting attribute value.
|
||||
// Info_t is used to obtain attributes which will be used for calculating the output shape later.
|
||||
// Shape_t is used to obtain input shape which will be used for adjusting attribute value.
|
||||
template<typename Info_t, typename Shape_t>
|
||||
GemmHelper(const Info_t& info, const Shape_t& shape)
|
||||
{
|
||||
|
|
@ -591,8 +591,8 @@ class SliceHelper
|
|||
);
|
||||
|
||||
public:
|
||||
// Info_t is used to obtain attributes which will be used for calculating the output shape later.
|
||||
// Shape_t is used to obtain input shape which will be used for adjusting attribute value.
|
||||
// Info_t is used to obtain attributes which will be used for calculating the output shape later.
|
||||
// Shape_t is used to obtain input shape which will be used for adjusting attribute value.
|
||||
template<typename Info_t, typename Shape_t>
|
||||
SliceHelper(const Info_t& info, const Shape_t& shape, uint32_t opsetVersion)
|
||||
{
|
||||
|
|
@ -722,6 +722,9 @@ public:
|
|||
MatMul,
|
||||
MatMulTransposeA,
|
||||
MatMulTransposeB,
|
||||
MatMulNhcw,
|
||||
MatMulNhcwTransposeA,
|
||||
MatMulNhcwTransposeB,
|
||||
ReduceSum,
|
||||
Transpose,
|
||||
Total,
|
||||
|
|
@ -740,7 +743,7 @@ protected:
|
|||
{
|
||||
uint32_t labelIndexBegin;
|
||||
uint32_t labelIndexEnd;
|
||||
|
||||
|
||||
uint32_t GetDimensionCount() const noexcept
|
||||
{
|
||||
return labelIndexEnd - labelIndexBegin;
|
||||
|
|
@ -1037,8 +1040,8 @@ protected:
|
|||
class UnpoolingHelper
|
||||
{
|
||||
public:
|
||||
// Info_t is used to obtain attributes which will be used for calculating the output shape later.
|
||||
// Shape_t is used to obtain input shape which will be used for adjusting attribute value.
|
||||
// Info_t is used to obtain attributes which will be used for calculating the output shape later.
|
||||
// Shape_t is used to obtain input shape which will be used for adjusting attribute value.
|
||||
template<typename Info_t, typename Shape_t>
|
||||
UnpoolingHelper(
|
||||
const Info_t& info,
|
||||
|
|
|
|||
|
|
@ -179,6 +179,33 @@ TEST(Einsum, ExplicitEinsumAsMatmul) {
|
|||
test.Run();
|
||||
}
|
||||
|
||||
TEST(Einsum, ExplicitEinsumAsMatmulNhcw) {
|
||||
OpTester test("Einsum", 12, onnxruntime::kOnnxDomain);
|
||||
test.AddAttribute<std::string>("equation", "aibj,ajbk->aibk");
|
||||
test.AddInput<float>("x", {1, 3, 1, 2}, {1.f, 2.f, 3.f, 4.f, 5.f, 6.f});
|
||||
test.AddInput<float>("y", {1, 2, 1, 3}, {1.f, 2.f, 3.f, 4.f, 5.f, 6.f});
|
||||
test.AddOutput<float>("o", {1, 3, 1, 3}, {9.f, 12.f, 15.f, 19.f, 26.f, 33.f, 29.f, 40.f, 51.f});
|
||||
test.Run();
|
||||
}
|
||||
|
||||
TEST(Einsum, ExplicitEinsumAsMatmulNhcwTransposeA) {
|
||||
OpTester test("Einsum", 12, onnxruntime::kOnnxDomain);
|
||||
test.AddAttribute<std::string>("equation", "ajbi,ajbk->aibk");
|
||||
test.AddInput<float>("x", {1, 2, 1, 3}, {1.f, 2.f, 3.f, 4.f, 5.f, 6.f});
|
||||
test.AddInput<float>("y", {1, 2, 1, 3}, {1.f, 2.f, 3.f, 4.f, 5.f, 6.f});
|
||||
test.AddOutput<float>("o", {1, 3, 1, 3}, {17.f, 22.f, 27.f, 22.f, 29.f, 36.f, 27.f, 36.f, 45.f});
|
||||
test.Run();
|
||||
}
|
||||
|
||||
TEST(Einsum, ExplicitEinsumAsMatmulNhcwTransposeB) {
|
||||
OpTester test("Einsum", 12, onnxruntime::kOnnxDomain);
|
||||
test.AddAttribute<std::string>("equation", "aibj,akbj->aibk");
|
||||
test.AddInput<float>("x", {1, 3, 1, 2}, {1.f, 2.f, 3.f, 4.f, 5.f, 6.f});
|
||||
test.AddInput<float>("y", {1, 3, 1, 2}, {1.f, 2.f, 3.f, 4.f, 5.f, 6.f});
|
||||
test.AddOutput<float>("o", {1, 3, 1, 3}, {5.f, 11.f, 17.f, 11.f, 25.f, 39.f, 17.f, 39.f, 61.f});
|
||||
test.Run();
|
||||
}
|
||||
|
||||
TEST(Einsum, ExplicitEinsumAsMatmulWithUpperCasedLabel) {
|
||||
OpTester test("Einsum", 12, onnxruntime::kOnnxDomain);
|
||||
// 'K' != 'k' (and dim values differ too) and Einsum should handle be able to handle that
|
||||
|
|
|
|||
Loading…
Reference in a new issue