mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-25 02:50:42 +00:00
MatrixMultiplyIntegerToFloat (#16804)
### Description Implementation for[ com.microsoft.MatMulIntegerToFloat](https://github.com/microsoft/onnxruntime/blob/main/docs/ContribOperators.md#com.microsoft.MatMulIntegerToFloat) ``` C:\workspace\ORT\onnxruntime\build\test\RelWithDebInfo\RelWithDebInfo>.\onnxruntime_test_all.exe --gtest_filter="*MatMulIntegerToFloat*" Note: Google Test filter = *MatMulIntegerToFloat* [==========] Running 8 tests from 4 test suites. [----------] Global test environment set-up. [----------] 1 test from CPU_U8S8_Precision_Tests [ RUN ] CPU_U8S8_Precision_Tests.MatMulIntegerToFloat [ OK ] CPU_U8S8_Precision_Tests.MatMulIntegerToFloat (31 ms) [----------] 1 test from CPU_U8S8_Precision_Tests (31 ms total) [----------] 1 test from GraphTransformationTests [ RUN ] GraphTransformationTests.MatMulIntegerToFloatTest [ OK ] GraphTransformationTests.MatMulIntegerToFloatTest (0 ms) [----------] 1 test from GraphTransformationTests (1 ms total) [----------] 1 test from QDQTransformerTests [ RUN ] QDQTransformerTests.MatMulIntegerToFloat [ OK ] QDQTransformerTests.MatMulIntegerToFloat (12 ms) [----------] 1 test from QDQTransformerTests (12 ms total) [----------] 5 tests from MatMulIntegerToFloat [ RUN ] MatMulIntegerToFloat.HasZeroPoint_NoBias_test_U8X8 [ OK ] MatMulIntegerToFloat.HasZeroPoint_NoBias_test_U8X8 (801 ms) [ RUN ] MatMulIntegerToFloat.NoZeroPoint_HasBias_test_U8X8 [ OK ] MatMulIntegerToFloat.NoZeroPoint_HasBias_test_U8X8 (804 ms) [ RUN ] MatMulIntegerToFloat.HasZeroPoint_NoBias_test_S8S8 [ OK ] MatMulIntegerToFloat.HasZeroPoint_NoBias_test_S8S8 (378 ms) [ RUN ] MatMulIntegerToFloat.NoZeroPoint_HasBias_test_S8S8 [ OK ] MatMulIntegerToFloat.NoZeroPoint_HasBias_test_S8S8 (382 ms) [ RUN ] MatMulIntegerToFloat.MatMulInteger_With_ZeroPoint [ OK ] MatMulIntegerToFloat.MatMulInteger_With_ZeroPoint (31 ms) [----------] 5 tests from MatMulIntegerToFloat (2398 ms total) [----------] Global test environment tear-down [==========] 8 tests from 4 test suites ran. (2455 ms total) [ PASSED ] 8 tests. ```
This commit is contained in:
parent
4336d6422d
commit
14298e9c02
8 changed files with 186 additions and 50 deletions
|
|
@ -52,7 +52,7 @@ struct EnumTraits<DML_TENSOR_TYPE>
|
|||
template <>
|
||||
struct EnumTraits<DML_OPERATOR_TYPE>
|
||||
{
|
||||
static constexpr auto ValueCount = 160;
|
||||
static constexpr auto ValueCount = 161;
|
||||
static constexpr size_t ActivationFunctionCount = 24;
|
||||
};
|
||||
|
||||
|
|
@ -889,6 +889,12 @@ struct OperatorDescTraits<DML_QUANTIZED_LINEAR_MATRIX_MULTIPLY_OPERATOR_DESC>
|
|||
static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_QUANTIZED_LINEAR_MATRIX_MULTIPLY;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct OperatorDescTraits<DML_MATRIX_MULTIPLY_INTEGER_TO_FLOAT_OPERATOR_DESC>
|
||||
{
|
||||
static constexpr DML_OPERATOR_TYPE Type = (DML_OPERATOR_TYPE) DML_OPERATOR_MATRIX_MULTIPLY_INTEGER_TO_FLOAT;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct OperatorDescTraits<DML_CONVOLUTION_INTEGER_OPERATOR_DESC>
|
||||
{
|
||||
|
|
@ -1446,12 +1452,6 @@ struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_ELEMENT_WISE_DEQUANTIZ
|
|||
using DescType = DML_ELEMENT_WISE_DEQUANTIZE_LINEAR_OPERATOR_DESC;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct OperatorDescTraits<DML_MATRIX_MULTIPLY_INTEGER_TO_FLOAT_OPERATOR_DESC>
|
||||
{
|
||||
static constexpr DML_OPERATOR_TYPE Type = (DML_OPERATOR_TYPE) DML_OPERATOR_MATRIX_MULTIPLY_INTEGER_TO_FLOAT;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_CONVOLUTION>
|
||||
{
|
||||
|
|
@ -1890,6 +1890,12 @@ struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_QUANTIZED_LINEAR_MATRI
|
|||
using DescType = DML_QUANTIZED_LINEAR_MATRIX_MULTIPLY_OPERATOR_DESC;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_MATRIX_MULTIPLY_INTEGER_TO_FLOAT>
|
||||
{
|
||||
using DescType = DML_MATRIX_MULTIPLY_INTEGER_TO_FLOAT_OPERATOR_DESC;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_CONVOLUTION_INTEGER>
|
||||
{
|
||||
|
|
@ -2189,11 +2195,6 @@ struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_ELEMENT_WISE_QUANTIZED
|
|||
{
|
||||
using DescType = DML_ELEMENT_WISE_QUANTIZED_LINEAR_ADD1_OPERATOR_DESC;
|
||||
};
|
||||
template <>
|
||||
struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_MATRIX_MULTIPLY_INTEGER_TO_FLOAT>
|
||||
{
|
||||
using DescType = DML_MATRIX_MULTIPLY_INTEGER_TO_FLOAT_OPERATOR_DESC;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_MULTIHEAD_ATTENTION>
|
||||
|
|
@ -2444,6 +2445,8 @@ auto OperatorTypeVisitor(DML_OPERATOR_TYPE type, Visitor&& visitor, Ts&&... args
|
|||
return std::invoke(std::forward<Visitor>(visitor), DML_MATRIX_MULTIPLY_INTEGER_OPERATOR_DESC{}, std::forward<Ts>(args)...);
|
||||
case DML_OPERATOR_QUANTIZED_LINEAR_MATRIX_MULTIPLY:
|
||||
return std::invoke(std::forward<Visitor>(visitor), DML_QUANTIZED_LINEAR_MATRIX_MULTIPLY_OPERATOR_DESC{}, std::forward<Ts>(args)...);
|
||||
case DML_OPERATOR_MATRIX_MULTIPLY_INTEGER_TO_FLOAT:
|
||||
return std::invoke(std::forward<Visitor>(visitor), DML_MATRIX_MULTIPLY_INTEGER_TO_FLOAT_OPERATOR_DESC{}, std::forward<Ts>(args)...);
|
||||
case DML_OPERATOR_CONVOLUTION_INTEGER:
|
||||
return std::invoke(std::forward<Visitor>(visitor), DML_CONVOLUTION_INTEGER_OPERATOR_DESC{}, std::forward<Ts>(args)...);
|
||||
case DML_OPERATOR_QUANTIZED_LINEAR_CONVOLUTION:
|
||||
|
|
@ -2546,8 +2549,6 @@ auto OperatorTypeVisitor(DML_OPERATOR_TYPE type, Visitor&& visitor, Ts&&... args
|
|||
return std::invoke(std::forward<Visitor>(visitor), DML_ACTIVATION_GELU_OPERATOR_DESC{}, std::forward<Ts>(args)...);
|
||||
case DML_OPERATOR_ELEMENT_WISE_QUANTIZED_LINEAR_ADD1:
|
||||
return std::invoke(std::forward<Visitor>(visitor), DML_ELEMENT_WISE_QUANTIZED_LINEAR_ADD1_OPERATOR_DESC{}, std::forward<Ts>(args)...);
|
||||
case DML_OPERATOR_MATRIX_MULTIPLY_INTEGER_TO_FLOAT:
|
||||
return std::invoke(std::forward<Visitor>(visitor), DML_MATRIX_MULTIPLY_INTEGER_TO_FLOAT_OPERATOR_DESC{}, std::forward<Ts>(args)...);
|
||||
|
||||
default:
|
||||
ORT_THROW_HR(E_INVALIDARG);
|
||||
|
|
|
|||
|
|
@ -1825,25 +1825,6 @@ constexpr DML_OPERATOR_SCHEMA DML_MATRIX_MULTIPLY_INTEGER_OPERATOR_SCHEMA {
|
|||
DML_MATRIX_MULTIPLY_INTEGER_OPERATOR_SCHEMA_FIELDS,
|
||||
};
|
||||
|
||||
constexpr DML_SCHEMA_FIELD DML_MATRIX_MULTIPLY_INTEGER_TO_FLOAT_OPERATOR_SCHEMA_FIELDS[8] {
|
||||
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "ATensor", false },
|
||||
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "AScaleTensor", false },
|
||||
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "AZeroPointTensor", true },
|
||||
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "BTensor", false },
|
||||
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "BScaleTensor", false },
|
||||
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "BZeroPointTensor", true },
|
||||
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "BiasTensor", true },
|
||||
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false },
|
||||
};
|
||||
|
||||
constexpr DML_OPERATOR_SCHEMA DML_MATRIX_MULTIPLY_INTEGER_TO_FLOAT_OPERATOR_SCHEMA {
|
||||
"DML_OPERATOR_MATRIX_MULTIPLY_INTEGER_TO_FLOAT",
|
||||
static_cast<DML_OPERATOR_TYPE>(DML_OPERATOR_MATRIX_MULTIPLY_INTEGER_TO_FLOAT),
|
||||
DML_SCHEMA_OPERATOR_SUPPORT_FLAG_NONE,
|
||||
8,
|
||||
DML_MATRIX_MULTIPLY_INTEGER_TO_FLOAT_OPERATOR_SCHEMA_FIELDS,
|
||||
};
|
||||
|
||||
constexpr DML_SCHEMA_FIELD DML_QUANTIZED_LINEAR_MATRIX_MULTIPLY_OPERATOR_SCHEMA_FIELDS[9] {
|
||||
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "ATensor", false },
|
||||
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "AScaleTensor", false },
|
||||
|
|
@ -1864,6 +1845,25 @@ constexpr DML_OPERATOR_SCHEMA DML_QUANTIZED_LINEAR_MATRIX_MULTIPLY_OPERATOR_SCHE
|
|||
DML_QUANTIZED_LINEAR_MATRIX_MULTIPLY_OPERATOR_SCHEMA_FIELDS,
|
||||
};
|
||||
|
||||
constexpr DML_SCHEMA_FIELD DML_MATRIX_MULTIPLY_INTEGER_TO_FLOAT_OPERATOR_SCHEMA_FIELDS[8] {
|
||||
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "ATensor", false },
|
||||
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "AScaleTensor", false },
|
||||
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "AZeroPointTensor", true },
|
||||
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "BTensor", false },
|
||||
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "BScaleTensor", false },
|
||||
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "BZeroPointTensor", true },
|
||||
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "BiasTensor", true },
|
||||
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false },
|
||||
};
|
||||
|
||||
constexpr DML_OPERATOR_SCHEMA DML_MATRIX_MULTIPLY_INTEGER_TO_FLOAT_OPERATOR_SCHEMA {
|
||||
"DML_OPERATOR_MATRIX_MULTIPLY_INTEGER_TO_FLOAT",
|
||||
static_cast<DML_OPERATOR_TYPE>(DML_OPERATOR_MATRIX_MULTIPLY_INTEGER_TO_FLOAT),
|
||||
DML_SCHEMA_OPERATOR_SUPPORT_FLAG_NONE,
|
||||
8,
|
||||
DML_MATRIX_MULTIPLY_INTEGER_TO_FLOAT_OPERATOR_SCHEMA_FIELDS,
|
||||
};
|
||||
|
||||
constexpr DML_SCHEMA_FIELD DML_CONVOLUTION_INTEGER_OPERATOR_SCHEMA_FIELDS[11] {
|
||||
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false },
|
||||
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputZeroPointTensor", true },
|
||||
|
|
|
|||
|
|
@ -1111,6 +1111,19 @@ inline std::vector<OperatorField> GetFields(const DML_QUANTIZED_LINEAR_MATRIX_MU
|
|||
OperatorField(&DML_QUANTIZED_LINEAR_MATRIX_MULTIPLY_OPERATOR_SCHEMA.Fields[8], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.OutputTensor))),
|
||||
};
|
||||
}
|
||||
inline std::vector<OperatorField> GetFields(const DML_MATRIX_MULTIPLY_INTEGER_TO_FLOAT_OPERATOR_DESC& desc)
|
||||
{
|
||||
return {
|
||||
OperatorField(&DML_MATRIX_MULTIPLY_INTEGER_TO_FLOAT_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.ATensor))),
|
||||
OperatorField(&DML_MATRIX_MULTIPLY_INTEGER_TO_FLOAT_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.AScaleTensor))),
|
||||
OperatorField(&DML_MATRIX_MULTIPLY_INTEGER_TO_FLOAT_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.AZeroPointTensor))),
|
||||
OperatorField(&DML_MATRIX_MULTIPLY_INTEGER_TO_FLOAT_OPERATOR_SCHEMA.Fields[3], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.BTensor))),
|
||||
OperatorField(&DML_MATRIX_MULTIPLY_INTEGER_TO_FLOAT_OPERATOR_SCHEMA.Fields[4], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.BScaleTensor))),
|
||||
OperatorField(&DML_MATRIX_MULTIPLY_INTEGER_TO_FLOAT_OPERATOR_SCHEMA.Fields[5], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.BZeroPointTensor))),
|
||||
OperatorField(&DML_MATRIX_MULTIPLY_INTEGER_TO_FLOAT_OPERATOR_SCHEMA.Fields[6], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.BiasTensor))),
|
||||
OperatorField(&DML_MATRIX_MULTIPLY_INTEGER_TO_FLOAT_OPERATOR_SCHEMA.Fields[7], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.OutputTensor))),
|
||||
};
|
||||
}
|
||||
inline std::vector<OperatorField> GetFields(const DML_CONVOLUTION_INTEGER_OPERATOR_DESC& desc)
|
||||
{
|
||||
return {
|
||||
|
|
@ -1626,19 +1639,6 @@ inline std::vector<OperatorField> GetFields(const DML_ACTIVATION_SHRINK_OPERATOR
|
|||
OperatorField(&DML_ACTIVATION_SHRINK_OPERATOR_SCHEMA.Fields[3], ToOperatorFieldType(static_cast<FLOAT>(desc.Threshold))),
|
||||
};
|
||||
}
|
||||
inline std::vector<OperatorField> GetFields(const DML_MATRIX_MULTIPLY_INTEGER_TO_FLOAT_OPERATOR_DESC& desc)
|
||||
{
|
||||
return {
|
||||
OperatorField(&DML_MATRIX_MULTIPLY_INTEGER_TO_FLOAT_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.ATensor))),
|
||||
OperatorField(&DML_MATRIX_MULTIPLY_INTEGER_TO_FLOAT_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.AScaleTensor))),
|
||||
OperatorField(&DML_MATRIX_MULTIPLY_INTEGER_TO_FLOAT_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.AZeroPointTensor))),
|
||||
OperatorField(&DML_MATRIX_MULTIPLY_INTEGER_TO_FLOAT_OPERATOR_SCHEMA.Fields[3], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.BTensor))),
|
||||
OperatorField(&DML_MATRIX_MULTIPLY_INTEGER_TO_FLOAT_OPERATOR_SCHEMA.Fields[4], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.BScaleTensor))),
|
||||
OperatorField(&DML_MATRIX_MULTIPLY_INTEGER_TO_FLOAT_OPERATOR_SCHEMA.Fields[5], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.BZeroPointTensor))),
|
||||
OperatorField(&DML_MATRIX_MULTIPLY_INTEGER_TO_FLOAT_OPERATOR_SCHEMA.Fields[6], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.BiasTensor))),
|
||||
OperatorField(&DML_MATRIX_MULTIPLY_INTEGER_TO_FLOAT_OPERATOR_SCHEMA.Fields[7], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.OutputTensor))),
|
||||
};
|
||||
}
|
||||
inline std::vector<OperatorField> GetFields(const DML_ACTIVATION_GELU_OPERATOR_DESC& desc)
|
||||
{
|
||||
return {
|
||||
|
|
@ -1779,6 +1779,7 @@ inline const DML_OPERATOR_SCHEMA& GetSchema(DML_OPERATOR_TYPE operatorType)
|
|||
case DML_OPERATOR_RESAMPLE1: return DML_RESAMPLE1_OPERATOR_SCHEMA;
|
||||
case DML_OPERATOR_MATRIX_MULTIPLY_INTEGER: return DML_MATRIX_MULTIPLY_INTEGER_OPERATOR_SCHEMA;
|
||||
case DML_OPERATOR_QUANTIZED_LINEAR_MATRIX_MULTIPLY: return DML_QUANTIZED_LINEAR_MATRIX_MULTIPLY_OPERATOR_SCHEMA;
|
||||
case DML_OPERATOR_MATRIX_MULTIPLY_INTEGER_TO_FLOAT: return DML_MATRIX_MULTIPLY_INTEGER_TO_FLOAT_OPERATOR_SCHEMA;
|
||||
case DML_OPERATOR_CONVOLUTION_INTEGER: return DML_CONVOLUTION_INTEGER_OPERATOR_SCHEMA;
|
||||
case DML_OPERATOR_QUANTIZED_LINEAR_CONVOLUTION: return DML_QUANTIZED_LINEAR_CONVOLUTION_OPERATOR_SCHEMA;
|
||||
case DML_OPERATOR_ELEMENT_WISE_BIT_AND: return DML_ELEMENT_WISE_BIT_AND_OPERATOR_SCHEMA;
|
||||
|
|
@ -2299,6 +2300,10 @@ inline AbstractOperatorDesc ConvertOperatorDesc(const DML_OPERATOR_DESC& opDesc)
|
|||
return AbstractOperatorDesc(
|
||||
&DML_QUANTIZED_LINEAR_MATRIX_MULTIPLY_OPERATOR_SCHEMA,
|
||||
GetFields(*static_cast<const DML_QUANTIZED_LINEAR_MATRIX_MULTIPLY_OPERATOR_DESC*>(opDesc.Desc)));
|
||||
case DML_OPERATOR_MATRIX_MULTIPLY_INTEGER_TO_FLOAT:
|
||||
return AbstractOperatorDesc(
|
||||
&DML_MATRIX_MULTIPLY_INTEGER_TO_FLOAT_OPERATOR_SCHEMA,
|
||||
GetFields(*static_cast<const DML_MATRIX_MULTIPLY_INTEGER_TO_FLOAT_OPERATOR_DESC*>(opDesc.Desc)));
|
||||
case DML_OPERATOR_CONVOLUTION_INTEGER:
|
||||
return AbstractOperatorDesc(
|
||||
&DML_CONVOLUTION_INTEGER_OPERATOR_SCHEMA,
|
||||
|
|
@ -2379,10 +2384,6 @@ inline AbstractOperatorDesc ConvertOperatorDesc(const DML_OPERATOR_DESC& opDesc)
|
|||
return AbstractOperatorDesc(
|
||||
&DML_ELEMENT_WISE_QUANTIZED_LINEAR_ADD_OPERATOR_SCHEMA,
|
||||
GetFields(*static_cast<const DML_ELEMENT_WISE_QUANTIZED_LINEAR_ADD_OPERATOR_DESC*>(opDesc.Desc)));
|
||||
case DML_OPERATOR_MATRIX_MULTIPLY_INTEGER_TO_FLOAT:
|
||||
return AbstractOperatorDesc(
|
||||
&DML_MATRIX_MULTIPLY_INTEGER_TO_FLOAT_OPERATOR_SCHEMA,
|
||||
GetFields(*static_cast<const DML_MATRIX_MULTIPLY_INTEGER_TO_FLOAT_OPERATOR_DESC*>(opDesc.Desc)));
|
||||
case DML_OPERATOR_ROI_ALIGN_GRAD:
|
||||
return AbstractOperatorDesc(
|
||||
&DML_ROI_ALIGN_GRAD_OPERATOR_SCHEMA,
|
||||
|
|
|
|||
|
|
@ -0,0 +1,111 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include "precomp.h"
|
||||
|
||||
namespace Dml
|
||||
{
|
||||
|
||||
class DmlOperatorMatMulIntegerToFloat : public DmlOperator
|
||||
{
|
||||
enum InputTensors : uint32_t {
|
||||
IN_A,
|
||||
IN_B,
|
||||
IN_A_SCALE,
|
||||
IN_B_SCALE,
|
||||
IN_A_ZERO_POINT,
|
||||
IN_B_ZERO_POINT,
|
||||
IN_BIAS,
|
||||
IN_COUNT
|
||||
};
|
||||
|
||||
enum DmlInputIndex : uint32_t
|
||||
{
|
||||
dmlA,
|
||||
dmlAScale,
|
||||
dmlAZeroPoint,
|
||||
dmlB,
|
||||
dmlBScale,
|
||||
dmlBZeroPoint,
|
||||
dmlBias,
|
||||
dmlInputCount,
|
||||
};
|
||||
|
||||
public:
|
||||
DmlOperatorMatMulIntegerToFloat(const MLOperatorKernelCreationContext& kernelInfo)
|
||||
: DmlOperator(kernelInfo)
|
||||
{
|
||||
std::vector<std::optional<uint32_t>> inputIndices = { InputTensors::IN_A, InputTensors::IN_A_SCALE, InputTensors::IN_A_ZERO_POINT, InputTensors::IN_B, InputTensors::IN_B_SCALE, InputTensors::IN_B_ZERO_POINT, InputTensors::IN_BIAS };
|
||||
DmlOperator::Initialize(kernelInfo, inputIndices);
|
||||
|
||||
std::vector<DimensionType> inputShape0 = kernelInfo.GetTensorShapeDescription().GetInputTensorShape(InputTensors::IN_A);
|
||||
std::vector<DimensionType> inputShape1 = kernelInfo.GetTensorShapeDescription().GetInputTensorShape(InputTensors::IN_B);
|
||||
std::vector<DimensionType> outputShape = kernelInfo.GetTensorShapeDescription().GetOutputTensorShape(0);
|
||||
|
||||
OperatorHelper::MatMulShapeMapping(inputShape0, inputShape1, outputShape);
|
||||
|
||||
// Initialize the input descriptions with broadcasting
|
||||
m_inputTensorDescs[DmlInputIndex::dmlA] = CreateTensorDescFromInput(kernelInfo, InputTensors::IN_A, TensorAxis::DoNotCoerce, TensorAxis::W, TensorAxis::RightAligned, inputShape0);
|
||||
m_inputTensorDescs[DmlInputIndex::dmlB] = CreateTensorDescFromInput(kernelInfo, InputTensors::IN_B, TensorAxis::DoNotCoerce, TensorAxis::W, TensorAxis::RightAligned, inputShape1);
|
||||
|
||||
// Broadcast Bias tensor to the shape of the output tensor.
|
||||
if(kernelInfo.IsInputValid(InputTensors::IN_BIAS)) {
|
||||
|
||||
m_inputTensorDescs[DmlInputIndex::dmlBias] = CreateTensorDescFromInput(kernelInfo, InputTensors::IN_BIAS, TensorAxis::DoNotCoerce,
|
||||
TensorAxis::W, TensorAxis::RightAligned, outputShape);
|
||||
}
|
||||
|
||||
uint32_t dmlDimSize = m_inputTensorDescs[DmlInputIndex::dmlA].GetDimensionCount();
|
||||
// Resize the A Scale to be the same dimension as the input tensor.
|
||||
// The 1D tensor needs to be moved to the H channel.
|
||||
m_inputTensorDescs[DmlInputIndex::dmlAScale] = CreateTensorDescFromInput(
|
||||
kernelInfo,
|
||||
InputTensors::IN_A_SCALE,
|
||||
TensorAxis::DoNotCoerce,
|
||||
TensorAxis::H,
|
||||
TensorAxis::LeftAligned,
|
||||
std::nullopt,
|
||||
dmlDimSize
|
||||
);
|
||||
|
||||
// Resize the A ZeroPoint to be the same dimension as the input tensor.
|
||||
// The 1D tensor needs to be moved to the H channel.
|
||||
if(kernelInfo.IsInputValid(InputTensors::IN_A_ZERO_POINT)) {
|
||||
|
||||
m_inputTensorDescs[DmlInputIndex::dmlAZeroPoint] = CreateTensorDescFromInput(
|
||||
kernelInfo,
|
||||
InputTensors::IN_A_ZERO_POINT,
|
||||
TensorAxis::DoNotCoerce,
|
||||
TensorAxis::H,
|
||||
TensorAxis::LeftAligned,
|
||||
std::nullopt,
|
||||
dmlDimSize
|
||||
);
|
||||
}
|
||||
|
||||
// B Zeropoint and BScale are already aligned in the W dimension so no need to align them
|
||||
|
||||
// Initialize the output description while overriding the shape
|
||||
m_outputTensorDescs[0] = CreateTensorDescFromOutput(kernelInfo, 0, TensorAxis::DoNotCoerce, TensorAxis::W, TensorAxis::RightAligned, outputShape);
|
||||
|
||||
std::vector<DML_TENSOR_DESC> inputDescs = GetDmlInputDescs();
|
||||
std::vector<DML_TENSOR_DESC> outputDescs = GetDmlOutputDescs();
|
||||
|
||||
DML_MATRIX_MULTIPLY_INTEGER_TO_FLOAT_OPERATOR_DESC matMulDesc = {};
|
||||
matMulDesc.ATensor = &inputDescs[DmlInputIndex::dmlA];
|
||||
matMulDesc.AScaleTensor = &inputDescs[DmlInputIndex::dmlAScale];
|
||||
matMulDesc.AZeroPointTensor = inputDescs[DmlInputIndex::dmlAZeroPoint].Desc != nullptr ? &inputDescs[DmlInputIndex::dmlAZeroPoint] : nullptr;
|
||||
matMulDesc.BTensor = &inputDescs[DmlInputIndex::dmlB];
|
||||
matMulDesc.BScaleTensor = &inputDescs[DmlInputIndex::dmlBScale];
|
||||
matMulDesc.BZeroPointTensor = inputDescs[DmlInputIndex::dmlBZeroPoint].Desc != nullptr ? &inputDescs[DmlInputIndex::dmlBZeroPoint] : nullptr;
|
||||
matMulDesc.BiasTensor = inputDescs[DmlInputIndex::dmlBias].Desc != nullptr ? &inputDescs[DmlInputIndex::dmlBias] : nullptr;
|
||||
matMulDesc.OutputTensor = &outputDescs[0];
|
||||
|
||||
DML_OPERATOR_DESC opDesc = { (DML_OPERATOR_TYPE) DML_OPERATOR_MATRIX_MULTIPLY_INTEGER_TO_FLOAT, &matMulDesc };
|
||||
SetDmlOperatorDesc(opDesc, kernelInfo);
|
||||
}
|
||||
};
|
||||
|
||||
DML_OP_DEFINE_CREATION_FUNCTION(MatMulIntegerToFloat, DmlOperatorMatMulIntegerToFloat);
|
||||
|
||||
} // namespace Dml
|
||||
|
|
@ -436,6 +436,7 @@ DML_OP_EXTERN_CREATION_FUNCTION(QLinearConv);
|
|||
DML_OP_EXTERN_CREATION_FUNCTION(QLinearMatMul);
|
||||
DML_OP_EXTERN_CREATION_FUNCTION(DynamicQuantizeLinear);
|
||||
DML_OP_EXTERN_CREATION_FUNCTION(MatMulInteger);
|
||||
DML_OP_EXTERN_CREATION_FUNCTION(MatMulIntegerToFloat);
|
||||
DML_OP_EXTERN_CREATION_FUNCTION(ConvInteger);
|
||||
DML_OP_EXTERN_CREATION_FUNCTION(Trilu);
|
||||
DML_OP_EXTERN_CREATION_FUNCTION(Shape);
|
||||
|
|
@ -543,6 +544,13 @@ constexpr static std::array<SupportedTensorDataTypes, 3> supportedTypeListQLinea
|
|||
SupportedTensorDataTypes::Int8|SupportedTensorDataTypes::UInt8,
|
||||
SupportedTensorDataTypes::Int8|SupportedTensorDataTypes::UInt8
|
||||
};
|
||||
|
||||
constexpr static std::array<SupportedTensorDataTypes, 3> supportedTypeListMatMulIntegerToFloat = {
|
||||
SupportedTensorDataTypes::Int8|SupportedTensorDataTypes::UInt8,
|
||||
SupportedTensorDataTypes::Int8|SupportedTensorDataTypes::UInt8,
|
||||
SupportedTensorDataTypes::Float32
|
||||
};
|
||||
|
||||
constexpr static std::array<SupportedTensorDataTypes, 4> supportedTypeListQLinearConv = {
|
||||
SupportedTensorDataTypes::Int8|SupportedTensorDataTypes::UInt8,
|
||||
SupportedTensorDataTypes::Int8|SupportedTensorDataTypes::UInt8,
|
||||
|
|
@ -968,6 +976,7 @@ constexpr static OperatorRegistrationInformation operatorRegistrationInformation
|
|||
{REG_INFO( 10, QLinearConv, typeNameListFour, supportedTypeListQLinearConv, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 10, QLinearMatMul, typeNameListThree, supportedTypeListQLinearMatMul, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 10, MatMulInteger, typeNameListThree, supportedTypeListInteger, DmlGraphSupport::Supported)},
|
||||
{REG_INFO_MS( 1, MatMulIntegerToFloat, typeNameListThree, supportedTypeListMatMulIntegerToFloat, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 10, ConvInteger, typeNameListThree, supportedTypeListInteger, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 11, DynamicQuantizeLinear, typeNameListTwo, supportedTypeListDynamicQuantizeLinear, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 7, LayerNormalization, typeNameListLayerNormContrib, supportedTypeListLayerNormalizationContrib, DmlGraphSupport::Supported, requiredConstantCpuInputs(), std::nullopt, QueryLayerNormalization)},
|
||||
|
|
|
|||
|
|
@ -822,6 +822,13 @@ public:
|
|||
QLinearMatMulHelper(const Info_t& info, const Shape_t& shape) : MatMulHelperBase(info, shape, 0, 3) {}
|
||||
};
|
||||
|
||||
class MatMulIntegerToFloatHelper : public MatMulHelperBase
|
||||
{
|
||||
public:
|
||||
template<typename Info_t, typename Shape_t>
|
||||
MatMulIntegerToFloatHelper(const Info_t& info, const Shape_t& shape) : MatMulHelperBase(info, shape, 0, 1) {}
|
||||
};
|
||||
|
||||
|
||||
class TopKHelper
|
||||
{
|
||||
|
|
@ -1657,6 +1664,7 @@ using ShapeInferenceHelper_Identity16 = GetOutputShapeAsInputShapeHelper;
|
|||
using ShapeInferenceHelper_MatMul = MatMulHelper;
|
||||
using ShapeInferenceHelper_MatMulInteger = MatMulHelper;
|
||||
using ShapeInferenceHelper_DynamicQuantizeMatMul = MatMulHelper;
|
||||
using ShapeInferenceHelper_MatMulIntegerToFloat = MatMulIntegerToFloatHelper;
|
||||
using ShapeInferenceHelper_QLinearMatMul = QLinearMatMulHelper;
|
||||
using ShapeInferenceHelper_QLinearAdd = GetBroadcastedOutputShapeHelper;
|
||||
using ShapeInferenceHelper_DynamicQuantizeLinear = GetOutputShapeAsInputShapeHelper;
|
||||
|
|
|
|||
|
|
@ -429,6 +429,7 @@ namespace OperatorHelper
|
|||
static const int sc_sinceVer_FusedMatMulActivation = 1;
|
||||
static const int sc_sinceVer_QLinearSigmoid = 1;
|
||||
static const int sc_sinceVer_Attention = 1;
|
||||
static const int sc_sinceVer_MatMulIntegerToFloat = 1;
|
||||
static const int sc_sinceVer_MultiHeadAttention = 1;
|
||||
static const int sc_sinceVer_SkipLayerNormalization = 1;
|
||||
static const int sc_sinceVer_EmbedLayerNormalization = 1;
|
||||
|
|
|
|||
|
|
@ -88,7 +88,12 @@ void TestMatMulIntegerToFloat(const std::vector<int64_t>& A_dims,
|
|||
}
|
||||
|
||||
test.AddReferenceOutputs(reference_model);
|
||||
#if defined(USE_DML)
|
||||
test.SetOutputRelErr("Y", 2e-2f);
|
||||
#else
|
||||
test.SetOutputRelErr("Y", 1e-4f);
|
||||
#endif
|
||||
|
||||
test.Run();
|
||||
}
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue