From d3345f36802d9a3a17828224ff38f03bcf615b41 Mon Sep 17 00:00:00 2001 From: Xiang Zhang Date: Thu, 17 Aug 2023 15:15:25 -0700 Subject: [PATCH] Add QLinearConcat for DML EP (#16971) --- .../Operators/DmlOperatorQLinearConcat.cpp | 236 ++++++++++++++++++ .../src/Operators/OperatorRegistration.cpp | 13 +- .../src/Operators/OperatorUtility.cpp | 4 +- .../src/Operators/OperatorUtility.h | 3 +- .../OperatorAuthorHelper/OperatorHelper.cpp | 18 +- .../dml/OperatorAuthorHelper/OperatorHelper.h | 25 +- .../OperatorAuthorHelper/OperatorVersions.h | 1 + 7 files changed, 286 insertions(+), 14 deletions(-) create mode 100644 onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorQLinearConcat.cpp diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorQLinearConcat.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorQLinearConcat.cpp new file mode 100644 index 0000000000..67711fdc28 --- /dev/null +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorQLinearConcat.cpp @@ -0,0 +1,236 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "precomp.h" + +namespace Dml +{ +// QLinearConcat = Dequantize + Join + Quantize +class DmlOperatorQLinearConcat : public DmlOperator, public QLinearConcatHelper +{ + // This order matches the ONNX schema. + enum OnnxInputIndex + { + YScale, + YZeroPoint, + Count, + }; + +public: + DmlOperatorQLinearConcat(const MLOperatorKernelCreationContext& kernelCreationContext) + : DmlOperator(kernelCreationContext), + QLinearConcatHelper(kernelCreationContext, kernelCreationContext.GetTensorShapeDescription()) + { + DmlOperator::Initialize(kernelCreationContext); + + auto outputShape = kernelCreationContext.GetTensorShapeDescription().GetOutputTensorShape(0); + + // inputs: {y_scale, y_zero_point, tuple(x_tensor, x_scale, x_zero_point)} + uint32_t inputDefinitionCount = kernelCreationContext.GetInputCount(); + ML_CHECK_VALID_ARGUMENT(inputDefinitionCount >= 5, "Require at least 5 inputs."); + ML_CHECK_VALID_ARGUMENT((inputDefinitionCount - 2) % 3 == 0, "Each input must be (tensor, scale, zero_point) tuple!"); + + uint32_t inputCount = (inputDefinitionCount - 2) / 3; + + auto yScaleDataType = kernelCreationContext.GetInputEdgeDescription(OnnxInputIndex::YScale).tensorDataType; + auto yZeroPointDataType = kernelCreationContext.GetInputEdgeDescription(OnnxInputIndex::YZeroPoint).tensorDataType; + + // broadcast y_scale and y_zero_point to output shape + m_inputTensorDescs[OnnxInputIndex::YScale] = TensorDesc( + yScaleDataType, + outputShape, + kernelCreationContext.GetTensorShapeDescription().GetInputTensorShape(OnnxInputIndex::YScale), + TensorAxis::DoNotCoerce, + TensorAxis::W, + TensorAxis::RightAligned, + NchwDimensionCount, // minDimensionCount + 0 // guaranteedBaseOffsetAlignment + ); + + m_inputTensorDescs[OnnxInputIndex::YZeroPoint] = TensorDesc( + yZeroPointDataType, + outputShape, + kernelCreationContext.GetTensorShapeDescription().GetInputTensorShape(OnnxInputIndex::YZeroPoint), + TensorAxis::DoNotCoerce, + TensorAxis::W, + TensorAxis::RightAligned, + NchwDimensionCount, // minDimensionCount + 0 // guaranteedBaseOffsetAlignment + ); + + // Validate input tensors + for (uint32_t inputIndex = 0; inputIndex < inputCount; ++inputIndex) + { + // Inputs(input tensor, scale, zero_point) are in tuple and starting from index 2 + auto tupleStartIndex = 2 + inputIndex * 3; + auto xScaleDataType = kernelCreationContext.GetInputEdgeDescription(tupleStartIndex + 1).tensorDataType; + auto xZeroPointDataType = kernelCreationContext.GetInputEdgeDescription(tupleStartIndex + 2).tensorDataType; + ML_CHECK_VALID_ARGUMENT(xScaleDataType == yScaleDataType, "Wrong input type encountered for scale"); + ML_CHECK_VALID_ARGUMENT(xZeroPointDataType == yZeroPointDataType, "Wrong input type encountered for zero point"); + + // broadcast x_scale and x_zero_point to shape of corresponding x + m_inputTensorDescs[tupleStartIndex + 1] = TensorDesc( + xScaleDataType, + kernelCreationContext.GetTensorShapeDescription().GetInputTensorShape(tupleStartIndex), + kernelCreationContext.GetTensorShapeDescription().GetInputTensorShape(tupleStartIndex + 1), + TensorAxis::DoNotCoerce, + TensorAxis::W, + TensorAxis::RightAligned, + NchwDimensionCount, // minDimensionCount + 0 // guaranteedBaseOffsetAlignment + ); + + m_inputTensorDescs[tupleStartIndex + 2] = TensorDesc( + xZeroPointDataType, + kernelCreationContext.GetTensorShapeDescription().GetInputTensorShape(tupleStartIndex), + kernelCreationContext.GetTensorShapeDescription().GetInputTensorShape(tupleStartIndex + 2), + TensorAxis::DoNotCoerce, + TensorAxis::W, + TensorAxis::RightAligned, + NchwDimensionCount, // minDimensionCount + 0 // guaranteedBaseOffsetAlignment + ); + } + + uint32_t dmlAxis = GetDmlAdjustedAxis(m_axis, kernelCreationContext, m_inputTensorDescs.front().GetDimensionCount(), 2); + + std::vector inputDescs = GetDmlInputDescs(); + std::vector outputDescs = GetDmlOutputDescs(); + + // 1. output edges between Dequantize and Join node + // 2. input edge between Join and Quantize node + std::vector intermediateOutputTensorDescs(inputCount); + std::vector namedDequantizeOperatorDescs(inputCount); + std::vector dequantizeOperatorDescs(inputCount); + std::vector dmlOpDesc(inputCount); + std::vector opDescs; + for (uint32_t inputIndex = 0; inputIndex < inputCount; ++inputIndex) + { + auto tupleStartIndex = 2 + inputIndex * 3; + intermediateOutputTensorDescs[inputIndex] = TensorDesc( + MLOperatorTensorDataType::Float, + kernelCreationContext.GetTensorShapeDescription().GetInputTensorShape(tupleStartIndex), + kernelCreationContext.GetTensorShapeDescription().GetInputTensorShape(tupleStartIndex), + TensorAxis::DoNotCoerce, + TensorAxis::W, + TensorAxis::RightAligned, + NchwDimensionCount, // minDimensionCount + 0 // guaranteedBaseOffsetAlignment) + ); + namedDequantizeOperatorDescs[inputIndex] = intermediateOutputTensorDescs[inputIndex].GetDmlDesc(); + + dequantizeOperatorDescs[inputIndex].InputTensor = &inputDescs[tupleStartIndex]; + dequantizeOperatorDescs[inputIndex].ScaleTensor = &inputDescs[tupleStartIndex + 1]; + dequantizeOperatorDescs[inputIndex].ZeroPointTensor = &inputDescs[tupleStartIndex + 2]; + dequantizeOperatorDescs[inputIndex].OutputTensor = &namedDequantizeOperatorDescs[inputIndex]; + + dmlOpDesc[inputIndex] = {DML_OPERATOR_ELEMENT_WISE_DEQUANTIZE_LINEAR, &dequantizeOperatorDescs[inputIndex]}; + opDescs.push_back(&dmlOpDesc[inputIndex]); + } + + TensorDesc joinOutputTensorDesc = TensorDesc( + MLOperatorTensorDataType::Float, + outputShape, + outputShape, + TensorAxis::DoNotCoerce, + TensorAxis::W, + TensorAxis::RightAligned, + NchwDimensionCount, // minDimensionCount + 0 // guaranteedBaseOffsetAlignment + ); + DML_TENSOR_DESC namedJoinOutputTensorDesc = joinOutputTensorDesc.GetDmlDesc(); + + DML_JOIN_OPERATOR_DESC joinDesc = {}; + joinDesc.InputCount = gsl::narrow_cast(namedDequantizeOperatorDescs.size()); + joinDesc.InputTensors = namedDequantizeOperatorDescs.data(); + joinDesc.OutputTensor = &namedJoinOutputTensorDesc; + joinDesc.Axis = dmlAxis; + + const DML_OPERATOR_DESC opJoinDesc = {DML_OPERATOR_JOIN, &joinDesc}; + opDescs.push_back(&opJoinDesc); + + DML_ELEMENT_WISE_QUANTIZE_LINEAR_OPERATOR_DESC quantizeOperatorDesc = {}; + quantizeOperatorDesc.InputTensor = joinDesc.OutputTensor; + quantizeOperatorDesc.ScaleTensor = &inputDescs[OnnxInputIndex::YScale]; + quantizeOperatorDesc.ZeroPointTensor = &inputDescs[OnnxInputIndex::YZeroPoint]; + quantizeOperatorDesc.OutputTensor = &outputDescs[0]; + const DML_OPERATOR_DESC opQuantizeDesc = {DML_OPERATOR_ELEMENT_WISE_QUANTIZE_LINEAR, &quantizeOperatorDesc}; + opDescs.push_back(&opQuantizeDesc); + + MLOperatorGraphDesc operatorGraphDesc = {}; + operatorGraphDesc.nodeCount = static_cast(opDescs.size()); + operatorGraphDesc.nodesAsOpDesc = opDescs.data(); + + uint32_t joinNodeIndex = operatorGraphDesc.nodeCount - 2; + uint32_t quantizeNodeIndex = operatorGraphDesc.nodeCount - 1; + + std::vector inputEdges; + // Input edges to Dequantize nodes + for (uint32_t inputIndex = 0; inputIndex < inputCount; ++inputIndex) + { + auto tupleStartIndex = 2 + inputIndex * 3; + for (auto edge_index = 0; edge_index < 3; ++edge_index) + { + DML_INPUT_GRAPH_EDGE_DESC inputEdge = {}; + inputEdge.GraphInputIndex = tupleStartIndex + edge_index; + inputEdge.ToNodeIndex = inputIndex; + inputEdge.ToNodeInputIndex = edge_index; + inputEdges.push_back(inputEdge); + } + } + + // Input edge from y_scale to quantize node + DML_INPUT_GRAPH_EDGE_DESC yScaleInputEdge = {}; + yScaleInputEdge.GraphInputIndex = 0; // Y_scale + yScaleInputEdge.ToNodeIndex = quantizeNodeIndex; + yScaleInputEdge.ToNodeInputIndex = 1; + inputEdges.push_back(yScaleInputEdge); + + // Input edge from y_zero_point to quantize node + DML_INPUT_GRAPH_EDGE_DESC yZeroPointInputEdge = {}; + yZeroPointInputEdge.GraphInputIndex = 1; // Y_zero_point + yZeroPointInputEdge.ToNodeIndex = quantizeNodeIndex; + yZeroPointInputEdge.ToNodeInputIndex = 2; + inputEdges.push_back(yZeroPointInputEdge); + + operatorGraphDesc.inputEdgeCount = gsl::narrow_cast(inputEdges.size()); + operatorGraphDesc.inputEdges = inputEdges.data(); + + // set intermediate edges + std::vector intermediateEdges; + for (uint32_t inputIndex = 0; inputIndex < inputCount; ++inputIndex) + { + DML_INTERMEDIATE_GRAPH_EDGE_DESC dequantizeToJoinEdge = {}; + dequantizeToJoinEdge.FromNodeIndex = inputIndex; + dequantizeToJoinEdge.FromNodeOutputIndex = 0; + dequantizeToJoinEdge.ToNodeIndex = joinNodeIndex; // The second last node Join + dequantizeToJoinEdge.ToNodeInputIndex = inputIndex; + intermediateEdges.push_back(dequantizeToJoinEdge); + } + + DML_INTERMEDIATE_GRAPH_EDGE_DESC joinToQuantizeEdge = {}; + joinToQuantizeEdge.FromNodeIndex = joinNodeIndex; + joinToQuantizeEdge.FromNodeOutputIndex = 0; + joinToQuantizeEdge.ToNodeIndex = quantizeNodeIndex; // The second last node Join + joinToQuantizeEdge.ToNodeInputIndex = 0; + intermediateEdges.push_back(joinToQuantizeEdge); + + operatorGraphDesc.intermediateEdgeCount = gsl::narrow_cast(intermediateEdges.size()); + operatorGraphDesc.intermediateEdges = intermediateEdges.data(); + + // set the output edges + std::vector outputEdges; + DML_OUTPUT_GRAPH_EDGE_DESC outputEdge = {}; + outputEdge.FromNodeIndex = quantizeNodeIndex; + outputEdge.FromNodeOutputIndex = 0; + outputEdge.GraphOutputIndex = 0; + outputEdges.push_back(outputEdge); + operatorGraphDesc.outputEdgeCount = gsl::narrow_cast(outputEdges.size()); + operatorGraphDesc.outputEdges = outputEdges.data(); + + SetDmlOperatorGraphDesc(std::move(operatorGraphDesc), kernelCreationContext); + }; +}; + +DML_OP_DEFINE_CREATION_FUNCTION(QLinearConcat, DmlOperatorQLinearConcat); +} // namespace Dml diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.cpp index 1c90a9988a..07ff4f3145 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.cpp @@ -434,6 +434,7 @@ DML_OP_EXTERN_CREATION_FUNCTION(ScatterND); DML_OP_EXTERN_CREATION_FUNCTION(QLinearAdd); DML_OP_EXTERN_CREATION_FUNCTION(QLinearConv); DML_OP_EXTERN_CREATION_FUNCTION(QLinearMatMul); +DML_OP_EXTERN_CREATION_FUNCTION(QLinearConcat); DML_OP_EXTERN_CREATION_FUNCTION(DynamicQuantizeLinear); DML_OP_EXTERN_CREATION_FUNCTION(MatMulInteger); DML_OP_EXTERN_CREATION_FUNCTION(MatMulIntegerToFloat); @@ -486,6 +487,7 @@ constexpr static std::array typeNameListEyeLike = { "T1", "T2" } constexpr static std::array typeNameShape = { "T", "T1" }; constexpr static std::array typeNameSize = { "T", "T1" }; constexpr static std::array typeNameListGroupNorm = {"T", "M"}; +constexpr static std::array typeNameListQLinearConcat= {"TF", "T8", "TV"}; constexpr static std::array supportedTypeListAll = {SupportedTensorDataTypes::All}; constexpr static std::array supportedTypeListFloat32 = {SupportedTensorDataTypes::Float32}; @@ -571,12 +573,18 @@ constexpr static std::array supportedTypeListQLinea constexpr static std::array supportedTypeListDynamicQuantizeLinear = { SupportedTensorDataTypes::Float32, - SupportedTensorDataTypes::Int8|SupportedTensorDataTypes::UInt8 + SupportedTensorDataTypes::Ints8Bit }; constexpr static std::array supportedTypeListDynamicQuantizeMatMul= { SupportedTensorDataTypes::Float32, - SupportedTensorDataTypes::Int8|SupportedTensorDataTypes::UInt8, + SupportedTensorDataTypes::Ints8Bit, +}; + +constexpr static std::array supportedTypeListQLinearConcat= { + SupportedTensorDataTypes::Float32, + SupportedTensorDataTypes::Ints8Bit, + SupportedTensorDataTypes::Ints8Bit|SupportedTensorDataTypes::Float32, }; template @@ -969,6 +977,7 @@ constexpr static OperatorRegistrationInformation operatorRegistrationInformation {REG_INFO_MS( 1, QAttention, typeNameListFour, supportedTypeListQAttention, DmlGraphSupport::Supported, requiredConstantCpuInputs(), std::nullopt, QueryQAttention)}, {REG_INFO_MS( 1, Attention, typeNameListAttention, supportedTypeListAttention, DmlGraphSupport::Supported, requiredConstantCpuInputs(), std::nullopt, QueryAttention)}, {REG_INFO_MS( 1, MultiHeadAttention, typeNameListAttention, supportedTypeListAttention, DmlGraphSupport::Supported)}, + {REG_INFO_MS( 1, QLinearConcat, typeNameListQLinearConcat, supportedTypeListQLinearConcat, DmlGraphSupport::Supported)}, {REG_INFO( 10, IsInf, typeNameListTwo, supportedTypeListIsInf, DmlGraphSupport::Supported)}, {REG_INFO( 10, Mod, typeNameListDefault, supportedTypeListNumericDefault, DmlGraphSupport::Supported)}, diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorUtility.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorUtility.cpp index d8290bbdae..2965fa32ce 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorUtility.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorUtility.cpp @@ -419,9 +419,9 @@ namespace Dml } // namespace FusionHelpers - uint32_t GetDmlAdjustedAxis(int32_t onnxAxis, const MLOperatorKernelCreationContext& kernelCreationContext, uint32_t dmlDimCount) + uint32_t GetDmlAdjustedAxis(int32_t onnxAxis, const MLOperatorKernelCreationContext& kernelCreationContext, uint32_t dmlDimCount, uint32_t firstInputIndex) { - const std::vector inputDimensions = kernelCreationContext.GetTensorShapeDescription().GetInputTensorShape(0); + const std::vector inputDimensions = kernelCreationContext.GetTensorShapeDescription().GetInputTensorShape(firstInputIndex); uint32_t onnxDimCount = gsl::narrow_cast(inputDimensions.size()); onnxAxis = HandleNegativeAxis(onnxAxis, onnxDimCount); return GetDmlAdjustedAxis(onnxAxis, onnxDimCount, dmlDimCount); diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorUtility.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorUtility.h index f0fad6a05f..8b2da60842 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorUtility.h +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorUtility.h @@ -64,8 +64,7 @@ namespace Dml } // namespace FusionHelpers // Given an axis in ONNX axis numbering, return the axis adjusted for DML based on how the sizes have been coerced. - // Note this function presumes the axis attribute is relative to the first input tensor (which is always the case). - uint32_t GetDmlAdjustedAxis(int32_t onnxAxis, const MLOperatorKernelCreationContext& kernelCreationContext, uint32_t dmlDimCount); + uint32_t GetDmlAdjustedAxis(int32_t onnxAxis, const MLOperatorKernelCreationContext& kernelCreationContext, uint32_t dmlDimCount, uint32_t firstInputIndex = 0); uint32_t GetDmlAdjustedAxis(int32_t onnxAxis, uint32_t onnxDimCount, uint32_t dmlDimCount); diff --git a/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.cpp b/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.cpp index 370f336ff5..4d59964dcc 100644 --- a/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.cpp +++ b/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.cpp @@ -1862,7 +1862,7 @@ namespace OperatorHelper return { std::move(outputShape) }; } - void ConcatHelper::Initialize( + void ConcatHelperBase::Initialize( const MLOperatorAttributes& operatorAttributes, gsl::span inputDimensions ) @@ -1872,13 +1872,13 @@ namespace OperatorHelper ML_CHECK_VALID_ARGUMENT(m_axis < static_cast(inputDimensions.size())); } - std::vector ConcatHelper::GetOutputShapes(const MLShapeInferenceContext& shapeInfo) const + std::vector ConcatHelperBase::GetOutputShapes(const MLShapeInferenceContext& shapeInfo, uint32_t firstInputIndex, uint32_t step) const { - auto outputShape = shapeInfo.GetInputTensorShape(0); + auto outputShape = shapeInfo.GetInputTensorShape(firstInputIndex); uint32_t inputCount = shapeInfo.GetInputCount(); - for (uint32_t i = 1; i < inputCount; ++i) + for (uint32_t i = firstInputIndex + step; i < inputCount; i += step) { auto inputShape = shapeInfo.GetInputTensorShape(i); for (size_t j = 0; j < outputShape.size(); ++j) @@ -1893,6 +1893,16 @@ namespace OperatorHelper return { EdgeShapes(outputShape) }; } + std::vector ConcatHelper::GetOutputShapes(const MLShapeInferenceContext& shapeInfo) const + { + return ConcatHelperBase::GetOutputShapes(shapeInfo, 0, 1); + } + + std::vector QLinearConcatHelper::GetOutputShapes(const MLShapeInferenceContext& shapeInfo) const + { + return ConcatHelperBase::GetOutputShapes(shapeInfo, 2, 3); + } + void CropHelper::Initialize( const MLOperatorAttributes& operatorAttributes, gsl::span inputDimensions diff --git a/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.h b/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.h index c23f4fe355..5add951dcc 100644 --- a/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.h +++ b/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.h @@ -871,7 +871,7 @@ protected: int m_hiddenSize = 0; }; -class ConcatHelper +class ConcatHelperBase { public: void Initialize( @@ -882,17 +882,33 @@ public: // Info_t is used to obtain attributes which will be used for calculating the output shape later. // Shape_t is used to obtain input shape which will be used for adjusting attribute value. template - ConcatHelper(const Info_t& info, const Shape_t& shape) + ConcatHelperBase(const Info_t& info, const Shape_t& shape, uint32_t firstInputIndex) { - Initialize(info, shape.GetInputTensorShape(0)); + Initialize(info, shape.GetInputTensorShape(firstInputIndex)); } - std::vector GetOutputShapes(const MLShapeInferenceContext& shapeInfo) const; + std::vector GetOutputShapes(const MLShapeInferenceContext& shapeInfo, uint32_t firstInputIndex, uint32_t step) const; protected: int m_axis; }; +class ConcatHelper: public ConcatHelperBase +{ +public: + template + ConcatHelper(const Info_t& info, const Shape_t& shape) : ConcatHelperBase(info, shape, 0) {} + std::vector GetOutputShapes(const MLShapeInferenceContext& shapeInfo) const; +}; + +class QLinearConcatHelper: public ConcatHelperBase +{ +public: + template + QLinearConcatHelper(const Info_t& info, const Shape_t& shape) : ConcatHelperBase(info, shape, 2) {} + std::vector GetOutputShapes(const MLShapeInferenceContext& shapeInfo) const; +}; + class CropHelper { public: @@ -1519,6 +1535,7 @@ using ShapeInferenceHelper_Split13 = VersionedOpsetHelper; using ShapeInferenceHelper_Split18 = VersionedOpsetHelper; using ShapeInferenceHelper_Transpose = TransposeHelper; using ShapeInferenceHelper_Concat = ConcatHelper; +using ShapeInferenceHelper_QLinearConcat = QLinearConcatHelper; using ShapeInferenceHelper_Slice7 = VersionedOpsetHelper; using ShapeInferenceHelper_Slice10 = VersionedOpsetHelper; using ShapeInferenceHelper_Slice11 = VersionedOpsetHelper; // Note 11 and 10 are identical - no functional change. diff --git a/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorVersions.h b/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorVersions.h index a469a3e06e..d785f77e24 100644 --- a/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorVersions.h +++ b/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorVersions.h @@ -440,6 +440,7 @@ namespace OperatorHelper static const int sc_sinceVer_QuickGelu = 1; static const int sc_sinceVer_GroupNorm = 1; static const int sc_sinceVer_DynamicQuantizeMatMul = 1; + static const int sc_sinceVer_QLinearConcat = 1; } // namespace MsftOperatorSet1 } // namespace OperatorHelper