mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-01 23:30:35 +00:00
Add QLinearConcat for DML EP (#16971)
This commit is contained in:
parent
4c8ef4080d
commit
d3345f3680
7 changed files with 286 additions and 14 deletions
|
|
@ -0,0 +1,236 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include "precomp.h"
|
||||
|
||||
namespace Dml
|
||||
{
|
||||
// QLinearConcat = Dequantize + Join + Quantize
|
||||
class DmlOperatorQLinearConcat : public DmlOperator, public QLinearConcatHelper
|
||||
{
|
||||
// This order matches the ONNX schema.
|
||||
enum OnnxInputIndex
|
||||
{
|
||||
YScale,
|
||||
YZeroPoint,
|
||||
Count,
|
||||
};
|
||||
|
||||
public:
|
||||
DmlOperatorQLinearConcat(const MLOperatorKernelCreationContext& kernelCreationContext)
|
||||
: DmlOperator(kernelCreationContext),
|
||||
QLinearConcatHelper(kernelCreationContext, kernelCreationContext.GetTensorShapeDescription())
|
||||
{
|
||||
DmlOperator::Initialize(kernelCreationContext);
|
||||
|
||||
auto outputShape = kernelCreationContext.GetTensorShapeDescription().GetOutputTensorShape(0);
|
||||
|
||||
// inputs: {y_scale, y_zero_point, tuple(x_tensor, x_scale, x_zero_point)}
|
||||
uint32_t inputDefinitionCount = kernelCreationContext.GetInputCount();
|
||||
ML_CHECK_VALID_ARGUMENT(inputDefinitionCount >= 5, "Require at least 5 inputs.");
|
||||
ML_CHECK_VALID_ARGUMENT((inputDefinitionCount - 2) % 3 == 0, "Each input must be (tensor, scale, zero_point) tuple!");
|
||||
|
||||
uint32_t inputCount = (inputDefinitionCount - 2) / 3;
|
||||
|
||||
auto yScaleDataType = kernelCreationContext.GetInputEdgeDescription(OnnxInputIndex::YScale).tensorDataType;
|
||||
auto yZeroPointDataType = kernelCreationContext.GetInputEdgeDescription(OnnxInputIndex::YZeroPoint).tensorDataType;
|
||||
|
||||
// broadcast y_scale and y_zero_point to output shape
|
||||
m_inputTensorDescs[OnnxInputIndex::YScale] = TensorDesc(
|
||||
yScaleDataType,
|
||||
outputShape,
|
||||
kernelCreationContext.GetTensorShapeDescription().GetInputTensorShape(OnnxInputIndex::YScale),
|
||||
TensorAxis::DoNotCoerce,
|
||||
TensorAxis::W,
|
||||
TensorAxis::RightAligned,
|
||||
NchwDimensionCount, // minDimensionCount
|
||||
0 // guaranteedBaseOffsetAlignment
|
||||
);
|
||||
|
||||
m_inputTensorDescs[OnnxInputIndex::YZeroPoint] = TensorDesc(
|
||||
yZeroPointDataType,
|
||||
outputShape,
|
||||
kernelCreationContext.GetTensorShapeDescription().GetInputTensorShape(OnnxInputIndex::YZeroPoint),
|
||||
TensorAxis::DoNotCoerce,
|
||||
TensorAxis::W,
|
||||
TensorAxis::RightAligned,
|
||||
NchwDimensionCount, // minDimensionCount
|
||||
0 // guaranteedBaseOffsetAlignment
|
||||
);
|
||||
|
||||
// Validate input tensors
|
||||
for (uint32_t inputIndex = 0; inputIndex < inputCount; ++inputIndex)
|
||||
{
|
||||
// Inputs(input tensor, scale, zero_point) are in tuple and starting from index 2
|
||||
auto tupleStartIndex = 2 + inputIndex * 3;
|
||||
auto xScaleDataType = kernelCreationContext.GetInputEdgeDescription(tupleStartIndex + 1).tensorDataType;
|
||||
auto xZeroPointDataType = kernelCreationContext.GetInputEdgeDescription(tupleStartIndex + 2).tensorDataType;
|
||||
ML_CHECK_VALID_ARGUMENT(xScaleDataType == yScaleDataType, "Wrong input type encountered for scale");
|
||||
ML_CHECK_VALID_ARGUMENT(xZeroPointDataType == yZeroPointDataType, "Wrong input type encountered for zero point");
|
||||
|
||||
// broadcast x_scale and x_zero_point to shape of corresponding x
|
||||
m_inputTensorDescs[tupleStartIndex + 1] = TensorDesc(
|
||||
xScaleDataType,
|
||||
kernelCreationContext.GetTensorShapeDescription().GetInputTensorShape(tupleStartIndex),
|
||||
kernelCreationContext.GetTensorShapeDescription().GetInputTensorShape(tupleStartIndex + 1),
|
||||
TensorAxis::DoNotCoerce,
|
||||
TensorAxis::W,
|
||||
TensorAxis::RightAligned,
|
||||
NchwDimensionCount, // minDimensionCount
|
||||
0 // guaranteedBaseOffsetAlignment
|
||||
);
|
||||
|
||||
m_inputTensorDescs[tupleStartIndex + 2] = TensorDesc(
|
||||
xZeroPointDataType,
|
||||
kernelCreationContext.GetTensorShapeDescription().GetInputTensorShape(tupleStartIndex),
|
||||
kernelCreationContext.GetTensorShapeDescription().GetInputTensorShape(tupleStartIndex + 2),
|
||||
TensorAxis::DoNotCoerce,
|
||||
TensorAxis::W,
|
||||
TensorAxis::RightAligned,
|
||||
NchwDimensionCount, // minDimensionCount
|
||||
0 // guaranteedBaseOffsetAlignment
|
||||
);
|
||||
}
|
||||
|
||||
uint32_t dmlAxis = GetDmlAdjustedAxis(m_axis, kernelCreationContext, m_inputTensorDescs.front().GetDimensionCount(), 2);
|
||||
|
||||
std::vector<DML_TENSOR_DESC> inputDescs = GetDmlInputDescs();
|
||||
std::vector<DML_TENSOR_DESC> outputDescs = GetDmlOutputDescs();
|
||||
|
||||
// 1. output edges between Dequantize and Join node
|
||||
// 2. input edge between Join and Quantize node
|
||||
std::vector<TensorDesc> intermediateOutputTensorDescs(inputCount);
|
||||
std::vector<DML_TENSOR_DESC> namedDequantizeOperatorDescs(inputCount);
|
||||
std::vector<DML_ELEMENT_WISE_DEQUANTIZE_LINEAR_OPERATOR_DESC> dequantizeOperatorDescs(inputCount);
|
||||
std::vector<DML_OPERATOR_DESC> dmlOpDesc(inputCount);
|
||||
std::vector<const DML_OPERATOR_DESC*> opDescs;
|
||||
for (uint32_t inputIndex = 0; inputIndex < inputCount; ++inputIndex)
|
||||
{
|
||||
auto tupleStartIndex = 2 + inputIndex * 3;
|
||||
intermediateOutputTensorDescs[inputIndex] = TensorDesc(
|
||||
MLOperatorTensorDataType::Float,
|
||||
kernelCreationContext.GetTensorShapeDescription().GetInputTensorShape(tupleStartIndex),
|
||||
kernelCreationContext.GetTensorShapeDescription().GetInputTensorShape(tupleStartIndex),
|
||||
TensorAxis::DoNotCoerce,
|
||||
TensorAxis::W,
|
||||
TensorAxis::RightAligned,
|
||||
NchwDimensionCount, // minDimensionCount
|
||||
0 // guaranteedBaseOffsetAlignment)
|
||||
);
|
||||
namedDequantizeOperatorDescs[inputIndex] = intermediateOutputTensorDescs[inputIndex].GetDmlDesc();
|
||||
|
||||
dequantizeOperatorDescs[inputIndex].InputTensor = &inputDescs[tupleStartIndex];
|
||||
dequantizeOperatorDescs[inputIndex].ScaleTensor = &inputDescs[tupleStartIndex + 1];
|
||||
dequantizeOperatorDescs[inputIndex].ZeroPointTensor = &inputDescs[tupleStartIndex + 2];
|
||||
dequantizeOperatorDescs[inputIndex].OutputTensor = &namedDequantizeOperatorDescs[inputIndex];
|
||||
|
||||
dmlOpDesc[inputIndex] = {DML_OPERATOR_ELEMENT_WISE_DEQUANTIZE_LINEAR, &dequantizeOperatorDescs[inputIndex]};
|
||||
opDescs.push_back(&dmlOpDesc[inputIndex]);
|
||||
}
|
||||
|
||||
TensorDesc joinOutputTensorDesc = TensorDesc(
|
||||
MLOperatorTensorDataType::Float,
|
||||
outputShape,
|
||||
outputShape,
|
||||
TensorAxis::DoNotCoerce,
|
||||
TensorAxis::W,
|
||||
TensorAxis::RightAligned,
|
||||
NchwDimensionCount, // minDimensionCount
|
||||
0 // guaranteedBaseOffsetAlignment
|
||||
);
|
||||
DML_TENSOR_DESC namedJoinOutputTensorDesc = joinOutputTensorDesc.GetDmlDesc();
|
||||
|
||||
DML_JOIN_OPERATOR_DESC joinDesc = {};
|
||||
joinDesc.InputCount = gsl::narrow_cast<uint32_t>(namedDequantizeOperatorDescs.size());
|
||||
joinDesc.InputTensors = namedDequantizeOperatorDescs.data();
|
||||
joinDesc.OutputTensor = &namedJoinOutputTensorDesc;
|
||||
joinDesc.Axis = dmlAxis;
|
||||
|
||||
const DML_OPERATOR_DESC opJoinDesc = {DML_OPERATOR_JOIN, &joinDesc};
|
||||
opDescs.push_back(&opJoinDesc);
|
||||
|
||||
DML_ELEMENT_WISE_QUANTIZE_LINEAR_OPERATOR_DESC quantizeOperatorDesc = {};
|
||||
quantizeOperatorDesc.InputTensor = joinDesc.OutputTensor;
|
||||
quantizeOperatorDesc.ScaleTensor = &inputDescs[OnnxInputIndex::YScale];
|
||||
quantizeOperatorDesc.ZeroPointTensor = &inputDescs[OnnxInputIndex::YZeroPoint];
|
||||
quantizeOperatorDesc.OutputTensor = &outputDescs[0];
|
||||
const DML_OPERATOR_DESC opQuantizeDesc = {DML_OPERATOR_ELEMENT_WISE_QUANTIZE_LINEAR, &quantizeOperatorDesc};
|
||||
opDescs.push_back(&opQuantizeDesc);
|
||||
|
||||
MLOperatorGraphDesc operatorGraphDesc = {};
|
||||
operatorGraphDesc.nodeCount = static_cast<uint32_t>(opDescs.size());
|
||||
operatorGraphDesc.nodesAsOpDesc = opDescs.data();
|
||||
|
||||
uint32_t joinNodeIndex = operatorGraphDesc.nodeCount - 2;
|
||||
uint32_t quantizeNodeIndex = operatorGraphDesc.nodeCount - 1;
|
||||
|
||||
std::vector<DML_INPUT_GRAPH_EDGE_DESC> inputEdges;
|
||||
// Input edges to Dequantize nodes
|
||||
for (uint32_t inputIndex = 0; inputIndex < inputCount; ++inputIndex)
|
||||
{
|
||||
auto tupleStartIndex = 2 + inputIndex * 3;
|
||||
for (auto edge_index = 0; edge_index < 3; ++edge_index)
|
||||
{
|
||||
DML_INPUT_GRAPH_EDGE_DESC inputEdge = {};
|
||||
inputEdge.GraphInputIndex = tupleStartIndex + edge_index;
|
||||
inputEdge.ToNodeIndex = inputIndex;
|
||||
inputEdge.ToNodeInputIndex = edge_index;
|
||||
inputEdges.push_back(inputEdge);
|
||||
}
|
||||
}
|
||||
|
||||
// Input edge from y_scale to quantize node
|
||||
DML_INPUT_GRAPH_EDGE_DESC yScaleInputEdge = {};
|
||||
yScaleInputEdge.GraphInputIndex = 0; // Y_scale
|
||||
yScaleInputEdge.ToNodeIndex = quantizeNodeIndex;
|
||||
yScaleInputEdge.ToNodeInputIndex = 1;
|
||||
inputEdges.push_back(yScaleInputEdge);
|
||||
|
||||
// Input edge from y_zero_point to quantize node
|
||||
DML_INPUT_GRAPH_EDGE_DESC yZeroPointInputEdge = {};
|
||||
yZeroPointInputEdge.GraphInputIndex = 1; // Y_zero_point
|
||||
yZeroPointInputEdge.ToNodeIndex = quantizeNodeIndex;
|
||||
yZeroPointInputEdge.ToNodeInputIndex = 2;
|
||||
inputEdges.push_back(yZeroPointInputEdge);
|
||||
|
||||
operatorGraphDesc.inputEdgeCount = gsl::narrow_cast<uint32_t>(inputEdges.size());
|
||||
operatorGraphDesc.inputEdges = inputEdges.data();
|
||||
|
||||
// set intermediate edges
|
||||
std::vector<DML_INTERMEDIATE_GRAPH_EDGE_DESC> intermediateEdges;
|
||||
for (uint32_t inputIndex = 0; inputIndex < inputCount; ++inputIndex)
|
||||
{
|
||||
DML_INTERMEDIATE_GRAPH_EDGE_DESC dequantizeToJoinEdge = {};
|
||||
dequantizeToJoinEdge.FromNodeIndex = inputIndex;
|
||||
dequantizeToJoinEdge.FromNodeOutputIndex = 0;
|
||||
dequantizeToJoinEdge.ToNodeIndex = joinNodeIndex; // The second last node Join
|
||||
dequantizeToJoinEdge.ToNodeInputIndex = inputIndex;
|
||||
intermediateEdges.push_back(dequantizeToJoinEdge);
|
||||
}
|
||||
|
||||
DML_INTERMEDIATE_GRAPH_EDGE_DESC joinToQuantizeEdge = {};
|
||||
joinToQuantizeEdge.FromNodeIndex = joinNodeIndex;
|
||||
joinToQuantizeEdge.FromNodeOutputIndex = 0;
|
||||
joinToQuantizeEdge.ToNodeIndex = quantizeNodeIndex; // The second last node Join
|
||||
joinToQuantizeEdge.ToNodeInputIndex = 0;
|
||||
intermediateEdges.push_back(joinToQuantizeEdge);
|
||||
|
||||
operatorGraphDesc.intermediateEdgeCount = gsl::narrow_cast<uint32_t>(intermediateEdges.size());
|
||||
operatorGraphDesc.intermediateEdges = intermediateEdges.data();
|
||||
|
||||
// set the output edges
|
||||
std::vector<DML_OUTPUT_GRAPH_EDGE_DESC> outputEdges;
|
||||
DML_OUTPUT_GRAPH_EDGE_DESC outputEdge = {};
|
||||
outputEdge.FromNodeIndex = quantizeNodeIndex;
|
||||
outputEdge.FromNodeOutputIndex = 0;
|
||||
outputEdge.GraphOutputIndex = 0;
|
||||
outputEdges.push_back(outputEdge);
|
||||
operatorGraphDesc.outputEdgeCount = gsl::narrow_cast<uint32_t>(outputEdges.size());
|
||||
operatorGraphDesc.outputEdges = outputEdges.data();
|
||||
|
||||
SetDmlOperatorGraphDesc(std::move(operatorGraphDesc), kernelCreationContext);
|
||||
};
|
||||
};
|
||||
|
||||
DML_OP_DEFINE_CREATION_FUNCTION(QLinearConcat, DmlOperatorQLinearConcat);
|
||||
} // namespace Dml
|
||||
|
|
@ -434,6 +434,7 @@ DML_OP_EXTERN_CREATION_FUNCTION(ScatterND);
|
|||
DML_OP_EXTERN_CREATION_FUNCTION(QLinearAdd);
|
||||
DML_OP_EXTERN_CREATION_FUNCTION(QLinearConv);
|
||||
DML_OP_EXTERN_CREATION_FUNCTION(QLinearMatMul);
|
||||
DML_OP_EXTERN_CREATION_FUNCTION(QLinearConcat);
|
||||
DML_OP_EXTERN_CREATION_FUNCTION(DynamicQuantizeLinear);
|
||||
DML_OP_EXTERN_CREATION_FUNCTION(MatMulInteger);
|
||||
DML_OP_EXTERN_CREATION_FUNCTION(MatMulIntegerToFloat);
|
||||
|
|
@ -486,6 +487,7 @@ constexpr static std::array<const char*, 2> typeNameListEyeLike = { "T1", "T2" }
|
|||
constexpr static std::array<const char*, 2> typeNameShape = { "T", "T1" };
|
||||
constexpr static std::array<const char*, 2> typeNameSize = { "T", "T1" };
|
||||
constexpr static std::array<const char*, 2> typeNameListGroupNorm = {"T", "M"};
|
||||
constexpr static std::array<const char*, 3> typeNameListQLinearConcat= {"TF", "T8", "TV"};
|
||||
|
||||
constexpr static std::array<SupportedTensorDataTypes, 1> supportedTypeListAll = {SupportedTensorDataTypes::All};
|
||||
constexpr static std::array<SupportedTensorDataTypes, 1> supportedTypeListFloat32 = {SupportedTensorDataTypes::Float32};
|
||||
|
|
@ -571,12 +573,18 @@ constexpr static std::array<SupportedTensorDataTypes, 4> supportedTypeListQLinea
|
|||
|
||||
constexpr static std::array<SupportedTensorDataTypes, 2> supportedTypeListDynamicQuantizeLinear = {
|
||||
SupportedTensorDataTypes::Float32,
|
||||
SupportedTensorDataTypes::Int8|SupportedTensorDataTypes::UInt8
|
||||
SupportedTensorDataTypes::Ints8Bit
|
||||
};
|
||||
|
||||
constexpr static std::array<SupportedTensorDataTypes, 2> supportedTypeListDynamicQuantizeMatMul= {
|
||||
SupportedTensorDataTypes::Float32,
|
||||
SupportedTensorDataTypes::Int8|SupportedTensorDataTypes::UInt8,
|
||||
SupportedTensorDataTypes::Ints8Bit,
|
||||
};
|
||||
|
||||
constexpr static std::array<SupportedTensorDataTypes, 3> supportedTypeListQLinearConcat= {
|
||||
SupportedTensorDataTypes::Float32,
|
||||
SupportedTensorDataTypes::Ints8Bit,
|
||||
SupportedTensorDataTypes::Ints8Bit|SupportedTensorDataTypes::Float32,
|
||||
};
|
||||
|
||||
template<typename... Args>
|
||||
|
|
@ -969,6 +977,7 @@ constexpr static OperatorRegistrationInformation operatorRegistrationInformation
|
|||
{REG_INFO_MS( 1, QAttention, typeNameListFour, supportedTypeListQAttention, DmlGraphSupport::Supported, requiredConstantCpuInputs(), std::nullopt, QueryQAttention)},
|
||||
{REG_INFO_MS( 1, Attention, typeNameListAttention, supportedTypeListAttention, DmlGraphSupport::Supported, requiredConstantCpuInputs(), std::nullopt, QueryAttention)},
|
||||
{REG_INFO_MS( 1, MultiHeadAttention, typeNameListAttention, supportedTypeListAttention, DmlGraphSupport::Supported)},
|
||||
{REG_INFO_MS( 1, QLinearConcat, typeNameListQLinearConcat, supportedTypeListQLinearConcat, DmlGraphSupport::Supported)},
|
||||
|
||||
{REG_INFO( 10, IsInf, typeNameListTwo, supportedTypeListIsInf, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 10, Mod, typeNameListDefault, supportedTypeListNumericDefault, DmlGraphSupport::Supported)},
|
||||
|
|
|
|||
|
|
@ -419,9 +419,9 @@ namespace Dml
|
|||
|
||||
} // namespace FusionHelpers
|
||||
|
||||
uint32_t GetDmlAdjustedAxis(int32_t onnxAxis, const MLOperatorKernelCreationContext& kernelCreationContext, uint32_t dmlDimCount)
|
||||
uint32_t GetDmlAdjustedAxis(int32_t onnxAxis, const MLOperatorKernelCreationContext& kernelCreationContext, uint32_t dmlDimCount, uint32_t firstInputIndex)
|
||||
{
|
||||
const std::vector<DimensionType> inputDimensions = kernelCreationContext.GetTensorShapeDescription().GetInputTensorShape(0);
|
||||
const std::vector<DimensionType> inputDimensions = kernelCreationContext.GetTensorShapeDescription().GetInputTensorShape(firstInputIndex);
|
||||
uint32_t onnxDimCount = gsl::narrow_cast<uint32_t>(inputDimensions.size());
|
||||
onnxAxis = HandleNegativeAxis(onnxAxis, onnxDimCount);
|
||||
return GetDmlAdjustedAxis(onnxAxis, onnxDimCount, dmlDimCount);
|
||||
|
|
|
|||
|
|
@ -64,8 +64,7 @@ namespace Dml
|
|||
} // namespace FusionHelpers
|
||||
|
||||
// Given an axis in ONNX axis numbering, return the axis adjusted for DML based on how the sizes have been coerced.
|
||||
// Note this function presumes the axis attribute is relative to the first input tensor (which is always the case).
|
||||
uint32_t GetDmlAdjustedAxis(int32_t onnxAxis, const MLOperatorKernelCreationContext& kernelCreationContext, uint32_t dmlDimCount);
|
||||
uint32_t GetDmlAdjustedAxis(int32_t onnxAxis, const MLOperatorKernelCreationContext& kernelCreationContext, uint32_t dmlDimCount, uint32_t firstInputIndex = 0);
|
||||
|
||||
uint32_t GetDmlAdjustedAxis(int32_t onnxAxis, uint32_t onnxDimCount, uint32_t dmlDimCount);
|
||||
|
||||
|
|
|
|||
|
|
@ -1862,7 +1862,7 @@ namespace OperatorHelper
|
|||
return { std::move(outputShape) };
|
||||
}
|
||||
|
||||
void ConcatHelper::Initialize(
|
||||
void ConcatHelperBase::Initialize(
|
||||
const MLOperatorAttributes& operatorAttributes,
|
||||
gsl::span<const DimensionType> inputDimensions
|
||||
)
|
||||
|
|
@ -1872,13 +1872,13 @@ namespace OperatorHelper
|
|||
ML_CHECK_VALID_ARGUMENT(m_axis < static_cast<int>(inputDimensions.size()));
|
||||
}
|
||||
|
||||
std::vector<EdgeShapes> ConcatHelper::GetOutputShapes(const MLShapeInferenceContext& shapeInfo) const
|
||||
std::vector<EdgeShapes> ConcatHelperBase::GetOutputShapes(const MLShapeInferenceContext& shapeInfo, uint32_t firstInputIndex, uint32_t step) const
|
||||
{
|
||||
auto outputShape = shapeInfo.GetInputTensorShape(0);
|
||||
auto outputShape = shapeInfo.GetInputTensorShape(firstInputIndex);
|
||||
|
||||
uint32_t inputCount = shapeInfo.GetInputCount();
|
||||
|
||||
for (uint32_t i = 1; i < inputCount; ++i)
|
||||
for (uint32_t i = firstInputIndex + step; i < inputCount; i += step)
|
||||
{
|
||||
auto inputShape = shapeInfo.GetInputTensorShape(i);
|
||||
for (size_t j = 0; j < outputShape.size(); ++j)
|
||||
|
|
@ -1893,6 +1893,16 @@ namespace OperatorHelper
|
|||
return { EdgeShapes(outputShape) };
|
||||
}
|
||||
|
||||
std::vector<EdgeShapes> ConcatHelper::GetOutputShapes(const MLShapeInferenceContext& shapeInfo) const
|
||||
{
|
||||
return ConcatHelperBase::GetOutputShapes(shapeInfo, 0, 1);
|
||||
}
|
||||
|
||||
std::vector<EdgeShapes> QLinearConcatHelper::GetOutputShapes(const MLShapeInferenceContext& shapeInfo) const
|
||||
{
|
||||
return ConcatHelperBase::GetOutputShapes(shapeInfo, 2, 3);
|
||||
}
|
||||
|
||||
void CropHelper::Initialize(
|
||||
const MLOperatorAttributes& operatorAttributes,
|
||||
gsl::span<const DimensionType> inputDimensions
|
||||
|
|
|
|||
|
|
@ -871,7 +871,7 @@ protected:
|
|||
int m_hiddenSize = 0;
|
||||
};
|
||||
|
||||
class ConcatHelper
|
||||
class ConcatHelperBase
|
||||
{
|
||||
public:
|
||||
void Initialize(
|
||||
|
|
@ -882,17 +882,33 @@ public:
|
|||
// Info_t is used to obtain attributes which will be used for calculating the output shape later.
|
||||
// Shape_t is used to obtain input shape which will be used for adjusting attribute value.
|
||||
template <typename Info_t, typename Shape_t>
|
||||
ConcatHelper(const Info_t& info, const Shape_t& shape)
|
||||
ConcatHelperBase(const Info_t& info, const Shape_t& shape, uint32_t firstInputIndex)
|
||||
{
|
||||
Initialize(info, shape.GetInputTensorShape(0));
|
||||
Initialize(info, shape.GetInputTensorShape(firstInputIndex));
|
||||
}
|
||||
|
||||
std::vector<EdgeShapes> GetOutputShapes(const MLShapeInferenceContext& shapeInfo) const;
|
||||
std::vector<EdgeShapes> GetOutputShapes(const MLShapeInferenceContext& shapeInfo, uint32_t firstInputIndex, uint32_t step) const;
|
||||
|
||||
protected:
|
||||
int m_axis;
|
||||
};
|
||||
|
||||
class ConcatHelper: public ConcatHelperBase
|
||||
{
|
||||
public:
|
||||
template<typename Info_t, typename Shape_t>
|
||||
ConcatHelper(const Info_t& info, const Shape_t& shape) : ConcatHelperBase(info, shape, 0) {}
|
||||
std::vector<EdgeShapes> GetOutputShapes(const MLShapeInferenceContext& shapeInfo) const;
|
||||
};
|
||||
|
||||
class QLinearConcatHelper: public ConcatHelperBase
|
||||
{
|
||||
public:
|
||||
template<typename Info_t, typename Shape_t>
|
||||
QLinearConcatHelper(const Info_t& info, const Shape_t& shape) : ConcatHelperBase(info, shape, 2) {}
|
||||
std::vector<EdgeShapes> GetOutputShapes(const MLShapeInferenceContext& shapeInfo) const;
|
||||
};
|
||||
|
||||
class CropHelper
|
||||
{
|
||||
public:
|
||||
|
|
@ -1519,6 +1535,7 @@ using ShapeInferenceHelper_Split13 = VersionedOpsetHelper<SplitHelper, 13>;
|
|||
using ShapeInferenceHelper_Split18 = VersionedOpsetHelper<SplitHelper, 18>;
|
||||
using ShapeInferenceHelper_Transpose = TransposeHelper;
|
||||
using ShapeInferenceHelper_Concat = ConcatHelper;
|
||||
using ShapeInferenceHelper_QLinearConcat = QLinearConcatHelper;
|
||||
using ShapeInferenceHelper_Slice7 = VersionedOpsetHelper<SliceHelper, 7>;
|
||||
using ShapeInferenceHelper_Slice10 = VersionedOpsetHelper<SliceHelper, 10>;
|
||||
using ShapeInferenceHelper_Slice11 = VersionedOpsetHelper<SliceHelper, 11>; // Note 11 and 10 are identical - no functional change.
|
||||
|
|
|
|||
|
|
@ -440,6 +440,7 @@ namespace OperatorHelper
|
|||
static const int sc_sinceVer_QuickGelu = 1;
|
||||
static const int sc_sinceVer_GroupNorm = 1;
|
||||
static const int sc_sinceVer_DynamicQuantizeMatMul = 1;
|
||||
static const int sc_sinceVer_QLinearConcat = 1;
|
||||
} // namespace MsftOperatorSet1
|
||||
|
||||
} // namespace OperatorHelper
|
||||
|
|
|
|||
Loading…
Reference in a new issue