Add QLinearConcat for DML EP (#16971)

This commit is contained in:
Xiang Zhang 2023-08-17 15:15:25 -07:00 committed by GitHub
parent 4c8ef4080d
commit d3345f3680
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
7 changed files with 286 additions and 14 deletions

View file

@ -0,0 +1,236 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "precomp.h"
namespace Dml
{
// QLinearConcat = Dequantize + Join + Quantize
class DmlOperatorQLinearConcat : public DmlOperator, public QLinearConcatHelper
{
// This order matches the ONNX schema.
enum OnnxInputIndex
{
YScale,
YZeroPoint,
Count,
};
public:
DmlOperatorQLinearConcat(const MLOperatorKernelCreationContext& kernelCreationContext)
: DmlOperator(kernelCreationContext),
QLinearConcatHelper(kernelCreationContext, kernelCreationContext.GetTensorShapeDescription())
{
DmlOperator::Initialize(kernelCreationContext);
auto outputShape = kernelCreationContext.GetTensorShapeDescription().GetOutputTensorShape(0);
// inputs: {y_scale, y_zero_point, tuple(x_tensor, x_scale, x_zero_point)}
uint32_t inputDefinitionCount = kernelCreationContext.GetInputCount();
ML_CHECK_VALID_ARGUMENT(inputDefinitionCount >= 5, "Require at least 5 inputs.");
ML_CHECK_VALID_ARGUMENT((inputDefinitionCount - 2) % 3 == 0, "Each input must be (tensor, scale, zero_point) tuple!");
uint32_t inputCount = (inputDefinitionCount - 2) / 3;
auto yScaleDataType = kernelCreationContext.GetInputEdgeDescription(OnnxInputIndex::YScale).tensorDataType;
auto yZeroPointDataType = kernelCreationContext.GetInputEdgeDescription(OnnxInputIndex::YZeroPoint).tensorDataType;
// broadcast y_scale and y_zero_point to output shape
m_inputTensorDescs[OnnxInputIndex::YScale] = TensorDesc(
yScaleDataType,
outputShape,
kernelCreationContext.GetTensorShapeDescription().GetInputTensorShape(OnnxInputIndex::YScale),
TensorAxis::DoNotCoerce,
TensorAxis::W,
TensorAxis::RightAligned,
NchwDimensionCount, // minDimensionCount
0 // guaranteedBaseOffsetAlignment
);
m_inputTensorDescs[OnnxInputIndex::YZeroPoint] = TensorDesc(
yZeroPointDataType,
outputShape,
kernelCreationContext.GetTensorShapeDescription().GetInputTensorShape(OnnxInputIndex::YZeroPoint),
TensorAxis::DoNotCoerce,
TensorAxis::W,
TensorAxis::RightAligned,
NchwDimensionCount, // minDimensionCount
0 // guaranteedBaseOffsetAlignment
);
// Validate input tensors
for (uint32_t inputIndex = 0; inputIndex < inputCount; ++inputIndex)
{
// Inputs(input tensor, scale, zero_point) are in tuple and starting from index 2
auto tupleStartIndex = 2 + inputIndex * 3;
auto xScaleDataType = kernelCreationContext.GetInputEdgeDescription(tupleStartIndex + 1).tensorDataType;
auto xZeroPointDataType = kernelCreationContext.GetInputEdgeDescription(tupleStartIndex + 2).tensorDataType;
ML_CHECK_VALID_ARGUMENT(xScaleDataType == yScaleDataType, "Wrong input type encountered for scale");
ML_CHECK_VALID_ARGUMENT(xZeroPointDataType == yZeroPointDataType, "Wrong input type encountered for zero point");
// broadcast x_scale and x_zero_point to shape of corresponding x
m_inputTensorDescs[tupleStartIndex + 1] = TensorDesc(
xScaleDataType,
kernelCreationContext.GetTensorShapeDescription().GetInputTensorShape(tupleStartIndex),
kernelCreationContext.GetTensorShapeDescription().GetInputTensorShape(tupleStartIndex + 1),
TensorAxis::DoNotCoerce,
TensorAxis::W,
TensorAxis::RightAligned,
NchwDimensionCount, // minDimensionCount
0 // guaranteedBaseOffsetAlignment
);
m_inputTensorDescs[tupleStartIndex + 2] = TensorDesc(
xZeroPointDataType,
kernelCreationContext.GetTensorShapeDescription().GetInputTensorShape(tupleStartIndex),
kernelCreationContext.GetTensorShapeDescription().GetInputTensorShape(tupleStartIndex + 2),
TensorAxis::DoNotCoerce,
TensorAxis::W,
TensorAxis::RightAligned,
NchwDimensionCount, // minDimensionCount
0 // guaranteedBaseOffsetAlignment
);
}
uint32_t dmlAxis = GetDmlAdjustedAxis(m_axis, kernelCreationContext, m_inputTensorDescs.front().GetDimensionCount(), 2);
std::vector<DML_TENSOR_DESC> inputDescs = GetDmlInputDescs();
std::vector<DML_TENSOR_DESC> outputDescs = GetDmlOutputDescs();
// 1. output edges between Dequantize and Join node
// 2. input edge between Join and Quantize node
std::vector<TensorDesc> intermediateOutputTensorDescs(inputCount);
std::vector<DML_TENSOR_DESC> namedDequantizeOperatorDescs(inputCount);
std::vector<DML_ELEMENT_WISE_DEQUANTIZE_LINEAR_OPERATOR_DESC> dequantizeOperatorDescs(inputCount);
std::vector<DML_OPERATOR_DESC> dmlOpDesc(inputCount);
std::vector<const DML_OPERATOR_DESC*> opDescs;
for (uint32_t inputIndex = 0; inputIndex < inputCount; ++inputIndex)
{
auto tupleStartIndex = 2 + inputIndex * 3;
intermediateOutputTensorDescs[inputIndex] = TensorDesc(
MLOperatorTensorDataType::Float,
kernelCreationContext.GetTensorShapeDescription().GetInputTensorShape(tupleStartIndex),
kernelCreationContext.GetTensorShapeDescription().GetInputTensorShape(tupleStartIndex),
TensorAxis::DoNotCoerce,
TensorAxis::W,
TensorAxis::RightAligned,
NchwDimensionCount, // minDimensionCount
0 // guaranteedBaseOffsetAlignment)
);
namedDequantizeOperatorDescs[inputIndex] = intermediateOutputTensorDescs[inputIndex].GetDmlDesc();
dequantizeOperatorDescs[inputIndex].InputTensor = &inputDescs[tupleStartIndex];
dequantizeOperatorDescs[inputIndex].ScaleTensor = &inputDescs[tupleStartIndex + 1];
dequantizeOperatorDescs[inputIndex].ZeroPointTensor = &inputDescs[tupleStartIndex + 2];
dequantizeOperatorDescs[inputIndex].OutputTensor = &namedDequantizeOperatorDescs[inputIndex];
dmlOpDesc[inputIndex] = {DML_OPERATOR_ELEMENT_WISE_DEQUANTIZE_LINEAR, &dequantizeOperatorDescs[inputIndex]};
opDescs.push_back(&dmlOpDesc[inputIndex]);
}
TensorDesc joinOutputTensorDesc = TensorDesc(
MLOperatorTensorDataType::Float,
outputShape,
outputShape,
TensorAxis::DoNotCoerce,
TensorAxis::W,
TensorAxis::RightAligned,
NchwDimensionCount, // minDimensionCount
0 // guaranteedBaseOffsetAlignment
);
DML_TENSOR_DESC namedJoinOutputTensorDesc = joinOutputTensorDesc.GetDmlDesc();
DML_JOIN_OPERATOR_DESC joinDesc = {};
joinDesc.InputCount = gsl::narrow_cast<uint32_t>(namedDequantizeOperatorDescs.size());
joinDesc.InputTensors = namedDequantizeOperatorDescs.data();
joinDesc.OutputTensor = &namedJoinOutputTensorDesc;
joinDesc.Axis = dmlAxis;
const DML_OPERATOR_DESC opJoinDesc = {DML_OPERATOR_JOIN, &joinDesc};
opDescs.push_back(&opJoinDesc);
DML_ELEMENT_WISE_QUANTIZE_LINEAR_OPERATOR_DESC quantizeOperatorDesc = {};
quantizeOperatorDesc.InputTensor = joinDesc.OutputTensor;
quantizeOperatorDesc.ScaleTensor = &inputDescs[OnnxInputIndex::YScale];
quantizeOperatorDesc.ZeroPointTensor = &inputDescs[OnnxInputIndex::YZeroPoint];
quantizeOperatorDesc.OutputTensor = &outputDescs[0];
const DML_OPERATOR_DESC opQuantizeDesc = {DML_OPERATOR_ELEMENT_WISE_QUANTIZE_LINEAR, &quantizeOperatorDesc};
opDescs.push_back(&opQuantizeDesc);
MLOperatorGraphDesc operatorGraphDesc = {};
operatorGraphDesc.nodeCount = static_cast<uint32_t>(opDescs.size());
operatorGraphDesc.nodesAsOpDesc = opDescs.data();
uint32_t joinNodeIndex = operatorGraphDesc.nodeCount - 2;
uint32_t quantizeNodeIndex = operatorGraphDesc.nodeCount - 1;
std::vector<DML_INPUT_GRAPH_EDGE_DESC> inputEdges;
// Input edges to Dequantize nodes
for (uint32_t inputIndex = 0; inputIndex < inputCount; ++inputIndex)
{
auto tupleStartIndex = 2 + inputIndex * 3;
for (auto edge_index = 0; edge_index < 3; ++edge_index)
{
DML_INPUT_GRAPH_EDGE_DESC inputEdge = {};
inputEdge.GraphInputIndex = tupleStartIndex + edge_index;
inputEdge.ToNodeIndex = inputIndex;
inputEdge.ToNodeInputIndex = edge_index;
inputEdges.push_back(inputEdge);
}
}
// Input edge from y_scale to quantize node
DML_INPUT_GRAPH_EDGE_DESC yScaleInputEdge = {};
yScaleInputEdge.GraphInputIndex = 0; // Y_scale
yScaleInputEdge.ToNodeIndex = quantizeNodeIndex;
yScaleInputEdge.ToNodeInputIndex = 1;
inputEdges.push_back(yScaleInputEdge);
// Input edge from y_zero_point to quantize node
DML_INPUT_GRAPH_EDGE_DESC yZeroPointInputEdge = {};
yZeroPointInputEdge.GraphInputIndex = 1; // Y_zero_point
yZeroPointInputEdge.ToNodeIndex = quantizeNodeIndex;
yZeroPointInputEdge.ToNodeInputIndex = 2;
inputEdges.push_back(yZeroPointInputEdge);
operatorGraphDesc.inputEdgeCount = gsl::narrow_cast<uint32_t>(inputEdges.size());
operatorGraphDesc.inputEdges = inputEdges.data();
// set intermediate edges
std::vector<DML_INTERMEDIATE_GRAPH_EDGE_DESC> intermediateEdges;
for (uint32_t inputIndex = 0; inputIndex < inputCount; ++inputIndex)
{
DML_INTERMEDIATE_GRAPH_EDGE_DESC dequantizeToJoinEdge = {};
dequantizeToJoinEdge.FromNodeIndex = inputIndex;
dequantizeToJoinEdge.FromNodeOutputIndex = 0;
dequantizeToJoinEdge.ToNodeIndex = joinNodeIndex; // The second last node Join
dequantizeToJoinEdge.ToNodeInputIndex = inputIndex;
intermediateEdges.push_back(dequantizeToJoinEdge);
}
DML_INTERMEDIATE_GRAPH_EDGE_DESC joinToQuantizeEdge = {};
joinToQuantizeEdge.FromNodeIndex = joinNodeIndex;
joinToQuantizeEdge.FromNodeOutputIndex = 0;
joinToQuantizeEdge.ToNodeIndex = quantizeNodeIndex; // The second last node Join
joinToQuantizeEdge.ToNodeInputIndex = 0;
intermediateEdges.push_back(joinToQuantizeEdge);
operatorGraphDesc.intermediateEdgeCount = gsl::narrow_cast<uint32_t>(intermediateEdges.size());
operatorGraphDesc.intermediateEdges = intermediateEdges.data();
// set the output edges
std::vector<DML_OUTPUT_GRAPH_EDGE_DESC> outputEdges;
DML_OUTPUT_GRAPH_EDGE_DESC outputEdge = {};
outputEdge.FromNodeIndex = quantizeNodeIndex;
outputEdge.FromNodeOutputIndex = 0;
outputEdge.GraphOutputIndex = 0;
outputEdges.push_back(outputEdge);
operatorGraphDesc.outputEdgeCount = gsl::narrow_cast<uint32_t>(outputEdges.size());
operatorGraphDesc.outputEdges = outputEdges.data();
SetDmlOperatorGraphDesc(std::move(operatorGraphDesc), kernelCreationContext);
};
};
DML_OP_DEFINE_CREATION_FUNCTION(QLinearConcat, DmlOperatorQLinearConcat);
} // namespace Dml

View file

@ -434,6 +434,7 @@ DML_OP_EXTERN_CREATION_FUNCTION(ScatterND);
DML_OP_EXTERN_CREATION_FUNCTION(QLinearAdd);
DML_OP_EXTERN_CREATION_FUNCTION(QLinearConv);
DML_OP_EXTERN_CREATION_FUNCTION(QLinearMatMul);
DML_OP_EXTERN_CREATION_FUNCTION(QLinearConcat);
DML_OP_EXTERN_CREATION_FUNCTION(DynamicQuantizeLinear);
DML_OP_EXTERN_CREATION_FUNCTION(MatMulInteger);
DML_OP_EXTERN_CREATION_FUNCTION(MatMulIntegerToFloat);
@ -486,6 +487,7 @@ constexpr static std::array<const char*, 2> typeNameListEyeLike = { "T1", "T2" }
constexpr static std::array<const char*, 2> typeNameShape = { "T", "T1" };
constexpr static std::array<const char*, 2> typeNameSize = { "T", "T1" };
constexpr static std::array<const char*, 2> typeNameListGroupNorm = {"T", "M"};
constexpr static std::array<const char*, 3> typeNameListQLinearConcat= {"TF", "T8", "TV"};
constexpr static std::array<SupportedTensorDataTypes, 1> supportedTypeListAll = {SupportedTensorDataTypes::All};
constexpr static std::array<SupportedTensorDataTypes, 1> supportedTypeListFloat32 = {SupportedTensorDataTypes::Float32};
@ -571,12 +573,18 @@ constexpr static std::array<SupportedTensorDataTypes, 4> supportedTypeListQLinea
constexpr static std::array<SupportedTensorDataTypes, 2> supportedTypeListDynamicQuantizeLinear = {
SupportedTensorDataTypes::Float32,
SupportedTensorDataTypes::Int8|SupportedTensorDataTypes::UInt8
SupportedTensorDataTypes::Ints8Bit
};
constexpr static std::array<SupportedTensorDataTypes, 2> supportedTypeListDynamicQuantizeMatMul= {
SupportedTensorDataTypes::Float32,
SupportedTensorDataTypes::Int8|SupportedTensorDataTypes::UInt8,
SupportedTensorDataTypes::Ints8Bit,
};
constexpr static std::array<SupportedTensorDataTypes, 3> supportedTypeListQLinearConcat= {
SupportedTensorDataTypes::Float32,
SupportedTensorDataTypes::Ints8Bit,
SupportedTensorDataTypes::Ints8Bit|SupportedTensorDataTypes::Float32,
};
template<typename... Args>
@ -969,6 +977,7 @@ constexpr static OperatorRegistrationInformation operatorRegistrationInformation
{REG_INFO_MS( 1, QAttention, typeNameListFour, supportedTypeListQAttention, DmlGraphSupport::Supported, requiredConstantCpuInputs(), std::nullopt, QueryQAttention)},
{REG_INFO_MS( 1, Attention, typeNameListAttention, supportedTypeListAttention, DmlGraphSupport::Supported, requiredConstantCpuInputs(), std::nullopt, QueryAttention)},
{REG_INFO_MS( 1, MultiHeadAttention, typeNameListAttention, supportedTypeListAttention, DmlGraphSupport::Supported)},
{REG_INFO_MS( 1, QLinearConcat, typeNameListQLinearConcat, supportedTypeListQLinearConcat, DmlGraphSupport::Supported)},
{REG_INFO( 10, IsInf, typeNameListTwo, supportedTypeListIsInf, DmlGraphSupport::Supported)},
{REG_INFO( 10, Mod, typeNameListDefault, supportedTypeListNumericDefault, DmlGraphSupport::Supported)},

View file

@ -419,9 +419,9 @@ namespace Dml
} // namespace FusionHelpers
uint32_t GetDmlAdjustedAxis(int32_t onnxAxis, const MLOperatorKernelCreationContext& kernelCreationContext, uint32_t dmlDimCount)
uint32_t GetDmlAdjustedAxis(int32_t onnxAxis, const MLOperatorKernelCreationContext& kernelCreationContext, uint32_t dmlDimCount, uint32_t firstInputIndex)
{
const std::vector<DimensionType> inputDimensions = kernelCreationContext.GetTensorShapeDescription().GetInputTensorShape(0);
const std::vector<DimensionType> inputDimensions = kernelCreationContext.GetTensorShapeDescription().GetInputTensorShape(firstInputIndex);
uint32_t onnxDimCount = gsl::narrow_cast<uint32_t>(inputDimensions.size());
onnxAxis = HandleNegativeAxis(onnxAxis, onnxDimCount);
return GetDmlAdjustedAxis(onnxAxis, onnxDimCount, dmlDimCount);

View file

@ -64,8 +64,7 @@ namespace Dml
} // namespace FusionHelpers
// Given an axis in ONNX axis numbering, return the axis adjusted for DML based on how the sizes have been coerced.
// Note this function presumes the axis attribute is relative to the first input tensor (which is always the case).
uint32_t GetDmlAdjustedAxis(int32_t onnxAxis, const MLOperatorKernelCreationContext& kernelCreationContext, uint32_t dmlDimCount);
uint32_t GetDmlAdjustedAxis(int32_t onnxAxis, const MLOperatorKernelCreationContext& kernelCreationContext, uint32_t dmlDimCount, uint32_t firstInputIndex = 0);
uint32_t GetDmlAdjustedAxis(int32_t onnxAxis, uint32_t onnxDimCount, uint32_t dmlDimCount);

View file

@ -1862,7 +1862,7 @@ namespace OperatorHelper
return { std::move(outputShape) };
}
void ConcatHelper::Initialize(
void ConcatHelperBase::Initialize(
const MLOperatorAttributes& operatorAttributes,
gsl::span<const DimensionType> inputDimensions
)
@ -1872,13 +1872,13 @@ namespace OperatorHelper
ML_CHECK_VALID_ARGUMENT(m_axis < static_cast<int>(inputDimensions.size()));
}
std::vector<EdgeShapes> ConcatHelper::GetOutputShapes(const MLShapeInferenceContext& shapeInfo) const
std::vector<EdgeShapes> ConcatHelperBase::GetOutputShapes(const MLShapeInferenceContext& shapeInfo, uint32_t firstInputIndex, uint32_t step) const
{
auto outputShape = shapeInfo.GetInputTensorShape(0);
auto outputShape = shapeInfo.GetInputTensorShape(firstInputIndex);
uint32_t inputCount = shapeInfo.GetInputCount();
for (uint32_t i = 1; i < inputCount; ++i)
for (uint32_t i = firstInputIndex + step; i < inputCount; i += step)
{
auto inputShape = shapeInfo.GetInputTensorShape(i);
for (size_t j = 0; j < outputShape.size(); ++j)
@ -1893,6 +1893,16 @@ namespace OperatorHelper
return { EdgeShapes(outputShape) };
}
std::vector<EdgeShapes> ConcatHelper::GetOutputShapes(const MLShapeInferenceContext& shapeInfo) const
{
return ConcatHelperBase::GetOutputShapes(shapeInfo, 0, 1);
}
std::vector<EdgeShapes> QLinearConcatHelper::GetOutputShapes(const MLShapeInferenceContext& shapeInfo) const
{
return ConcatHelperBase::GetOutputShapes(shapeInfo, 2, 3);
}
void CropHelper::Initialize(
const MLOperatorAttributes& operatorAttributes,
gsl::span<const DimensionType> inputDimensions

View file

@ -871,7 +871,7 @@ protected:
int m_hiddenSize = 0;
};
class ConcatHelper
class ConcatHelperBase
{
public:
void Initialize(
@ -882,17 +882,33 @@ public:
// Info_t is used to obtain attributes which will be used for calculating the output shape later.
// Shape_t is used to obtain input shape which will be used for adjusting attribute value.
template <typename Info_t, typename Shape_t>
ConcatHelper(const Info_t& info, const Shape_t& shape)
ConcatHelperBase(const Info_t& info, const Shape_t& shape, uint32_t firstInputIndex)
{
Initialize(info, shape.GetInputTensorShape(0));
Initialize(info, shape.GetInputTensorShape(firstInputIndex));
}
std::vector<EdgeShapes> GetOutputShapes(const MLShapeInferenceContext& shapeInfo) const;
std::vector<EdgeShapes> GetOutputShapes(const MLShapeInferenceContext& shapeInfo, uint32_t firstInputIndex, uint32_t step) const;
protected:
int m_axis;
};
class ConcatHelper: public ConcatHelperBase
{
public:
template<typename Info_t, typename Shape_t>
ConcatHelper(const Info_t& info, const Shape_t& shape) : ConcatHelperBase(info, shape, 0) {}
std::vector<EdgeShapes> GetOutputShapes(const MLShapeInferenceContext& shapeInfo) const;
};
class QLinearConcatHelper: public ConcatHelperBase
{
public:
template<typename Info_t, typename Shape_t>
QLinearConcatHelper(const Info_t& info, const Shape_t& shape) : ConcatHelperBase(info, shape, 2) {}
std::vector<EdgeShapes> GetOutputShapes(const MLShapeInferenceContext& shapeInfo) const;
};
class CropHelper
{
public:
@ -1519,6 +1535,7 @@ using ShapeInferenceHelper_Split13 = VersionedOpsetHelper<SplitHelper, 13>;
using ShapeInferenceHelper_Split18 = VersionedOpsetHelper<SplitHelper, 18>;
using ShapeInferenceHelper_Transpose = TransposeHelper;
using ShapeInferenceHelper_Concat = ConcatHelper;
using ShapeInferenceHelper_QLinearConcat = QLinearConcatHelper;
using ShapeInferenceHelper_Slice7 = VersionedOpsetHelper<SliceHelper, 7>;
using ShapeInferenceHelper_Slice10 = VersionedOpsetHelper<SliceHelper, 10>;
using ShapeInferenceHelper_Slice11 = VersionedOpsetHelper<SliceHelper, 11>; // Note 11 and 10 are identical - no functional change.

View file

@ -440,6 +440,7 @@ namespace OperatorHelper
static const int sc_sinceVer_QuickGelu = 1;
static const int sc_sinceVer_GroupNorm = 1;
static const int sc_sinceVer_DynamicQuantizeMatMul = 1;
static const int sc_sinceVer_QLinearConcat = 1;
} // namespace MsftOperatorSet1
} // namespace OperatorHelper