MatrixMultiplyIntegerToFloat (#16804)

### Description
Implementation for[
com.microsoft.MatMulIntegerToFloat](https://github.com/microsoft/onnxruntime/blob/main/docs/ContribOperators.md#com.microsoft.MatMulIntegerToFloat)

```
C:\workspace\ORT\onnxruntime\build\test\RelWithDebInfo\RelWithDebInfo>.\onnxruntime_test_all.exe --gtest_filter="*MatMulIntegerToFloat*"
Note: Google Test filter = *MatMulIntegerToFloat*
[==========] Running 8 tests from 4 test suites.
[----------] Global test environment set-up.
[----------] 1 test from CPU_U8S8_Precision_Tests
[ RUN      ] CPU_U8S8_Precision_Tests.MatMulIntegerToFloat
[       OK ] CPU_U8S8_Precision_Tests.MatMulIntegerToFloat (31 ms)
[----------] 1 test from CPU_U8S8_Precision_Tests (31 ms total)

[----------] 1 test from GraphTransformationTests
[ RUN      ] GraphTransformationTests.MatMulIntegerToFloatTest
[       OK ] GraphTransformationTests.MatMulIntegerToFloatTest (0 ms)
[----------] 1 test from GraphTransformationTests (1 ms total)

[----------] 1 test from QDQTransformerTests
[ RUN      ] QDQTransformerTests.MatMulIntegerToFloat
[       OK ] QDQTransformerTests.MatMulIntegerToFloat (12 ms)
[----------] 1 test from QDQTransformerTests (12 ms total)

[----------] 5 tests from MatMulIntegerToFloat
[ RUN      ] MatMulIntegerToFloat.HasZeroPoint_NoBias_test_U8X8
[       OK ] MatMulIntegerToFloat.HasZeroPoint_NoBias_test_U8X8 (801 ms)
[ RUN      ] MatMulIntegerToFloat.NoZeroPoint_HasBias_test_U8X8
[       OK ] MatMulIntegerToFloat.NoZeroPoint_HasBias_test_U8X8 (804 ms)
[ RUN      ] MatMulIntegerToFloat.HasZeroPoint_NoBias_test_S8S8
[       OK ] MatMulIntegerToFloat.HasZeroPoint_NoBias_test_S8S8 (378 ms)
[ RUN      ] MatMulIntegerToFloat.NoZeroPoint_HasBias_test_S8S8
[       OK ] MatMulIntegerToFloat.NoZeroPoint_HasBias_test_S8S8 (382 ms)
[ RUN      ] MatMulIntegerToFloat.MatMulInteger_With_ZeroPoint
[       OK ] MatMulIntegerToFloat.MatMulInteger_With_ZeroPoint (31 ms)
[----------] 5 tests from MatMulIntegerToFloat (2398 ms total)

[----------] Global test environment tear-down
[==========] 8 tests from 4 test suites ran. (2455 ms total)
[  PASSED  ] 8 tests.
```
This commit is contained in:
raoanag 2023-07-24 16:05:23 -07:00 committed by Jeff Bloomfield
parent 4336d6422d
commit 14298e9c02
8 changed files with 186 additions and 50 deletions

View file

@ -52,7 +52,7 @@ struct EnumTraits<DML_TENSOR_TYPE>
template <>
struct EnumTraits<DML_OPERATOR_TYPE>
{
static constexpr auto ValueCount = 160;
static constexpr auto ValueCount = 161;
static constexpr size_t ActivationFunctionCount = 24;
};
@ -889,6 +889,12 @@ struct OperatorDescTraits<DML_QUANTIZED_LINEAR_MATRIX_MULTIPLY_OPERATOR_DESC>
static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_QUANTIZED_LINEAR_MATRIX_MULTIPLY;
};
template <>
struct OperatorDescTraits<DML_MATRIX_MULTIPLY_INTEGER_TO_FLOAT_OPERATOR_DESC>
{
static constexpr DML_OPERATOR_TYPE Type = (DML_OPERATOR_TYPE) DML_OPERATOR_MATRIX_MULTIPLY_INTEGER_TO_FLOAT;
};
template <>
struct OperatorDescTraits<DML_CONVOLUTION_INTEGER_OPERATOR_DESC>
{
@ -1446,12 +1452,6 @@ struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_ELEMENT_WISE_DEQUANTIZ
using DescType = DML_ELEMENT_WISE_DEQUANTIZE_LINEAR_OPERATOR_DESC;
};
template <>
struct OperatorDescTraits<DML_MATRIX_MULTIPLY_INTEGER_TO_FLOAT_OPERATOR_DESC>
{
static constexpr DML_OPERATOR_TYPE Type = (DML_OPERATOR_TYPE) DML_OPERATOR_MATRIX_MULTIPLY_INTEGER_TO_FLOAT;
};
template <>
struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_CONVOLUTION>
{
@ -1890,6 +1890,12 @@ struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_QUANTIZED_LINEAR_MATRI
using DescType = DML_QUANTIZED_LINEAR_MATRIX_MULTIPLY_OPERATOR_DESC;
};
template <>
struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_MATRIX_MULTIPLY_INTEGER_TO_FLOAT>
{
using DescType = DML_MATRIX_MULTIPLY_INTEGER_TO_FLOAT_OPERATOR_DESC;
};
template <>
struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_CONVOLUTION_INTEGER>
{
@ -2189,11 +2195,6 @@ struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_ELEMENT_WISE_QUANTIZED
{
using DescType = DML_ELEMENT_WISE_QUANTIZED_LINEAR_ADD1_OPERATOR_DESC;
};
template <>
struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_MATRIX_MULTIPLY_INTEGER_TO_FLOAT>
{
using DescType = DML_MATRIX_MULTIPLY_INTEGER_TO_FLOAT_OPERATOR_DESC;
};
template <>
struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_MULTIHEAD_ATTENTION>
@ -2444,6 +2445,8 @@ auto OperatorTypeVisitor(DML_OPERATOR_TYPE type, Visitor&& visitor, Ts&&... args
return std::invoke(std::forward<Visitor>(visitor), DML_MATRIX_MULTIPLY_INTEGER_OPERATOR_DESC{}, std::forward<Ts>(args)...);
case DML_OPERATOR_QUANTIZED_LINEAR_MATRIX_MULTIPLY:
return std::invoke(std::forward<Visitor>(visitor), DML_QUANTIZED_LINEAR_MATRIX_MULTIPLY_OPERATOR_DESC{}, std::forward<Ts>(args)...);
case DML_OPERATOR_MATRIX_MULTIPLY_INTEGER_TO_FLOAT:
return std::invoke(std::forward<Visitor>(visitor), DML_MATRIX_MULTIPLY_INTEGER_TO_FLOAT_OPERATOR_DESC{}, std::forward<Ts>(args)...);
case DML_OPERATOR_CONVOLUTION_INTEGER:
return std::invoke(std::forward<Visitor>(visitor), DML_CONVOLUTION_INTEGER_OPERATOR_DESC{}, std::forward<Ts>(args)...);
case DML_OPERATOR_QUANTIZED_LINEAR_CONVOLUTION:
@ -2546,8 +2549,6 @@ auto OperatorTypeVisitor(DML_OPERATOR_TYPE type, Visitor&& visitor, Ts&&... args
return std::invoke(std::forward<Visitor>(visitor), DML_ACTIVATION_GELU_OPERATOR_DESC{}, std::forward<Ts>(args)...);
case DML_OPERATOR_ELEMENT_WISE_QUANTIZED_LINEAR_ADD1:
return std::invoke(std::forward<Visitor>(visitor), DML_ELEMENT_WISE_QUANTIZED_LINEAR_ADD1_OPERATOR_DESC{}, std::forward<Ts>(args)...);
case DML_OPERATOR_MATRIX_MULTIPLY_INTEGER_TO_FLOAT:
return std::invoke(std::forward<Visitor>(visitor), DML_MATRIX_MULTIPLY_INTEGER_TO_FLOAT_OPERATOR_DESC{}, std::forward<Ts>(args)...);
default:
ORT_THROW_HR(E_INVALIDARG);

View file

@ -1825,25 +1825,6 @@ constexpr DML_OPERATOR_SCHEMA DML_MATRIX_MULTIPLY_INTEGER_OPERATOR_SCHEMA {
DML_MATRIX_MULTIPLY_INTEGER_OPERATOR_SCHEMA_FIELDS,
};
constexpr DML_SCHEMA_FIELD DML_MATRIX_MULTIPLY_INTEGER_TO_FLOAT_OPERATOR_SCHEMA_FIELDS[8] {
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "ATensor", false },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "AScaleTensor", false },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "AZeroPointTensor", true },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "BTensor", false },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "BScaleTensor", false },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "BZeroPointTensor", true },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "BiasTensor", true },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false },
};
constexpr DML_OPERATOR_SCHEMA DML_MATRIX_MULTIPLY_INTEGER_TO_FLOAT_OPERATOR_SCHEMA {
"DML_OPERATOR_MATRIX_MULTIPLY_INTEGER_TO_FLOAT",
static_cast<DML_OPERATOR_TYPE>(DML_OPERATOR_MATRIX_MULTIPLY_INTEGER_TO_FLOAT),
DML_SCHEMA_OPERATOR_SUPPORT_FLAG_NONE,
8,
DML_MATRIX_MULTIPLY_INTEGER_TO_FLOAT_OPERATOR_SCHEMA_FIELDS,
};
constexpr DML_SCHEMA_FIELD DML_QUANTIZED_LINEAR_MATRIX_MULTIPLY_OPERATOR_SCHEMA_FIELDS[9] {
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "ATensor", false },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "AScaleTensor", false },
@ -1864,6 +1845,25 @@ constexpr DML_OPERATOR_SCHEMA DML_QUANTIZED_LINEAR_MATRIX_MULTIPLY_OPERATOR_SCHE
DML_QUANTIZED_LINEAR_MATRIX_MULTIPLY_OPERATOR_SCHEMA_FIELDS,
};
constexpr DML_SCHEMA_FIELD DML_MATRIX_MULTIPLY_INTEGER_TO_FLOAT_OPERATOR_SCHEMA_FIELDS[8] {
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "ATensor", false },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "AScaleTensor", false },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "AZeroPointTensor", true },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "BTensor", false },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "BScaleTensor", false },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "BZeroPointTensor", true },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "BiasTensor", true },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false },
};
constexpr DML_OPERATOR_SCHEMA DML_MATRIX_MULTIPLY_INTEGER_TO_FLOAT_OPERATOR_SCHEMA {
"DML_OPERATOR_MATRIX_MULTIPLY_INTEGER_TO_FLOAT",
static_cast<DML_OPERATOR_TYPE>(DML_OPERATOR_MATRIX_MULTIPLY_INTEGER_TO_FLOAT),
DML_SCHEMA_OPERATOR_SUPPORT_FLAG_NONE,
8,
DML_MATRIX_MULTIPLY_INTEGER_TO_FLOAT_OPERATOR_SCHEMA_FIELDS,
};
constexpr DML_SCHEMA_FIELD DML_CONVOLUTION_INTEGER_OPERATOR_SCHEMA_FIELDS[11] {
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputZeroPointTensor", true },

View file

@ -1111,6 +1111,19 @@ inline std::vector<OperatorField> GetFields(const DML_QUANTIZED_LINEAR_MATRIX_MU
OperatorField(&DML_QUANTIZED_LINEAR_MATRIX_MULTIPLY_OPERATOR_SCHEMA.Fields[8], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.OutputTensor))),
};
}
inline std::vector<OperatorField> GetFields(const DML_MATRIX_MULTIPLY_INTEGER_TO_FLOAT_OPERATOR_DESC& desc)
{
return {
OperatorField(&DML_MATRIX_MULTIPLY_INTEGER_TO_FLOAT_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.ATensor))),
OperatorField(&DML_MATRIX_MULTIPLY_INTEGER_TO_FLOAT_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.AScaleTensor))),
OperatorField(&DML_MATRIX_MULTIPLY_INTEGER_TO_FLOAT_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.AZeroPointTensor))),
OperatorField(&DML_MATRIX_MULTIPLY_INTEGER_TO_FLOAT_OPERATOR_SCHEMA.Fields[3], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.BTensor))),
OperatorField(&DML_MATRIX_MULTIPLY_INTEGER_TO_FLOAT_OPERATOR_SCHEMA.Fields[4], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.BScaleTensor))),
OperatorField(&DML_MATRIX_MULTIPLY_INTEGER_TO_FLOAT_OPERATOR_SCHEMA.Fields[5], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.BZeroPointTensor))),
OperatorField(&DML_MATRIX_MULTIPLY_INTEGER_TO_FLOAT_OPERATOR_SCHEMA.Fields[6], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.BiasTensor))),
OperatorField(&DML_MATRIX_MULTIPLY_INTEGER_TO_FLOAT_OPERATOR_SCHEMA.Fields[7], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.OutputTensor))),
};
}
inline std::vector<OperatorField> GetFields(const DML_CONVOLUTION_INTEGER_OPERATOR_DESC& desc)
{
return {
@ -1626,19 +1639,6 @@ inline std::vector<OperatorField> GetFields(const DML_ACTIVATION_SHRINK_OPERATOR
OperatorField(&DML_ACTIVATION_SHRINK_OPERATOR_SCHEMA.Fields[3], ToOperatorFieldType(static_cast<FLOAT>(desc.Threshold))),
};
}
inline std::vector<OperatorField> GetFields(const DML_MATRIX_MULTIPLY_INTEGER_TO_FLOAT_OPERATOR_DESC& desc)
{
return {
OperatorField(&DML_MATRIX_MULTIPLY_INTEGER_TO_FLOAT_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.ATensor))),
OperatorField(&DML_MATRIX_MULTIPLY_INTEGER_TO_FLOAT_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.AScaleTensor))),
OperatorField(&DML_MATRIX_MULTIPLY_INTEGER_TO_FLOAT_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.AZeroPointTensor))),
OperatorField(&DML_MATRIX_MULTIPLY_INTEGER_TO_FLOAT_OPERATOR_SCHEMA.Fields[3], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.BTensor))),
OperatorField(&DML_MATRIX_MULTIPLY_INTEGER_TO_FLOAT_OPERATOR_SCHEMA.Fields[4], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.BScaleTensor))),
OperatorField(&DML_MATRIX_MULTIPLY_INTEGER_TO_FLOAT_OPERATOR_SCHEMA.Fields[5], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.BZeroPointTensor))),
OperatorField(&DML_MATRIX_MULTIPLY_INTEGER_TO_FLOAT_OPERATOR_SCHEMA.Fields[6], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.BiasTensor))),
OperatorField(&DML_MATRIX_MULTIPLY_INTEGER_TO_FLOAT_OPERATOR_SCHEMA.Fields[7], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.OutputTensor))),
};
}
inline std::vector<OperatorField> GetFields(const DML_ACTIVATION_GELU_OPERATOR_DESC& desc)
{
return {
@ -1779,6 +1779,7 @@ inline const DML_OPERATOR_SCHEMA& GetSchema(DML_OPERATOR_TYPE operatorType)
case DML_OPERATOR_RESAMPLE1: return DML_RESAMPLE1_OPERATOR_SCHEMA;
case DML_OPERATOR_MATRIX_MULTIPLY_INTEGER: return DML_MATRIX_MULTIPLY_INTEGER_OPERATOR_SCHEMA;
case DML_OPERATOR_QUANTIZED_LINEAR_MATRIX_MULTIPLY: return DML_QUANTIZED_LINEAR_MATRIX_MULTIPLY_OPERATOR_SCHEMA;
case DML_OPERATOR_MATRIX_MULTIPLY_INTEGER_TO_FLOAT: return DML_MATRIX_MULTIPLY_INTEGER_TO_FLOAT_OPERATOR_SCHEMA;
case DML_OPERATOR_CONVOLUTION_INTEGER: return DML_CONVOLUTION_INTEGER_OPERATOR_SCHEMA;
case DML_OPERATOR_QUANTIZED_LINEAR_CONVOLUTION: return DML_QUANTIZED_LINEAR_CONVOLUTION_OPERATOR_SCHEMA;
case DML_OPERATOR_ELEMENT_WISE_BIT_AND: return DML_ELEMENT_WISE_BIT_AND_OPERATOR_SCHEMA;
@ -2299,6 +2300,10 @@ inline AbstractOperatorDesc ConvertOperatorDesc(const DML_OPERATOR_DESC& opDesc)
return AbstractOperatorDesc(
&DML_QUANTIZED_LINEAR_MATRIX_MULTIPLY_OPERATOR_SCHEMA,
GetFields(*static_cast<const DML_QUANTIZED_LINEAR_MATRIX_MULTIPLY_OPERATOR_DESC*>(opDesc.Desc)));
case DML_OPERATOR_MATRIX_MULTIPLY_INTEGER_TO_FLOAT:
return AbstractOperatorDesc(
&DML_MATRIX_MULTIPLY_INTEGER_TO_FLOAT_OPERATOR_SCHEMA,
GetFields(*static_cast<const DML_MATRIX_MULTIPLY_INTEGER_TO_FLOAT_OPERATOR_DESC*>(opDesc.Desc)));
case DML_OPERATOR_CONVOLUTION_INTEGER:
return AbstractOperatorDesc(
&DML_CONVOLUTION_INTEGER_OPERATOR_SCHEMA,
@ -2379,10 +2384,6 @@ inline AbstractOperatorDesc ConvertOperatorDesc(const DML_OPERATOR_DESC& opDesc)
return AbstractOperatorDesc(
&DML_ELEMENT_WISE_QUANTIZED_LINEAR_ADD_OPERATOR_SCHEMA,
GetFields(*static_cast<const DML_ELEMENT_WISE_QUANTIZED_LINEAR_ADD_OPERATOR_DESC*>(opDesc.Desc)));
case DML_OPERATOR_MATRIX_MULTIPLY_INTEGER_TO_FLOAT:
return AbstractOperatorDesc(
&DML_MATRIX_MULTIPLY_INTEGER_TO_FLOAT_OPERATOR_SCHEMA,
GetFields(*static_cast<const DML_MATRIX_MULTIPLY_INTEGER_TO_FLOAT_OPERATOR_DESC*>(opDesc.Desc)));
case DML_OPERATOR_ROI_ALIGN_GRAD:
return AbstractOperatorDesc(
&DML_ROI_ALIGN_GRAD_OPERATOR_SCHEMA,

View file

@ -0,0 +1,111 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "precomp.h"
namespace Dml
{
class DmlOperatorMatMulIntegerToFloat : public DmlOperator
{
enum InputTensors : uint32_t {
IN_A,
IN_B,
IN_A_SCALE,
IN_B_SCALE,
IN_A_ZERO_POINT,
IN_B_ZERO_POINT,
IN_BIAS,
IN_COUNT
};
enum DmlInputIndex : uint32_t
{
dmlA,
dmlAScale,
dmlAZeroPoint,
dmlB,
dmlBScale,
dmlBZeroPoint,
dmlBias,
dmlInputCount,
};
public:
DmlOperatorMatMulIntegerToFloat(const MLOperatorKernelCreationContext& kernelInfo)
: DmlOperator(kernelInfo)
{
std::vector<std::optional<uint32_t>> inputIndices = { InputTensors::IN_A, InputTensors::IN_A_SCALE, InputTensors::IN_A_ZERO_POINT, InputTensors::IN_B, InputTensors::IN_B_SCALE, InputTensors::IN_B_ZERO_POINT, InputTensors::IN_BIAS };
DmlOperator::Initialize(kernelInfo, inputIndices);
std::vector<DimensionType> inputShape0 = kernelInfo.GetTensorShapeDescription().GetInputTensorShape(InputTensors::IN_A);
std::vector<DimensionType> inputShape1 = kernelInfo.GetTensorShapeDescription().GetInputTensorShape(InputTensors::IN_B);
std::vector<DimensionType> outputShape = kernelInfo.GetTensorShapeDescription().GetOutputTensorShape(0);
OperatorHelper::MatMulShapeMapping(inputShape0, inputShape1, outputShape);
// Initialize the input descriptions with broadcasting
m_inputTensorDescs[DmlInputIndex::dmlA] = CreateTensorDescFromInput(kernelInfo, InputTensors::IN_A, TensorAxis::DoNotCoerce, TensorAxis::W, TensorAxis::RightAligned, inputShape0);
m_inputTensorDescs[DmlInputIndex::dmlB] = CreateTensorDescFromInput(kernelInfo, InputTensors::IN_B, TensorAxis::DoNotCoerce, TensorAxis::W, TensorAxis::RightAligned, inputShape1);
// Broadcast Bias tensor to the shape of the output tensor.
if(kernelInfo.IsInputValid(InputTensors::IN_BIAS)) {
m_inputTensorDescs[DmlInputIndex::dmlBias] = CreateTensorDescFromInput(kernelInfo, InputTensors::IN_BIAS, TensorAxis::DoNotCoerce,
TensorAxis::W, TensorAxis::RightAligned, outputShape);
}
uint32_t dmlDimSize = m_inputTensorDescs[DmlInputIndex::dmlA].GetDimensionCount();
// Resize the A Scale to be the same dimension as the input tensor.
// The 1D tensor needs to be moved to the H channel.
m_inputTensorDescs[DmlInputIndex::dmlAScale] = CreateTensorDescFromInput(
kernelInfo,
InputTensors::IN_A_SCALE,
TensorAxis::DoNotCoerce,
TensorAxis::H,
TensorAxis::LeftAligned,
std::nullopt,
dmlDimSize
);
// Resize the A ZeroPoint to be the same dimension as the input tensor.
// The 1D tensor needs to be moved to the H channel.
if(kernelInfo.IsInputValid(InputTensors::IN_A_ZERO_POINT)) {
m_inputTensorDescs[DmlInputIndex::dmlAZeroPoint] = CreateTensorDescFromInput(
kernelInfo,
InputTensors::IN_A_ZERO_POINT,
TensorAxis::DoNotCoerce,
TensorAxis::H,
TensorAxis::LeftAligned,
std::nullopt,
dmlDimSize
);
}
// B Zeropoint and BScale are already aligned in the W dimension so no need to align them
// Initialize the output description while overriding the shape
m_outputTensorDescs[0] = CreateTensorDescFromOutput(kernelInfo, 0, TensorAxis::DoNotCoerce, TensorAxis::W, TensorAxis::RightAligned, outputShape);
std::vector<DML_TENSOR_DESC> inputDescs = GetDmlInputDescs();
std::vector<DML_TENSOR_DESC> outputDescs = GetDmlOutputDescs();
DML_MATRIX_MULTIPLY_INTEGER_TO_FLOAT_OPERATOR_DESC matMulDesc = {};
matMulDesc.ATensor = &inputDescs[DmlInputIndex::dmlA];
matMulDesc.AScaleTensor = &inputDescs[DmlInputIndex::dmlAScale];
matMulDesc.AZeroPointTensor = inputDescs[DmlInputIndex::dmlAZeroPoint].Desc != nullptr ? &inputDescs[DmlInputIndex::dmlAZeroPoint] : nullptr;
matMulDesc.BTensor = &inputDescs[DmlInputIndex::dmlB];
matMulDesc.BScaleTensor = &inputDescs[DmlInputIndex::dmlBScale];
matMulDesc.BZeroPointTensor = inputDescs[DmlInputIndex::dmlBZeroPoint].Desc != nullptr ? &inputDescs[DmlInputIndex::dmlBZeroPoint] : nullptr;
matMulDesc.BiasTensor = inputDescs[DmlInputIndex::dmlBias].Desc != nullptr ? &inputDescs[DmlInputIndex::dmlBias] : nullptr;
matMulDesc.OutputTensor = &outputDescs[0];
DML_OPERATOR_DESC opDesc = { (DML_OPERATOR_TYPE) DML_OPERATOR_MATRIX_MULTIPLY_INTEGER_TO_FLOAT, &matMulDesc };
SetDmlOperatorDesc(opDesc, kernelInfo);
}
};
DML_OP_DEFINE_CREATION_FUNCTION(MatMulIntegerToFloat, DmlOperatorMatMulIntegerToFloat);
} // namespace Dml

View file

@ -436,6 +436,7 @@ DML_OP_EXTERN_CREATION_FUNCTION(QLinearConv);
DML_OP_EXTERN_CREATION_FUNCTION(QLinearMatMul);
DML_OP_EXTERN_CREATION_FUNCTION(DynamicQuantizeLinear);
DML_OP_EXTERN_CREATION_FUNCTION(MatMulInteger);
DML_OP_EXTERN_CREATION_FUNCTION(MatMulIntegerToFloat);
DML_OP_EXTERN_CREATION_FUNCTION(ConvInteger);
DML_OP_EXTERN_CREATION_FUNCTION(Trilu);
DML_OP_EXTERN_CREATION_FUNCTION(Shape);
@ -543,6 +544,13 @@ constexpr static std::array<SupportedTensorDataTypes, 3> supportedTypeListQLinea
SupportedTensorDataTypes::Int8|SupportedTensorDataTypes::UInt8,
SupportedTensorDataTypes::Int8|SupportedTensorDataTypes::UInt8
};
constexpr static std::array<SupportedTensorDataTypes, 3> supportedTypeListMatMulIntegerToFloat = {
SupportedTensorDataTypes::Int8|SupportedTensorDataTypes::UInt8,
SupportedTensorDataTypes::Int8|SupportedTensorDataTypes::UInt8,
SupportedTensorDataTypes::Float32
};
constexpr static std::array<SupportedTensorDataTypes, 4> supportedTypeListQLinearConv = {
SupportedTensorDataTypes::Int8|SupportedTensorDataTypes::UInt8,
SupportedTensorDataTypes::Int8|SupportedTensorDataTypes::UInt8,
@ -968,6 +976,7 @@ constexpr static OperatorRegistrationInformation operatorRegistrationInformation
{REG_INFO( 10, QLinearConv, typeNameListFour, supportedTypeListQLinearConv, DmlGraphSupport::Supported)},
{REG_INFO( 10, QLinearMatMul, typeNameListThree, supportedTypeListQLinearMatMul, DmlGraphSupport::Supported)},
{REG_INFO( 10, MatMulInteger, typeNameListThree, supportedTypeListInteger, DmlGraphSupport::Supported)},
{REG_INFO_MS( 1, MatMulIntegerToFloat, typeNameListThree, supportedTypeListMatMulIntegerToFloat, DmlGraphSupport::Supported)},
{REG_INFO( 10, ConvInteger, typeNameListThree, supportedTypeListInteger, DmlGraphSupport::Supported)},
{REG_INFO( 11, DynamicQuantizeLinear, typeNameListTwo, supportedTypeListDynamicQuantizeLinear, DmlGraphSupport::Supported)},
{REG_INFO( 7, LayerNormalization, typeNameListLayerNormContrib, supportedTypeListLayerNormalizationContrib, DmlGraphSupport::Supported, requiredConstantCpuInputs(), std::nullopt, QueryLayerNormalization)},

View file

@ -822,6 +822,13 @@ public:
QLinearMatMulHelper(const Info_t& info, const Shape_t& shape) : MatMulHelperBase(info, shape, 0, 3) {}
};
class MatMulIntegerToFloatHelper : public MatMulHelperBase
{
public:
template<typename Info_t, typename Shape_t>
MatMulIntegerToFloatHelper(const Info_t& info, const Shape_t& shape) : MatMulHelperBase(info, shape, 0, 1) {}
};
class TopKHelper
{
@ -1657,6 +1664,7 @@ using ShapeInferenceHelper_Identity16 = GetOutputShapeAsInputShapeHelper;
using ShapeInferenceHelper_MatMul = MatMulHelper;
using ShapeInferenceHelper_MatMulInteger = MatMulHelper;
using ShapeInferenceHelper_DynamicQuantizeMatMul = MatMulHelper;
using ShapeInferenceHelper_MatMulIntegerToFloat = MatMulIntegerToFloatHelper;
using ShapeInferenceHelper_QLinearMatMul = QLinearMatMulHelper;
using ShapeInferenceHelper_QLinearAdd = GetBroadcastedOutputShapeHelper;
using ShapeInferenceHelper_DynamicQuantizeLinear = GetOutputShapeAsInputShapeHelper;

View file

@ -429,6 +429,7 @@ namespace OperatorHelper
static const int sc_sinceVer_FusedMatMulActivation = 1;
static const int sc_sinceVer_QLinearSigmoid = 1;
static const int sc_sinceVer_Attention = 1;
static const int sc_sinceVer_MatMulIntegerToFloat = 1;
static const int sc_sinceVer_MultiHeadAttention = 1;
static const int sc_sinceVer_SkipLayerNormalization = 1;
static const int sc_sinceVer_EmbedLayerNormalization = 1;

View file

@ -88,7 +88,12 @@ void TestMatMulIntegerToFloat(const std::vector<int64_t>& A_dims,
}
test.AddReferenceOutputs(reference_model);
#if defined(USE_DML)
test.SetOutputRelErr("Y", 2e-2f);
#else
test.SetOutputRelErr("Y", 1e-4f);
#endif
test.Run();
}