From 897a0b9839daee8bf64e927abbe53a6304415396 Mon Sep 17 00:00:00 2001 From: Jeff Bloomfield Date: Thu, 18 Mar 2021 00:15:46 +0000 Subject: [PATCH] Merged PR 5807395: Add DML kernels for QLinearAdd (com.microsoft namespace) and DynamicQuantizeLinear Related work items: #32165595 --- .../src/External/DirectMLHelpers/ApiTraits.h | 30 ++ .../External/DirectMLHelpers/DirectMLSchema.h | 35 ++ .../DirectMLHelpers/GeneratedSchemaHelpers.h | 33 ++ .../DmlOperatorDynamicQuantizeLinear.cpp | 35 +- .../src/Operators/DmlOperatorQLinearAdd.cpp | 69 +++ .../src/Operators/OperatorRegistration.cpp | 454 +++++++++--------- .../dml/OperatorAuthorHelper/OperatorHelper.h | 2 + .../OperatorRegistration.h | 2 + 8 files changed, 429 insertions(+), 231 deletions(-) create mode 100644 onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorQLinearAdd.cpp diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/ApiTraits.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/ApiTraits.h index 2b177faeda..642fc1db60 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/ApiTraits.h +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/ApiTraits.h @@ -897,6 +897,18 @@ struct OperatorDescTraits static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_GATHER_ND1; }; +template <> +struct OperatorDescTraits +{ + static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_DYNAMIC_QUANTIZE_LINEAR; +}; + +template <> +struct OperatorDescTraits +{ + static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_ELEMENT_WISE_QUANTIZED_LINEAR_ADD; +}; + template <> struct OperatorDescTraits { @@ -1743,6 +1755,18 @@ struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_GATHER_ND1> using DescType = DML_GATHER_ND1_OPERATOR_DESC; }; +template <> +struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_DYNAMIC_QUANTIZE_LINEAR> +{ + using DescType = DML_DYNAMIC_QUANTIZE_LINEAR_OPERATOR_DESC; +}; + +template <> +struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_ELEMENT_WISE_QUANTIZED_LINEAR_ADD> +{ + using DescType = DML_ELEMENT_WISE_QUANTIZED_LINEAR_ADD_OPERATOR_DESC; +}; + template <> struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_ACTIVATION_ELU> { @@ -2116,6 +2140,10 @@ auto OperatorTypeVisitor(DML_OPERATOR_TYPE type, Visitor&& visitor, Ts&&... args return std::invoke(std::forward(visitor), DML_ROI_ALIGN_OPERATOR_DESC{}, std::forward(args)...); case DML_OPERATOR_GATHER_ND1: return std::invoke(std::forward(visitor), DML_GATHER_ND1_OPERATOR_DESC{}, std::forward(args)...); + case DML_OPERATOR_DYNAMIC_QUANTIZE_LINEAR: + return std::invoke(std::forward(visitor), DML_DYNAMIC_QUANTIZE_LINEAR_OPERATOR_DESC{}, std::forward(args)...); + case DML_OPERATOR_ELEMENT_WISE_QUANTIZED_LINEAR_ADD: + return std::invoke(std::forward(visitor), DML_ELEMENT_WISE_QUANTIZED_LINEAR_ADD_OPERATOR_DESC{}, std::forward(args)...); case DML_OPERATOR_ACTIVATION_ELU: return std::invoke(std::forward(visitor), DML_ACTIVATION_ELU_OPERATOR_DESC{}, std::forward(args)...); case DML_OPERATOR_ACTIVATION_CELU: @@ -2288,6 +2316,8 @@ inline gsl::czstring ToString(DML_OPERATOR_TYPE value) case DML_OPERATOR_ADAM_OPTIMIZER: return "DML_OPERATOR_ADAM_OPTIMIZER"; case DML_OPERATOR_ROI_ALIGN: return "DML_OPERATOR_ROI_ALIGN"; case DML_OPERATOR_GATHER_ND1: return "DML_OPERATOR_GATHER_ND1"; + case DML_OPERATOR_DYNAMIC_QUANTIZE_LINEAR: return "DML_OPERATOR_DYNAMIC_QUANTIZE_LINEAR"; + case DML_OPERATOR_ELEMENT_WISE_QUANTIZED_LINEAR_ADD: return "DML_OPERATOR_ELEMENT_WISE_QUANTIZED_LINEAR_ADD"; default: assert(false); return ""; diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/DirectMLSchema.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/DirectMLSchema.h index ec3d24070b..fa83ba16e4 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/DirectMLSchema.h +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/DirectMLSchema.h @@ -1949,6 +1949,41 @@ constexpr DML_OPERATOR_SCHEMA DML_GATHER_ND1_OPERATOR_SCHEMA { DML_GATHER_ND1_OPERATOR_SCHEMA_FIELDS, }; +constexpr DML_SCHEMA_FIELD DML_DYNAMIC_QUANTIZE_LINEAR_OPERATOR_SCHEMA_FIELDS[4] { + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputScaleTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputZeroPointTensor", false }, +}; + +constexpr DML_OPERATOR_SCHEMA DML_DYNAMIC_QUANTIZE_LINEAR_OPERATOR_SCHEMA { + "DML_OPERATOR_DYNAMIC_QUANTIZE_LINEAR", + DML_OPERATOR_DYNAMIC_QUANTIZE_LINEAR, + DML_SCHEMA_OPERATOR_SUPPORT_FLAG_NONE, + 4, + DML_DYNAMIC_QUANTIZE_LINEAR_OPERATOR_SCHEMA_FIELDS, +}; + +constexpr DML_SCHEMA_FIELD DML_ELEMENT_WISE_QUANTIZED_LINEAR_ADD_OPERATOR_SCHEMA_FIELDS[9] { + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "ATensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "AScaleTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "AZeroPointTensor", true }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "BTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "BScaleTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "BZeroPointTensor", true }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputScaleTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputZeroPointTensor", true }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false }, +}; + +constexpr DML_OPERATOR_SCHEMA DML_ELEMENT_WISE_QUANTIZED_LINEAR_ADD_OPERATOR_SCHEMA { + "DML_OPERATOR_ELEMENT_WISE_QUANTIZED_LINEAR_ADD", + DML_OPERATOR_ELEMENT_WISE_QUANTIZED_LINEAR_ADD, + DML_SCHEMA_OPERATOR_SUPPORT_FLAG_NONE, + 9, + DML_ELEMENT_WISE_QUANTIZED_LINEAR_ADD_OPERATOR_SCHEMA_FIELDS, +}; + constexpr DML_SCHEMA_FIELD DML_ACTIVATION_ELU_OPERATOR_SCHEMA_FIELDS[3] { DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false }, DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false }, diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/GeneratedSchemaHelpers.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/GeneratedSchemaHelpers.h index 7630e98489..b698a2d6d1 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/GeneratedSchemaHelpers.h +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/GeneratedSchemaHelpers.h @@ -1180,6 +1180,29 @@ inline std::vector GetFields(const DML_GATHER_ND1_OPERATOR_DESC& OperatorField(&DML_GATHER_ND1_OPERATOR_SCHEMA.Fields[5], ToOperatorFieldType(static_cast(desc.BatchDimensionCount))), }; } +inline std::vector GetFields(const DML_DYNAMIC_QUANTIZE_LINEAR_OPERATOR_DESC& desc) +{ + return { + OperatorField(&DML_DYNAMIC_QUANTIZE_LINEAR_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast(desc.InputTensor))), + OperatorField(&DML_DYNAMIC_QUANTIZE_LINEAR_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast(desc.OutputTensor))), + OperatorField(&DML_DYNAMIC_QUANTIZE_LINEAR_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast(desc.OutputScaleTensor))), + OperatorField(&DML_DYNAMIC_QUANTIZE_LINEAR_OPERATOR_SCHEMA.Fields[3], ToOperatorFieldType(static_cast(desc.OutputZeroPointTensor))), + }; +} +inline std::vector GetFields(const DML_ELEMENT_WISE_QUANTIZED_LINEAR_ADD_OPERATOR_DESC& desc) +{ + return { + OperatorField(&DML_ELEMENT_WISE_QUANTIZED_LINEAR_ADD_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast(desc.ATensor))), + OperatorField(&DML_ELEMENT_WISE_QUANTIZED_LINEAR_ADD_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast(desc.AScaleTensor))), + OperatorField(&DML_ELEMENT_WISE_QUANTIZED_LINEAR_ADD_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast(desc.AZeroPointTensor))), + OperatorField(&DML_ELEMENT_WISE_QUANTIZED_LINEAR_ADD_OPERATOR_SCHEMA.Fields[3], ToOperatorFieldType(static_cast(desc.BTensor))), + OperatorField(&DML_ELEMENT_WISE_QUANTIZED_LINEAR_ADD_OPERATOR_SCHEMA.Fields[4], ToOperatorFieldType(static_cast(desc.BScaleTensor))), + OperatorField(&DML_ELEMENT_WISE_QUANTIZED_LINEAR_ADD_OPERATOR_SCHEMA.Fields[5], ToOperatorFieldType(static_cast(desc.BZeroPointTensor))), + OperatorField(&DML_ELEMENT_WISE_QUANTIZED_LINEAR_ADD_OPERATOR_SCHEMA.Fields[6], ToOperatorFieldType(static_cast(desc.OutputScaleTensor))), + OperatorField(&DML_ELEMENT_WISE_QUANTIZED_LINEAR_ADD_OPERATOR_SCHEMA.Fields[7], ToOperatorFieldType(static_cast(desc.OutputZeroPointTensor))), + OperatorField(&DML_ELEMENT_WISE_QUANTIZED_LINEAR_ADD_OPERATOR_SCHEMA.Fields[8], ToOperatorFieldType(static_cast(desc.OutputTensor))), + }; +} inline std::vector GetFields(const DML_ACTIVATION_ELU_OPERATOR_DESC& desc) { return { @@ -1463,6 +1486,8 @@ inline const DML_OPERATOR_SCHEMA& GetSchema(DML_OPERATOR_TYPE operatorType) case DML_OPERATOR_ADAM_OPTIMIZER: return DML_ADAM_OPTIMIZER_OPERATOR_SCHEMA; case DML_OPERATOR_ROI_ALIGN: return DML_ROI_ALIGN_OPERATOR_SCHEMA; case DML_OPERATOR_GATHER_ND1: return DML_GATHER_ND1_OPERATOR_SCHEMA; + case DML_OPERATOR_DYNAMIC_QUANTIZE_LINEAR: return DML_DYNAMIC_QUANTIZE_LINEAR_OPERATOR_SCHEMA; + case DML_OPERATOR_ELEMENT_WISE_QUANTIZED_LINEAR_ADD: return DML_ELEMENT_WISE_QUANTIZED_LINEAR_ADD_OPERATOR_SCHEMA; case DML_OPERATOR_ACTIVATION_ELU: return DML_ACTIVATION_ELU_OPERATOR_SCHEMA; case DML_OPERATOR_ACTIVATION_CELU: return DML_ACTIVATION_CELU_OPERATOR_SCHEMA; case DML_OPERATOR_ACTIVATION_HARDMAX: return DML_ACTIVATION_HARDMAX_OPERATOR_SCHEMA; @@ -1972,6 +1997,14 @@ inline AbstractOperatorDesc ConvertOperatorDesc(const DML_OPERATOR_DESC& opDesc) return AbstractOperatorDesc( &DML_GATHER_ND1_OPERATOR_SCHEMA, GetFields(*static_cast(opDesc.Desc))); + case DML_OPERATOR_DYNAMIC_QUANTIZE_LINEAR: + return AbstractOperatorDesc( + &DML_DYNAMIC_QUANTIZE_LINEAR_OPERATOR_SCHEMA, + GetFields(*static_cast(opDesc.Desc))); + case DML_OPERATOR_ELEMENT_WISE_QUANTIZED_LINEAR_ADD: + return AbstractOperatorDesc( + &DML_ELEMENT_WISE_QUANTIZED_LINEAR_ADD_OPERATOR_SCHEMA, + GetFields(*static_cast(opDesc.Desc))); case DML_OPERATOR_ACTIVATION_ELU: return AbstractOperatorDesc( &DML_ACTIVATION_ELU_OPERATOR_SCHEMA, diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorDynamicQuantizeLinear.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorDynamicQuantizeLinear.cpp index b727817e98..75a19968e8 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorDynamicQuantizeLinear.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorDynamicQuantizeLinear.cpp @@ -7,30 +7,45 @@ namespace Dml { class DmlOperatorDynamicQuantizeLinear : public DmlOperator -{ +{ + enum DmlInputTensors { + IN_A, + }; + + enum DmlOutputTensors { + OUT_Y, + OUT_Y_SCALE, + OUT_Y_ZERO_POINT + }; + public: using Self = DmlOperatorDynamicQuantizeLinear; DmlOperatorDynamicQuantizeLinear(const MLOperatorKernelCreationContext& kernelCreationContext) : DmlOperator(kernelCreationContext) { -#if 0 // TODO:NickFe - https://github.com/onnx/onnx/blob/master/docs/Operators.md#dynamicquantizelinear ML_CHECK_VALID_ARGUMENT(kernelCreationContext.GetInputCount() == 1); ML_CHECK_VALID_ARGUMENT(kernelCreationContext.GetOutputCount() == 3); - DmlOperator::Initialize(kernelCreationContext, inputIndices, outputIndices); + + DmlOperator::Initialize(kernelCreationContext); + + m_inputTensorDescs[IN_A] = CreateTensorDescFromInput(kernelCreationContext, 0/*A OnnxIndex*/, TensorAxis::DoNotCoerce, TensorAxis::W, TensorAxis::RightAligned); + + m_outputTensorDescs[OUT_Y] = CreateTensorDescFromOutput(kernelCreationContext, 0/*Y OnnxIndex*/, TensorAxis::DoNotCoerce, TensorAxis::W, TensorAxis::RightAligned); + m_outputTensorDescs[OUT_Y_SCALE] = CreateTensorDescFromOutput(kernelCreationContext, 1/*Y Scale OnnxIndex*/, TensorAxis::DoNotCoerce, TensorAxis::W, TensorAxis::RightAligned); + m_outputTensorDescs[OUT_Y_ZERO_POINT] = CreateTensorDescFromOutput(kernelCreationContext, 2/*Y Zero point OnnxIndex*/, TensorAxis::DoNotCoerce, TensorAxis::W, TensorAxis::RightAligned); std::vector inputDescs = GetDmlInputDescs(); std::vector outputDescs = GetDmlOutputDescs(); - DML_PLACEHOLDER_OPERATOR_DESC operatorDesc = {}; - operatorDesc.IndicesTensor = &inputDescs[0]; - operatorDesc.ValuesTensor = &inputDescs[1]; - operatorDesc.OutputTensor = outputDescs.data(); - operatorDesc.Axis = dmlAxis; + DML_DYNAMIC_QUANTIZE_LINEAR_OPERATOR_DESC operatorDesc = {}; + operatorDesc.InputTensor = &inputDescs[IN_A]; + operatorDesc.OutputTensor = &outputDescs[OUT_Y]; + operatorDesc.OutputScaleTensor = &outputDescs[OUT_Y_SCALE]; + operatorDesc.OutputZeroPointTensor = &outputDescs[OUT_Y_ZERO_POINT]; - DML_OPERATOR_DESC opDesc = { DML_OPERATOR_PLACEHOLDER, &operatorDesc }; + DML_OPERATOR_DESC opDesc = { DML_OPERATOR_DYNAMIC_QUANTIZE_LINEAR, &operatorDesc }; SetDmlOperatorDesc(opDesc, kernelCreationContext); -#endif } }; diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorQLinearAdd.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorQLinearAdd.cpp new file mode 100644 index 0000000000..7b50dfb9ff --- /dev/null +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorQLinearAdd.cpp @@ -0,0 +1,69 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "precomp.h" + +namespace Dml +{ + +class DmlOperatorQLinearAdd : public DmlOperator +{ + enum InputTensors { + IN_A, + IN_A_SCALE, + IN_A_ZERO_POINT, + IN_B, + IN_B_SCALE, + IN_B_ZERO_POINT, + IN_C_SCALE, + IN_C_ZERO_POINT + }; + +public: + DmlOperatorQLinearAdd(const MLOperatorKernelCreationContext& kernelInfo) + : DmlOperator(kernelInfo) + { + DmlOperator::Initialize(kernelInfo); + + std::vector outputShape = kernelInfo.GetTensorShapeDescription().GetOutputTensorShape(0); + + uint32_t dmlDimSize = m_inputTensorDescs[0].GetDimensionCount(); + + // Initialize the input descriptions with broadcasting + m_inputTensorDescs[IN_A] = CreateTensorDescFromInput(kernelInfo, 0/*A OnnxIndex*/, TensorAxis::DoNotCoerce, TensorAxis::W, TensorAxis::RightAligned, outputShape); + m_inputTensorDescs[IN_B] = CreateTensorDescFromInput(kernelInfo, 3/*B OnnxIndex*/, TensorAxis::DoNotCoerce, TensorAxis::W, TensorAxis::RightAligned, outputShape); + + m_inputTensorDescs[IN_A_SCALE] = CreateTensorDescFromInput(kernelInfo, 1/*A Scale OnnxIndex*/, TensorAxis::DoNotCoerce, TensorAxis::W, TensorAxis::RightAligned, std::nullopt, dmlDimSize); + m_inputTensorDescs[IN_A_ZERO_POINT] = CreateTensorDescFromInput(kernelInfo, 2/*A Zero point OnnxIndex*/, TensorAxis::DoNotCoerce, TensorAxis::W, TensorAxis::RightAligned, std::nullopt, dmlDimSize); + + m_inputTensorDescs[IN_B_SCALE] = CreateTensorDescFromInput(kernelInfo, 4/*B Scale OnnxIndex*/, TensorAxis::DoNotCoerce, TensorAxis::W, TensorAxis::RightAligned, std::nullopt, dmlDimSize); + m_inputTensorDescs[IN_B_ZERO_POINT] = CreateTensorDescFromInput(kernelInfo, 5/*B Zero point OnnxIndex*/, TensorAxis::DoNotCoerce, TensorAxis::W, TensorAxis::RightAligned, std::nullopt, dmlDimSize); + + m_inputTensorDescs[IN_C_SCALE] = CreateTensorDescFromInput(kernelInfo, 6/*C Zero point OnnxIndex*/, TensorAxis::DoNotCoerce, TensorAxis::W, TensorAxis::RightAligned, std::nullopt, dmlDimSize); + m_inputTensorDescs[IN_C_ZERO_POINT] = CreateTensorDescFromInput(kernelInfo, 7/*C Zero point OnnxIndex*/, TensorAxis::DoNotCoerce, TensorAxis::W, TensorAxis::RightAligned, std::nullopt, dmlDimSize); + + // Initialize the output description while overriding the shape + m_outputTensorDescs[0] = CreateTensorDescFromOutput(kernelInfo, 0, TensorAxis::DoNotCoerce, TensorAxis::W, TensorAxis::RightAligned, outputShape); + + std::vector inputDescs = GetDmlInputDescs(); + std::vector outputDescs = GetDmlOutputDescs(); + + DML_ELEMENT_WISE_QUANTIZED_LINEAR_ADD_OPERATOR_DESC AddDesc = {}; + AddDesc.ATensor = &inputDescs[IN_A]; + AddDesc.AScaleTensor = &inputDescs[IN_A_SCALE]; + AddDesc.AZeroPointTensor = inputDescs[IN_A_ZERO_POINT].Desc != nullptr ? &inputDescs[IN_A_ZERO_POINT] : nullptr; + AddDesc.BTensor = &inputDescs[IN_B]; + AddDesc.BScaleTensor = &inputDescs[IN_B_SCALE]; + AddDesc.BZeroPointTensor = inputDescs[IN_B_ZERO_POINT].Desc != nullptr ? &inputDescs[IN_B_ZERO_POINT] : nullptr; + AddDesc.OutputScaleTensor = &inputDescs[IN_C_SCALE]; + AddDesc.OutputZeroPointTensor = inputDescs[IN_C_ZERO_POINT].Desc != nullptr ? &inputDescs[IN_C_ZERO_POINT] : nullptr; + AddDesc.OutputTensor = &outputDescs[0]; + + DML_OPERATOR_DESC opDesc = { DML_OPERATOR_ELEMENT_WISE_QUANTIZED_LINEAR_ADD, &AddDesc }; + SetDmlOperatorDesc(opDesc, kernelInfo); + } +}; + +DML_OP_DEFINE_CREATION_FUNCTION(QLinearAdd, DmlOperatorQLinearAdd); + +} // namespace Dml diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.cpp index 413971166c..a3ad2f5755 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.cpp @@ -236,6 +236,7 @@ DML_OP_EXTERN_CREATION_FUNCTION(ReverseSequence); DML_OP_EXTERN_CREATION_FUNCTION(Round); DML_OP_EXTERN_CREATION_FUNCTION(ScatterElements); DML_OP_EXTERN_CREATION_FUNCTION(ScatterND); +DML_OP_EXTERN_CREATION_FUNCTION(QLinearAdd); DML_OP_EXTERN_CREATION_FUNCTION(QLinearConv); DML_OP_EXTERN_CREATION_FUNCTION(QLinearMatMul); DML_OP_EXTERN_CREATION_FUNCTION(DynamicQuantizeLinear); @@ -260,6 +261,7 @@ constexpr static std::array typeNameListScatterGatherND = { "T" constexpr static std::array typeNameListSlice10 = { "T", "Tind" }; constexpr static std::array typeNameListWhere = { "B", "T" }; constexpr static std::array typeNameListEyeLike = { "T2" }; + constexpr static std::array supportedTypeListAll = {SupportedTensorDataTypes::All}; constexpr static std::array supportedTypeListFloat32 = {SupportedTensorDataTypes::Float32}; constexpr static std::array supportedTypeListFloat16to32 = {SupportedTensorDataTypes::Float16to32}; @@ -292,6 +294,7 @@ constexpr static std::array supportedTypeListLogica constexpr static std::array supportedTypeListSigned = { SupportedTensorDataTypes::Float16to32 | SupportedTensorDataTypes::Int32 | SupportedTensorDataTypes::Int16 | SupportedTensorDataTypes::Int8 }; constexpr static std::array supportedTypeListRange = {SupportedTensorDataTypes::Int16|SupportedTensorDataTypes::Int32|SupportedTensorDataTypes::Float32}; constexpr static std::array supportedTypeListInteger = {SupportedTensorDataTypes::Int8|SupportedTensorDataTypes::UInt8, SupportedTensorDataTypes::Int8|SupportedTensorDataTypes::UInt8, SupportedTensorDataTypes::Int32 }; +constexpr static std::array supportedTypeListInteger8 = {SupportedTensorDataTypes::Int8|SupportedTensorDataTypes::UInt8 }; constexpr static std::array supportedTypeListRoiAlign = {SupportedTensorDataTypes::Float16to32, SupportedTensorDataTypes::Int32|SupportedTensorDataTypes::Int64 }; constexpr static std::array supportedTypeListQLinearMatMul = { SupportedTensorDataTypes::Int8|SupportedTensorDataTypes::UInt8, @@ -305,6 +308,12 @@ constexpr static std::array supportedTypeListQLinea SupportedTensorDataTypes::Int32 }; + +constexpr static std::array supportedTypeListDynamicQuantizeLinear = { + SupportedTensorDataTypes::Float32, + SupportedTensorDataTypes::UInt8, +}; + template constexpr auto requiredConstantCpuInputs(Args... args) { @@ -337,249 +346,252 @@ constexpr auto requiredConstantCpuInputs(Args... args) constexpr static OperatorRegistrationInformation operatorRegistrationInformationTable[] = { -/// Domain/Type, Ver, Name, TypeNames, Types, Graph Support, Required const CPU inputs, -/// Input count required for graph support, -/// Support query function +/// Domain/Type, Ver, Name, TypeNames, Types, Graph Support, Required const CPU inputs, +/// Input count required for graph support, +/// Support query function // Deep Learning Standard Layers - {REG_INFO( 7, Conv, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, - {REG_INFO( 11, Conv, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, - {REG_INFO( 7, ConvTranspose, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, - {REG_INFO( 11, ConvTranspose, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, - {REG_INFO( 7, AveragePool, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, - {REG_INFO( 10, AveragePool, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, - {REG_INFO( 11, AveragePool, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, - {REG_INFO( 7, GlobalAveragePool, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, - {REG_INFO( 7, MaxPool, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides)}, - {REG_INFO( 8, MaxPool, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides, requiredConstantCpuInputs(), std::nullopt, QueryMaxPool)}, - {REG_INFO( 10, MaxPool, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides, requiredConstantCpuInputs(), std::nullopt, QueryMaxPool)}, - {REG_INFO( 11, MaxPool, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides, requiredConstantCpuInputs(), std::nullopt, QueryMaxPool)}, - {REG_INFO( 12, MaxPool, typeNameListDefault, supportedTypeListFloat16to32Int8, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides, requiredConstantCpuInputs(), std::nullopt, QueryMaxPool)}, + {REG_INFO( 7, Conv, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, + {REG_INFO( 11, Conv, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, + {REG_INFO( 7, ConvTranspose, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, + {REG_INFO( 11, ConvTranspose, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, + {REG_INFO( 7, AveragePool, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, + {REG_INFO( 10, AveragePool, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, + {REG_INFO( 11, AveragePool, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, + {REG_INFO( 7, GlobalAveragePool, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, + {REG_INFO( 7, MaxPool, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides)}, + {REG_INFO( 8, MaxPool, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides, requiredConstantCpuInputs(), std::nullopt, QueryMaxPool)}, + {REG_INFO( 10, MaxPool, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides, requiredConstantCpuInputs(), std::nullopt, QueryMaxPool)}, + {REG_INFO( 11, MaxPool, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides, requiredConstantCpuInputs(), std::nullopt, QueryMaxPool)}, + {REG_INFO( 12, MaxPool, typeNameListDefault, supportedTypeListFloat16to32Int8, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides, requiredConstantCpuInputs(), std::nullopt, QueryMaxPool)}, - {REG_INFO( 7, GlobalMaxPool, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, - {REG_INFO( 7, LpPool, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, - {REG_INFO( 11, LpPool, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, - {REG_INFO( 7, GlobalLpPool, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, - {REG_INFO( 7, MaxRoiPool, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, - {REG_INFO_VER( 10, RoiAlign, typeNameListTwo, supportedTypeListRoiAlign, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStridesFromAnyEp)}, - {REG_INFO( 7, InstanceNormalization, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, - {REG_INFO( 7, BatchNormalization, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, - {REG_INFO( 9, BatchNormalization, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, // v9 just removes 'spatial' attribute. - {REG_INFO( 7, LRN, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, - {REG_INFO( 7, MeanVarianceNormalization, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, - {REG_INFO( 9, MeanVarianceNormalization, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, - {REG_INFO( 7, LpNormalization, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, - {REG_INFO( 7, RNN, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::NotSupported)}, - {REG_INFO( 7, GRU, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::NotSupported)}, - {REG_INFO( 7, LSTM, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::NotSupported)}, - {REG_INFO_MS( 1, ConvTransposeWithDynamicPads, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported, requiredConstantCpuInputs(2))}, + {REG_INFO( 7, GlobalMaxPool, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, + {REG_INFO( 7, LpPool, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, + {REG_INFO( 11, LpPool, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, + {REG_INFO( 7, GlobalLpPool, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, + {REG_INFO( 7, MaxRoiPool, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, + {REG_INFO_VER( 10, RoiAlign, typeNameListTwo, supportedTypeListRoiAlign, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStridesFromAnyEp)}, + {REG_INFO( 7, InstanceNormalization, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, + {REG_INFO( 7, BatchNormalization, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, + {REG_INFO( 9, BatchNormalization, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, // v9 just removes 'spatial' attribute. + {REG_INFO( 7, LRN, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, + {REG_INFO( 7, MeanVarianceNormalization, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, + {REG_INFO( 9, MeanVarianceNormalization, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, + {REG_INFO( 7, LpNormalization, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, + {REG_INFO( 7, RNN, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::NotSupported)}, + {REG_INFO( 7, GRU, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::NotSupported)}, + {REG_INFO( 7, LSTM, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::NotSupported)}, + {REG_INFO_MS( 1, ConvTransposeWithDynamicPads, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported, requiredConstantCpuInputs(2))}, // Data Reorganization Layers - {REG_INFO( 7, Split, typeNameListDefault, supportedTypeListNumericDefault, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides)}, - {REG_INFO( 11, Split, typeNameListDefault, supportedTypeListNumericDefault, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides)}, // Adds negative axis. - {REG_INFO( 7, Transpose, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides)}, - {REG_INFO( 7, Concat, typeNameListDefault, supportedTypeListNumericDefault, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides)}, - {REG_INFO( 11, Concat, typeNameListDefault, supportedTypeListNumericDefault, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides)}, // Adds negative axis. - {REG_INFO_VER( 7, Slice, typeNameListDefault, supportedTypeListNumericDefault, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides)}, - {REG_INFO_VER( 10, Slice, typeNameListSlice10, supportedTypeListSlice10, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides, requiredConstantCpuInputs(1, 2, 3, 4), std::nullopt, QuerySlice)}, // Adds negative axes. - {REG_INFO_VER( 11, Slice, typeNameListSlice10, supportedTypeListSlice10, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides, requiredConstantCpuInputs(1, 2, 3, 4), std::nullopt, QuerySlice)}, - {REG_INFO_VER( 7, Pad, typeNameListDefault, supportedTypeListNumericDefault, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides)}, - {REG_INFO_VER( 11, Pad, typeNameListDefault, supportedTypeListNumericDefault, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides, requiredConstantCpuInputs(1, 2) /*pads, value*/)}, // https://microsoft.visualstudio.com/OS/_workitems/edit/26007728 - {REG_INFO( 7, SpaceToDepth, typeNameListDefault, supportedTypeListScalars8to32, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides)}, - {REG_INFO( 7, DepthToSpace, typeNameListDefault, supportedTypeListScalars8to32, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides)}, - {REG_INFO( 11, DepthToSpace, typeNameListDefault, supportedTypeListScalars8to32, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides)}, - {REG_INFO( 7, Tile, typeNameListDefault, supportedTypeListScalars8to32, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides, requiredConstantCpuInputs(1))}, - {REG_INFO( 8, Expand, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides, requiredConstantCpuInputs(1))}, - {REG_INFO( 9, ConstantOfShape, typeNameListConstantOfShape, supportedTypeListConstantOfShape, DmlGraphSupport::NotSupported, requiredConstantCpuInputs(0))}, - {REG_INFO( 7, Gather, typeNameListScatterGather, supportedTypeListScatterGather, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStridesFromAnyEp)}, - {REG_INFO( 11, Gather, typeNameListScatterGather, supportedTypeListScatterGather, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStridesFromAnyEp)}, - {REG_INFO( 11, GatherElements, typeNameListScatterGather, supportedTypeListScatterGather, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStridesFromAnyEp)}, - {REG_INFO( 11, GatherND, typeNameListScatterGatherND, supportedTypeListScatterGatherND, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStridesFromAnyEp)}, - {REG_INFO( 12, GatherND, typeNameListScatterGatherND, supportedTypeListScatterGatherND, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStridesFromAnyEp)}, - {REG_INFO_VER( 9, Scatter, typeNameListScatterGather, supportedTypeListScatterGather, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStridesFromAnyEp)}, - {REG_INFO_VER( 11, Scatter, typeNameListScatterGather, supportedTypeListScatterGather, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStridesFromAnyEp)}, - {REG_INFO( 11, ScatterElements, typeNameListScatterGather, supportedTypeListScatterGather, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStridesFromAnyEp)}, - {REG_INFO( 11, ScatterND, typeNameListScatterGatherND, supportedTypeListScatterGatherND, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStridesFromAnyEp)}, - {REG_INFO( 9, EyeLike, typeNameListEyeLike, supportedTypeListScalars8to32, DmlGraphSupport::Supported)}, + {REG_INFO( 7, Split, typeNameListDefault, supportedTypeListNumericDefault, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides)}, + {REG_INFO( 11, Split, typeNameListDefault, supportedTypeListNumericDefault, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides)}, // Adds negative axis. + {REG_INFO( 7, Transpose, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides)}, + {REG_INFO( 7, Concat, typeNameListDefault, supportedTypeListNumericDefault, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides)}, + {REG_INFO( 11, Concat, typeNameListDefault, supportedTypeListNumericDefault, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides)}, // Adds negative axis. + {REG_INFO_VER( 7, Slice, typeNameListDefault, supportedTypeListNumericDefault, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides)}, + {REG_INFO_VER( 10, Slice, typeNameListSlice10, supportedTypeListSlice10, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides, requiredConstantCpuInputs(1, 2, 3, 4), std::nullopt, QuerySlice)}, // Adds negative axes. + {REG_INFO_VER( 11, Slice, typeNameListSlice10, supportedTypeListSlice10, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides, requiredConstantCpuInputs(1, 2, 3, 4), std::nullopt, QuerySlice)}, + {REG_INFO_VER( 7, Pad, typeNameListDefault, supportedTypeListNumericDefault, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides)}, + {REG_INFO_VER( 11, Pad, typeNameListDefault, supportedTypeListNumericDefault, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides, requiredConstantCpuInputs(1, 2) /*pads, value*/)}, // https://microsoft.visualstudio.com/OS/_workitems/edit/26007728 + {REG_INFO( 7, SpaceToDepth, typeNameListDefault, supportedTypeListScalars8to32, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides)}, + {REG_INFO( 7, DepthToSpace, typeNameListDefault, supportedTypeListScalars8to32, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides)}, + {REG_INFO( 11, DepthToSpace, typeNameListDefault, supportedTypeListScalars8to32, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides)}, + {REG_INFO( 7, Tile, typeNameListDefault, supportedTypeListScalars8to32, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides, requiredConstantCpuInputs(1))}, + {REG_INFO( 8, Expand, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides, requiredConstantCpuInputs(1))}, + {REG_INFO( 9, ConstantOfShape, typeNameListConstantOfShape, supportedTypeListConstantOfShape, DmlGraphSupport::NotSupported, requiredConstantCpuInputs(0))}, + {REG_INFO( 7, Gather, typeNameListScatterGather, supportedTypeListScatterGather, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStridesFromAnyEp)}, + {REG_INFO( 11, Gather, typeNameListScatterGather, supportedTypeListScatterGather, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStridesFromAnyEp)}, + {REG_INFO( 11, GatherElements, typeNameListScatterGather, supportedTypeListScatterGather, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStridesFromAnyEp)}, + {REG_INFO( 11, GatherND, typeNameListScatterGatherND, supportedTypeListScatterGatherND, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStridesFromAnyEp)}, + {REG_INFO( 12, GatherND, typeNameListScatterGatherND, supportedTypeListScatterGatherND, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStridesFromAnyEp)}, + {REG_INFO_VER( 9, Scatter, typeNameListScatterGather, supportedTypeListScatterGather, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStridesFromAnyEp)}, + {REG_INFO_VER( 11, Scatter, typeNameListScatterGather, supportedTypeListScatterGather, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStridesFromAnyEp)}, + {REG_INFO( 11, ScatterElements, typeNameListScatterGather, supportedTypeListScatterGather, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStridesFromAnyEp)}, + {REG_INFO( 11, ScatterND, typeNameListScatterGatherND, supportedTypeListScatterGatherND, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStridesFromAnyEp)}, + {REG_INFO( 9, EyeLike, typeNameListEyeLike, supportedTypeListScalars8to32, DmlGraphSupport::Supported)}, // Data reorganization that merely changes the dimensions while keeping the data identical. - {REG_INFO_ID( 7, Identity, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides)}, - {REG_INFO_ID( 7, Flatten, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides)}, - {REG_INFO_ID( 9, Flatten, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides)}, - {REG_INFO_ID( 11, Flatten, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides)}, - {REG_INFO_ID( 7, Squeeze, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides)}, - {REG_INFO_ID( 11, Squeeze, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides)}, - {REG_INFO_ID( 7, Unsqueeze, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides)}, - {REG_INFO_ID( 11, Unsqueeze, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides)}, - {REG_INFO_ID( 7, Reshape, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides, requiredConstantCpuInputs(1))}, + {REG_INFO_ID( 7, Identity, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides)}, + {REG_INFO_ID( 7, Flatten, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides)}, + {REG_INFO_ID( 9, Flatten, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides)}, + {REG_INFO_ID( 11, Flatten, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides)}, + {REG_INFO_ID( 7, Squeeze, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides)}, + {REG_INFO_ID( 11, Squeeze, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides)}, + {REG_INFO_ID( 7, Unsqueeze, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides)}, + {REG_INFO_ID( 11, Unsqueeze, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides)}, + {REG_INFO_ID( 7, Reshape, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides, requiredConstantCpuInputs(1))}, // Elementwise - {REG_INFO( 7, Sqrt, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, - {REG_INFO( 7, Reciprocal, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, - {REG_INFO( 7, Pow, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, - {REG_INFO( 12, Pow, typeNameListPow12, supportedTypeListPow12, DmlGraphSupport::Supported)}, - {REG_INFO( 7, Exp, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, - {REG_INFO( 7, Log, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, - {REG_INFO( 7, Abs, typeNameListDefault, supportedTypeListNumericDefault, DmlGraphSupport::Supported)}, - {REG_INFO( 7, Ceil, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, - {REG_INFO( 7, Floor, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, - {REG_INFO_VER( 7, Clip, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, - {REG_INFO_VER( 11, Clip, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported, requiredConstantCpuInputs(1,2))}, - {REG_INFO_VER( 12, Clip, typeNameListDefault, supportedTypeListFloat16to32Int8to32,DmlGraphSupport::Supported, requiredConstantCpuInputs(1,2))}, - {REG_INFO( 7, Add, typeNameListDefault, supportedTypeListFloat16to32Int32, DmlGraphSupport::Supported)}, - {REG_INFO( 7, Sub, typeNameListDefault, supportedTypeListFloat16to32Int32, DmlGraphSupport::Supported)}, - {REG_INFO( 7, Mul, typeNameListDefault, supportedTypeListFloat16to32Int32, DmlGraphSupport::Supported)}, - {REG_INFO( 7, Div, typeNameListDefault, supportedTypeListFloat16to32Int32, DmlGraphSupport::Supported)}, - {REG_INFO( 7, Sum, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported, requiredConstantCpuInputs(), 2)}, - {REG_INFO( 8, Sum, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported, requiredConstantCpuInputs(), 2)}, - {REG_INFO( 7, Mean, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported, requiredConstantCpuInputs(), 2)}, - {REG_INFO( 8, Mean, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported, requiredConstantCpuInputs(), 2)}, - {REG_INFO( 7, Max, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported, requiredConstantCpuInputs(), 2)}, - {REG_INFO( 8, Max, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported, requiredConstantCpuInputs(), 2)}, - {REG_INFO( 12, Max, typeNameListDefault, supportedTypeListFloat16to32Int8to32,DmlGraphSupport::Supported, requiredConstantCpuInputs(), 2)}, - {REG_INFO( 7, Min, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported, requiredConstantCpuInputs(), 2)}, - {REG_INFO( 8, Min, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported, requiredConstantCpuInputs(), 2)}, - {REG_INFO( 12, Min, typeNameListDefault, supportedTypeListFloat16to32Int8to32,DmlGraphSupport::Supported, requiredConstantCpuInputs(), 2)}, - {REG_INFO( 7, Cos, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, - {REG_INFO( 7, Sin, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, - {REG_INFO( 7, Tan, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, - {REG_INFO( 7, Acos, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, - {REG_INFO( 7, Asin, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, - {REG_INFO( 7, Atan, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, - {REG_INFO( 7, Affine, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, - {REG_INFO( 10, QuantizeLinear, typeNameListTwo, supportedTypeListQuantizeLinear, DmlGraphSupport::Supported)}, - {REG_INFO( 10, DequantizeLinear, typeNameListTwo, supportedTypeListDequantizeLinear, DmlGraphSupport::Supported)}, - {REG_INFO_MS( 1, QuantizeLinear, typeNameListTwo, supportedTypeListQuantize, DmlGraphSupport::Supported)}, - {REG_INFO_MS( 1, DequantizeLinear, typeNameListTwo, supportedTypeListQuantize, DmlGraphSupport::Supported)}, - {REG_INFO( 9, Sign, typeNameListDefault, supportedTypeListNumericDefault, DmlGraphSupport::Supported)}, - {REG_INFO( 9, IsNan, typeNameListTwo, supportedTypeListIsNan, DmlGraphSupport::Supported)}, - {REG_INFO( 9, Sinh, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, - {REG_INFO( 9, Cosh, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, - {REG_INFO( 9, Asinh, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, - {REG_INFO( 9, Acosh, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, - {REG_INFO( 9, Atanh, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, - {REG_INFO( 9, Erf, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, - {REG_INFO( 9, Where, typeNameListWhere, supportedTypeListWhere, DmlGraphSupport::Supported)}, - {REG_INFO( 7, ReduceSum, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, - {REG_INFO( 11, ReduceSum, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, - {REG_INFO_VER( 12, Einsum, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported, requiredConstantCpuInputs(), std::nullopt, QueryEinSum )}, - {REG_INFO( 7, ReduceMean, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, - {REG_INFO( 11, ReduceMean, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, - {REG_INFO( 7, ReduceProd, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, - {REG_INFO( 11, ReduceProd, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, - {REG_INFO( 7, ReduceLogSum, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, - {REG_INFO( 11, ReduceLogSum, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, - {REG_INFO( 7, ReduceLogSumExp, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, - {REG_INFO( 11, ReduceLogSumExp, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, - {REG_INFO( 7, ReduceSumSquare, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, - {REG_INFO( 11, ReduceSumSquare, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, - {REG_INFO( 7, ReduceL1, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, - {REG_INFO( 11, ReduceL1, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, - {REG_INFO( 7, ReduceL2, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, - {REG_INFO( 11, ReduceL2, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, - {REG_INFO( 7, ReduceMax, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, - {REG_INFO( 11, ReduceMax, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, - {REG_INFO( 12, ReduceMax, typeNameListDefault, supportedTypeListFloat16to32Int8, DmlGraphSupport::Supported)}, - {REG_INFO( 7, ReduceMin, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, - {REG_INFO( 11, ReduceMin, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, - {REG_INFO( 12, ReduceMin, typeNameListDefault, supportedTypeListFloat16to32Int8, DmlGraphSupport::Supported)}, - {REG_INFO( 7, ArgMax, typeNameListDefault, supportedTypeListNumericDefault, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides)}, - {REG_INFO( 11, ArgMax, typeNameListDefault, supportedTypeListNumericDefault, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides)}, - {REG_INFO( 12, ArgMax, typeNameListDefault, supportedTypeListNumericDefault, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides)}, - {REG_INFO( 7, ArgMin, typeNameListDefault, supportedTypeListNumericDefault, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides)}, - {REG_INFO( 11, ArgMin, typeNameListDefault, supportedTypeListNumericDefault, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides)}, - {REG_INFO( 12, ArgMin, typeNameListDefault, supportedTypeListNumericDefault, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides)}, - {REG_INFO( 7, Gemm, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, - {REG_INFO( 11, Gemm, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, - {REG_INFO( 9, Gemm, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, - {REG_INFO( 7, Neg, typeNameListDefault, supportedTypeListSigned, DmlGraphSupport::Supported)}, - {REG_INFO( 7, Greater, typeNameListLogicalComparison, supportedTypeListLogicalComparison7,DmlGraphSupport::Supported)}, - {REG_INFO( 9, Greater, typeNameListLogicalComparison, supportedTypeListLogicalComparison9,DmlGraphSupport::Supported)}, - {REG_INFO( 7, Less, typeNameListLogicalComparison, supportedTypeListLogicalComparison7,DmlGraphSupport::Supported)}, - {REG_INFO( 9, Less, typeNameListLogicalComparison, supportedTypeListLogicalComparison9,DmlGraphSupport::Supported)}, - {REG_INFO( 12, GreaterOrEqual, typeNameListLogicalComparison, supportedTypeListLogicalComparison9,DmlGraphSupport::Supported)}, - {REG_INFO( 12, LessOrEqual, typeNameListLogicalComparison, supportedTypeListLogicalComparison9,DmlGraphSupport::Supported)}, - {REG_INFO( 7, Equal, typeNameListLogicalComparison, supportedTypeListLogicalComparison7,DmlGraphSupport::Supported)}, - {REG_INFO( 11, Equal, typeNameListLogicalComparison, supportedTypeListLogicalComparison9,DmlGraphSupport::Supported)}, - {REG_INFO( 7, Not, typeNameListDefault, supportedTypeListBool, DmlGraphSupport::Supported)}, - {REG_INFO( 7, And, typeNameListDefault, supportedTypeListBool, DmlGraphSupport::Supported)}, - {REG_INFO( 7, Or, typeNameListDefault, supportedTypeListBool, DmlGraphSupport::Supported)}, - {REG_INFO( 7, Xor, typeNameListDefault, supportedTypeListBool, DmlGraphSupport::Supported)}, + {REG_INFO( 7, Sqrt, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, + {REG_INFO( 7, Reciprocal, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, + {REG_INFO( 7, Pow, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, + {REG_INFO( 12, Pow, typeNameListPow12, supportedTypeListPow12, DmlGraphSupport::Supported)}, + {REG_INFO( 7, Exp, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, + {REG_INFO( 7, Log, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, + {REG_INFO( 7, Abs, typeNameListDefault, supportedTypeListNumericDefault, DmlGraphSupport::Supported)}, + {REG_INFO( 7, Ceil, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, + {REG_INFO( 7, Floor, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, + {REG_INFO_VER( 7, Clip, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, + {REG_INFO_VER( 11, Clip, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported, requiredConstantCpuInputs(1,2))}, + {REG_INFO_VER( 12, Clip, typeNameListDefault, supportedTypeListFloat16to32Int8to32, DmlGraphSupport::Supported, requiredConstantCpuInputs(1,2))}, + {REG_INFO( 7, Add, typeNameListDefault, supportedTypeListFloat16to32Int32, DmlGraphSupport::Supported)}, + {REG_INFO( 7, Sub, typeNameListDefault, supportedTypeListFloat16to32Int32, DmlGraphSupport::Supported)}, + {REG_INFO( 7, Mul, typeNameListDefault, supportedTypeListFloat16to32Int32, DmlGraphSupport::Supported)}, + {REG_INFO( 7, Div, typeNameListDefault, supportedTypeListFloat16to32Int32, DmlGraphSupport::Supported)}, + {REG_INFO( 7, Sum, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported, requiredConstantCpuInputs(), 2)}, + {REG_INFO( 8, Sum, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported, requiredConstantCpuInputs(), 2)}, + {REG_INFO( 7, Mean, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported, requiredConstantCpuInputs(), 2)}, + {REG_INFO( 8, Mean, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported, requiredConstantCpuInputs(), 2)}, + {REG_INFO( 7, Max, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported, requiredConstantCpuInputs(), 2)}, + {REG_INFO( 8, Max, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported, requiredConstantCpuInputs(), 2)}, + {REG_INFO( 12, Max, typeNameListDefault, supportedTypeListFloat16to32Int8to32, DmlGraphSupport::Supported, requiredConstantCpuInputs(), 2)}, + {REG_INFO( 7, Min, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported, requiredConstantCpuInputs(), 2)}, + {REG_INFO( 8, Min, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported, requiredConstantCpuInputs(), 2)}, + {REG_INFO( 12, Min, typeNameListDefault, supportedTypeListFloat16to32Int8to32, DmlGraphSupport::Supported, requiredConstantCpuInputs(), 2)}, + {REG_INFO( 7, Cos, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, + {REG_INFO( 7, Sin, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, + {REG_INFO( 7, Tan, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, + {REG_INFO( 7, Acos, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, + {REG_INFO( 7, Asin, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, + {REG_INFO( 7, Atan, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, + {REG_INFO( 7, Affine, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, + {REG_INFO( 10, QuantizeLinear, typeNameListTwo, supportedTypeListQuantizeLinear, DmlGraphSupport::Supported)}, + {REG_INFO( 10, DequantizeLinear, typeNameListTwo, supportedTypeListDequantizeLinear, DmlGraphSupport::Supported)}, + {REG_INFO_MS( 1, QuantizeLinear, typeNameListTwo, supportedTypeListQuantize, DmlGraphSupport::Supported)}, + {REG_INFO_MS( 1, DequantizeLinear, typeNameListTwo, supportedTypeListQuantize, DmlGraphSupport::Supported)}, + {REG_INFO( 9, Sign, typeNameListDefault, supportedTypeListNumericDefault, DmlGraphSupport::Supported)}, + {REG_INFO( 9, IsNan, typeNameListTwo, supportedTypeListIsNan, DmlGraphSupport::Supported)}, + {REG_INFO( 9, Sinh, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, + {REG_INFO( 9, Cosh, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, + {REG_INFO( 9, Asinh, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, + {REG_INFO( 9, Acosh, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, + {REG_INFO( 9, Atanh, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, + {REG_INFO( 9, Erf, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, + {REG_INFO( 9, Where, typeNameListWhere, supportedTypeListWhere, DmlGraphSupport::Supported)}, + {REG_INFO( 7, ReduceSum, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, + {REG_INFO( 11, ReduceSum, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, + {REG_INFO_VER( 12, Einsum, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported, requiredConstantCpuInputs(), std::nullopt, QueryEinSum )}, + {REG_INFO( 7, ReduceMean, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, + {REG_INFO( 11, ReduceMean, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, + {REG_INFO( 7, ReduceProd, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, + {REG_INFO( 11, ReduceProd, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, + {REG_INFO( 7, ReduceLogSum, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, + {REG_INFO( 11, ReduceLogSum, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, + {REG_INFO( 7, ReduceLogSumExp, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, + {REG_INFO( 11, ReduceLogSumExp, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, + {REG_INFO( 7, ReduceSumSquare, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, + {REG_INFO( 11, ReduceSumSquare, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, + {REG_INFO( 7, ReduceL1, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, + {REG_INFO( 11, ReduceL1, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, + {REG_INFO( 7, ReduceL2, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, + {REG_INFO( 11, ReduceL2, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, + {REG_INFO( 7, ReduceMax, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, + {REG_INFO( 11, ReduceMax, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, + {REG_INFO( 12, ReduceMax, typeNameListDefault, supportedTypeListFloat16to32Int8, DmlGraphSupport::Supported)}, + {REG_INFO( 7, ReduceMin, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, + {REG_INFO( 11, ReduceMin, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, + {REG_INFO( 12, ReduceMin, typeNameListDefault, supportedTypeListFloat16to32Int8, DmlGraphSupport::Supported)}, + {REG_INFO( 7, ArgMax, typeNameListDefault, supportedTypeListNumericDefault, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides)}, + {REG_INFO( 11, ArgMax, typeNameListDefault, supportedTypeListNumericDefault, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides)}, + {REG_INFO( 12, ArgMax, typeNameListDefault, supportedTypeListNumericDefault, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides)}, + {REG_INFO( 7, ArgMin, typeNameListDefault, supportedTypeListNumericDefault, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides)}, + {REG_INFO( 11, ArgMin, typeNameListDefault, supportedTypeListNumericDefault, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides)}, + {REG_INFO( 12, ArgMin, typeNameListDefault, supportedTypeListNumericDefault, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides)}, + {REG_INFO( 7, Gemm, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, + {REG_INFO( 11, Gemm, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, + {REG_INFO( 9, Gemm, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, + {REG_INFO( 7, Neg, typeNameListDefault, supportedTypeListSigned, DmlGraphSupport::Supported)}, + {REG_INFO( 7, Greater, typeNameListLogicalComparison, supportedTypeListLogicalComparison7, DmlGraphSupport::Supported)}, + {REG_INFO( 9, Greater, typeNameListLogicalComparison, supportedTypeListLogicalComparison9, DmlGraphSupport::Supported)}, + {REG_INFO( 7, Less, typeNameListLogicalComparison, supportedTypeListLogicalComparison7, DmlGraphSupport::Supported)}, + {REG_INFO( 9, Less, typeNameListLogicalComparison, supportedTypeListLogicalComparison9, DmlGraphSupport::Supported)}, + {REG_INFO( 12, GreaterOrEqual, typeNameListLogicalComparison, supportedTypeListLogicalComparison9, DmlGraphSupport::Supported)}, + {REG_INFO( 12, LessOrEqual, typeNameListLogicalComparison, supportedTypeListLogicalComparison9, DmlGraphSupport::Supported)}, + {REG_INFO( 7, Equal, typeNameListLogicalComparison, supportedTypeListLogicalComparison7, DmlGraphSupport::Supported)}, + {REG_INFO( 11, Equal, typeNameListLogicalComparison, supportedTypeListLogicalComparison9, DmlGraphSupport::Supported)}, + {REG_INFO( 7, Not, typeNameListDefault, supportedTypeListBool, DmlGraphSupport::Supported)}, + {REG_INFO( 7, And, typeNameListDefault, supportedTypeListBool, DmlGraphSupport::Supported)}, + {REG_INFO( 7, Or, typeNameListDefault, supportedTypeListBool, DmlGraphSupport::Supported)}, + {REG_INFO( 7, Xor, typeNameListDefault, supportedTypeListBool, DmlGraphSupport::Supported)}, // Imaging Operators - {REG_INFO( 7, Crop, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, - {REG_INFO( 7, ImageScaler, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, - {REG_INFO_VER( 7, Upsample, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, - {REG_INFO_VER( 9, Upsample, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported, requiredConstantCpuInputs(1) /*scales*/)}, - {REG_INFO_VER( 10, Upsample, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported, requiredConstantCpuInputs(1) /*scales*/)}, - {REG_INFO_VER( 10, Resize, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported, requiredConstantCpuInputs(1) /*scales*/)}, - {REG_INFO_VER( 11, Resize, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported, requiredConstantCpuInputs(1, 2, 3) /*roi, scales, sizes*/, std::nullopt, QueryResize)}, + {REG_INFO( 7, Crop, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, + {REG_INFO( 7, ImageScaler, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, + {REG_INFO_VER( 7, Upsample, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, + {REG_INFO_VER( 9, Upsample, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported, requiredConstantCpuInputs(1) /*scales*/)}, + {REG_INFO_VER( 10, Upsample, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported, requiredConstantCpuInputs(1) /*scales*/)}, + {REG_INFO_VER( 10, Resize, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported, requiredConstantCpuInputs(1) /*scales*/)}, + {REG_INFO_VER( 11, Resize, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported, requiredConstantCpuInputs(1, 2, 3) /*roi, scales, sizes*/, std::nullopt, QueryResize)}, // Activation Functions - {REG_INFO( 7, Sigmoid, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, - {REG_INFO( 7, HardSigmoid, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, - {REG_INFO( 7, Tanh, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, - {REG_INFO( 7, ScaledTanh, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, - {REG_INFO( 7, Relu, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, - {REG_INFO( 7, LeakyRelu, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, - {REG_INFO( 7, PRelu, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, - {REG_INFO( 9, PRelu, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, - {REG_INFO( 7, ThresholdedRelu, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, - {REG_INFO( 10, ThresholdedRelu, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, - {REG_INFO( 7, Elu, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, - {REG_INFO( 12, Celu, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, - {REG_INFO( 7, Selu, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, - {REG_INFO( 7, Softmax, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, - {REG_INFO( 11, Softmax, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, - {REG_INFO( 7, LogSoftmax, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, - {REG_INFO( 11, LogSoftmax, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, - {REG_INFO( 7, Hardmax, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, - {REG_INFO( 11, Hardmax, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, - {REG_INFO( 7, Softsign, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, - {REG_INFO( 7, Softplus, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, - {REG_INFO( 7, ParametricSoftplus, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, - {REG_INFO( 7, Dropout, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, - {REG_INFO( 9, Shrink, typeNameListDefault, supportedTypeListNumericDefault, DmlGraphSupport::Supported)}, + {REG_INFO( 7, Sigmoid, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, + {REG_INFO( 7, HardSigmoid, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, + {REG_INFO( 7, Tanh, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, + {REG_INFO( 7, ScaledTanh, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, + {REG_INFO( 7, Relu, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, + {REG_INFO( 7, LeakyRelu, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, + {REG_INFO( 7, PRelu, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, + {REG_INFO( 9, PRelu, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, + {REG_INFO( 7, ThresholdedRelu, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, + {REG_INFO( 10, ThresholdedRelu, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, + {REG_INFO( 7, Elu, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, + {REG_INFO( 12, Celu, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, + {REG_INFO( 7, Selu, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, + {REG_INFO( 7, Softmax, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, + {REG_INFO( 11, Softmax, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, + {REG_INFO( 7, LogSoftmax, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, + {REG_INFO( 11, LogSoftmax, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, + {REG_INFO( 7, Hardmax, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, + {REG_INFO( 11, Hardmax, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, + {REG_INFO( 7, Softsign, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, + {REG_INFO( 7, Softplus, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, + {REG_INFO( 7, ParametricSoftplus, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, + {REG_INFO( 7, Dropout, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, + {REG_INFO( 9, Shrink, typeNameListDefault, supportedTypeListNumericDefault, DmlGraphSupport::Supported)}, // Uncategorized - {REG_INFO( 7, MatMul, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, - {REG_INFO( 9, MatMul, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, - {REG_INFO( 7, Cast, typeNameListTwo, supportedTypeListCast, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides)}, - {REG_INFO( 9, Cast, typeNameListTwo, supportedTypeListCast, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides)}, + {REG_INFO( 7, MatMul, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, + {REG_INFO( 9, MatMul, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, + {REG_INFO( 7, Cast, typeNameListTwo, supportedTypeListCast, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides)}, + {REG_INFO( 9, Cast, typeNameListTwo, supportedTypeListCast, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides)}, {REG_INFO( 7, MemcpyFromHost, typeNameListDefault, supportedTypeListAll)}, {REG_INFO( 7, MemcpyToHost, typeNameListDefault, supportedTypeListAll)}, - {REG_INFO_VER( 7, TopK, typeNameListTopK, supportedTypeListTopK, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides)}, - {REG_INFO_VER( 10, TopK, typeNameListTopK, supportedTypeListTopK, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides, requiredConstantCpuInputs(1))}, - {REG_INFO_VER( 11, TopK, typeNameListTopK, supportedTypeListTopK, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides, requiredConstantCpuInputs(1))}, - {REG_INFO( 9, OneHot, typeNameListThree, supportedTypeListOneHot, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStridesFromAnyEp, requiredConstantCpuInputs(1))}, - {REG_INFO( 11, OneHot, typeNameListThree, supportedTypeListOneHot, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStridesFromAnyEp, requiredConstantCpuInputs(1))}, + {REG_INFO_VER( 7, TopK, typeNameListTopK, supportedTypeListTopK, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides)}, + {REG_INFO_VER( 10, TopK, typeNameListTopK, supportedTypeListTopK, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides, requiredConstantCpuInputs(1))}, + {REG_INFO_VER( 11, TopK, typeNameListTopK, supportedTypeListTopK, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides, requiredConstantCpuInputs(1))}, + {REG_INFO( 9, OneHot, typeNameListThree, supportedTypeListOneHot, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStridesFromAnyEp, requiredConstantCpuInputs(1))}, + {REG_INFO( 11, OneHot, typeNameListThree, supportedTypeListOneHot, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStridesFromAnyEp, requiredConstantCpuInputs(1))}, // Fused operators - {REG_INFO_MSDML(1, FusedConv, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, - {REG_INFO_MSDML(1, FusedConvTranspose, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, - {REG_INFO_MSDML(1, FusedInstanceNormalization, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, - {REG_INFO_MSDML(1, FusedBatchNormalization, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, - {REG_INFO_MSDML(1, FusedMeanVarianceNormalization, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, - {REG_INFO_MSDML(1, FusedGemm, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, - {REG_INFO_MSDML(1, FusedMatMul, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, - {REG_INFO_MSDML(1, FusedAdd, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, - {REG_INFO_MSDML(1, FusedSum, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported, requiredConstantCpuInputs(), 2)}, + {REG_INFO_MSDML(1, FusedConv, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, + {REG_INFO_MSDML(1, FusedConvTranspose, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, + {REG_INFO_MSDML(1, FusedInstanceNormalization, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, + {REG_INFO_MSDML(1, FusedBatchNormalization, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, + {REG_INFO_MSDML(1, FusedMeanVarianceNormalization, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, + {REG_INFO_MSDML(1, FusedGemm, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, + {REG_INFO_MSDML(1, FusedMatMul, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, + {REG_INFO_MSDML(1, FusedAdd, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, + {REG_INFO_MSDML(1, FusedSum, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported, requiredConstantCpuInputs(), 2)}, - {REG_INFO( 10, IsInf, typeNameListTwo, supportedTypeListIsInf, DmlGraphSupport::Supported)}, - {REG_INFO( 10, Mod, typeNameListDefault, supportedTypeListNumericDefault, DmlGraphSupport::Supported)}, - {REG_INFO( 11, BitShift, typeNameListDefault, supportedTypeListInt8to32, DmlGraphSupport::Supported)}, - {REG_INFO( 11, Round, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, - {REG_INFO( 10, ReverseSequence, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStridesFromAnyEp)}, - {REG_INFO( 11, CumSum, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported, requiredConstantCpuInputs(1))}, - {REG_INFO( 11, Range, typeNameListDefault, supportedTypeListRange, DmlGraphSupport::Supported, requiredConstantCpuInputs(0,1,2))}, + {REG_INFO( 10, IsInf, typeNameListTwo, supportedTypeListIsInf, DmlGraphSupport::Supported)}, + {REG_INFO( 10, Mod, typeNameListDefault, supportedTypeListNumericDefault, DmlGraphSupport::Supported)}, + {REG_INFO( 11, BitShift, typeNameListDefault, supportedTypeListInt8to32, DmlGraphSupport::Supported)}, + {REG_INFO( 11, Round, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, + {REG_INFO( 10, ReverseSequence, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStridesFromAnyEp)}, + {REG_INFO( 11, CumSum, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported, requiredConstantCpuInputs(1))}, + {REG_INFO( 11, Range, typeNameListDefault, supportedTypeListRange, DmlGraphSupport::Supported, requiredConstantCpuInputs(0,1,2))}, - {REG_INFO( 9, MaxUnpool, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStridesFromAnyEp, requiredConstantCpuInputs(2))}, - {REG_INFO( 11, MaxUnpool, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStridesFromAnyEp, requiredConstantCpuInputs(2))}, // 11 is identical to 9. + {REG_INFO( 9, MaxUnpool, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStridesFromAnyEp, requiredConstantCpuInputs(2))}, + {REG_INFO( 11, MaxUnpool, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStridesFromAnyEp, requiredConstantCpuInputs(2))}, // 11 is identical to 9. + + {REG_INFO_MS( 1, QLinearAdd, typeNameListDefault, supportedTypeListInteger8, DmlGraphSupport::Supported)}, + {REG_INFO( 10, QLinearConv, typeNameListFour, supportedTypeListQLinearConv, DmlGraphSupport::Supported)}, + {REG_INFO( 10, QLinearMatMul, typeNameListThree, supportedTypeListQLinearMatMul, DmlGraphSupport::Supported)}, + {REG_INFO( 10, MatMulInteger, typeNameListThree, supportedTypeListInteger, DmlGraphSupport::Supported)}, + {REG_INFO( 10, ConvInteger, typeNameListThree, supportedTypeListInteger, DmlGraphSupport::Supported)}, + {REG_INFO( 11, DynamicQuantizeLinear, typeNameListTwo, supportedTypeListDynamicQuantizeLinear, DmlGraphSupport::Supported)}, - {REG_INFO( 10, QLinearConv, typeNameListFour, supportedTypeListQLinearConv, DmlGraphSupport::Supported)}, - {REG_INFO( 10, QLinearMatMul, typeNameListThree, supportedTypeListQLinearMatMul, DmlGraphSupport::Supported)}, - {REG_INFO( 10, MatMulInteger, typeNameListThree, supportedTypeListInteger, DmlGraphSupport::Supported)}, - {REG_INFO( 10, ConvInteger, typeNameListThree, supportedTypeListInteger, DmlGraphSupport::Supported)}, }; template diff --git a/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.h b/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.h index 06250fddce..5772cceb01 100644 --- a/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.h +++ b/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.h @@ -1565,6 +1565,8 @@ using ShapeInferenceHelper_Identity = GetOutputShapeAsInputShapeHelper; using ShapeInferenceHelper_MatMul = MatMulHelper; using ShapeInferenceHelper_MatMulInteger = MatMulHelper; using ShapeInferenceHelper_QLinearMatMul = QLinearMatMulHelper; +using ShapeInferenceHelper_QLinearAdd = GetBroadcastedOutputShapeHelper; +using ShapeInferenceHelper_DynamicQuantizeLinear = GetOutputShapeAsInputShapeHelper; using ShapeInferenceHelper_Cast = GetOutputShapeAsInputShapeHelper; using ShapeInferenceHelper_MemcpyFromHost = GetOutputShapeAsInputShapeHelper; diff --git a/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorRegistration.h b/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorRegistration.h index 9d5c248559..4411a638b5 100644 --- a/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorRegistration.h +++ b/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorRegistration.h @@ -233,6 +233,7 @@ namespace OperatorHelper static const int sc_sinceVer_ReduceSumSquare = 11; static const int sc_sinceVer_Resize = 11; static const int sc_sinceVer_Round = 11; + static const int sc_sinceVer_DynamicQuantizeLinear = 11; static const int sc_sinceVer_Scan = 11; static const int sc_sinceVer_Scatter = 11; // Deprecated alias static const int sc_sinceVer_ScatterElements = 11; @@ -277,6 +278,7 @@ namespace OperatorHelper static const int sc_sinceVer_QuantizeLinear = 1; static const int sc_sinceVer_DequantizeLinear = 1; static const int sc_sinceVer_ConvTransposeWithDynamicPads = 1; + static const int sc_sinceVer_QLinearAdd = 1; } // namespace MsftOperatorSet1 } // namespace OperatorHelper