diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/inc/IWinmlExecutionProvider.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/inc/IWinmlExecutionProvider.h index da11068d34..d5c651a0c4 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/inc/IWinmlExecutionProvider.h +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/inc/IWinmlExecutionProvider.h @@ -122,6 +122,10 @@ namespace Windows::AI::MachineLearning::Adapter // Operator supports true 64-bit tensors directly, no strides needed. // So fallback to strided 32-bit only occurs when the device lacks 64-bit support. bool prefer64BitTensorsDirectly = false; + + // The operator supports emulation for uint64/int64 even if the hardware doesn't + // support native uint64/int64 data types. + bool support64BitTensorsViaEmulation = false; }; using InternalRegistrationInfoMap = std::unordered_map>; diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/AbiCustomRegistry.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/AbiCustomRegistry.cpp index ff7ab0ea03..02733453f7 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/AbiCustomRegistry.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/AbiCustomRegistry.cpp @@ -345,6 +345,7 @@ HRESULT STDMETHODCALLTYPE AbiCustomRegistry::RegisterOperatorKernel( bool supportedWith64BitTensorsVia32BitStrides, bool supportedWith64BitTensorsVia32BitStridesFromAnyEp, bool prefer64BitTensorsDirectly, + bool support64BitTensorsViaEmulation, _In_reads_(constantCpuInputCount) const uint32_t* requiredConstantCpuInputs, uint32_t constantCpuInputCount) const noexcept { @@ -472,6 +473,7 @@ HRESULT STDMETHODCALLTYPE AbiCustomRegistry::RegisterOperatorKernel( regInfo->supportedWith64BitTensorsVia32BitStrides = supportedWith64BitTensorsVia32BitStrides; regInfo->supportedWith64BitTensorsVia32BitStridesFromAnyEp = supportedWith64BitTensorsVia32BitStridesFromAnyEp; regInfo->prefer64BitTensorsDirectly = prefer64BitTensorsDirectly; + regInfo->support64BitTensorsViaEmulation = support64BitTensorsViaEmulation; // Only internal operators support usage in DML graphs if (supportsGraph) @@ -546,7 +548,8 @@ HRESULT STDMETHODCALLTYPE AbiCustomRegistry::RegisterOperatorKernel( requiredConstantCpuInputs || supportedWith64BitTensorsVia32BitStrides || supportedWith64BitTensorsVia32BitStridesFromAnyEp || - prefer64BitTensorsDirectly) + prefer64BitTensorsDirectly || + support64BitTensorsViaEmulation) { ORT_THROW_HR(E_INVALIDARG); } diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/AbiCustomRegistry.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/AbiCustomRegistry.h index 78d17418ef..2482d3af8b 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/AbiCustomRegistry.h +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/AbiCustomRegistry.h @@ -44,6 +44,7 @@ class AbiCustomRegistry : public WRL::Base(i) }; DML_FEATURE_DATA_TENSOR_DATA_TYPE_SUPPORT dataTypeSupport = {}; diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlCommon.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlCommon.h index 0f0c533558..e43b982364 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlCommon.h +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlCommon.h @@ -96,4 +96,24 @@ namespace Dml return minimumImpliedSizeInBytes; } + + template + void CastToClampedScalarUnion(DML_TENSOR_DATA_TYPE dataType, T value, DML_SCALAR_UNION* outputValue) + { + switch (dataType) + { + case DML_TENSOR_DATA_TYPE_UINT8: outputValue->UInt8 = clamp_cast(value); break; + case DML_TENSOR_DATA_TYPE_UINT16: outputValue->UInt16 = clamp_cast(value); break; + case DML_TENSOR_DATA_TYPE_UINT32: outputValue->UInt32 = clamp_cast(value); break; + case DML_TENSOR_DATA_TYPE_UINT64: outputValue->UInt64 = clamp_cast(value); break; + case DML_TENSOR_DATA_TYPE_INT8: outputValue->Int8 = clamp_cast(value); break; + case DML_TENSOR_DATA_TYPE_INT16: outputValue->Int16 = clamp_cast(value); break; + case DML_TENSOR_DATA_TYPE_INT32: outputValue->Int32 = clamp_cast(value); break; + case DML_TENSOR_DATA_TYPE_INT64: outputValue->Int64 = clamp_cast(value); break; + case DML_TENSOR_DATA_TYPE_FLOAT16: outputValue->Float32 = clamp_cast(value); break; + case DML_TENSOR_DATA_TYPE_FLOAT32: outputValue->Float32 = clamp_cast(value); break; + case DML_TENSOR_DATA_TYPE_FLOAT64: outputValue->Float64 = clamp_cast(value); break; + default: assert(false); + } + } } // namespace Dml diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/ApiTraits.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/ApiTraits.h index 776bc0569f..dd00eb4a80 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/ApiTraits.h +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/ApiTraits.h @@ -24,7 +24,7 @@ struct EnumTraits template <> struct EnumTraits { - static constexpr auto ValueCount = 141; + static constexpr auto ValueCount = 153; static constexpr size_t ActivationFunctionCount = 20; }; @@ -86,7 +86,7 @@ struct EnumTraits template <> struct EnumTraits { - static constexpr auto ValueCount = 4; + static constexpr auto ValueCount = 8; }; template <> @@ -225,6 +225,24 @@ struct OperatorDescTraits static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_ELEMENT_WISE_CLIP; }; +template <> +struct OperatorDescTraits +{ + static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_ELEMENT_WISE_CLIP1; +}; + +template <> +struct OperatorDescTraits +{ + static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_ELEMENT_WISE_CLIP_GRAD; +}; + +template <> +struct OperatorDescTraits +{ + static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_ELEMENT_WISE_CLIP_GRAD1; +}; + template <> struct OperatorDescTraits { @@ -363,6 +381,18 @@ struct OperatorDescTraits static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_ELEMENT_WISE_SQRT; }; +template <> +struct OperatorDescTraits +{ + static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_ELEMENT_WISE_DIFFERENCE_SQUARE; +}; + +template <> +struct OperatorDescTraits +{ + static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_ELEMENT_WISE_ATAN_YX; +}; + template <> struct OperatorDescTraits { @@ -483,6 +513,12 @@ struct OperatorDescTraits static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_PADDING; }; +template <> +struct OperatorDescTraits +{ + static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_PADDING1; +}; + template <> struct OperatorDescTraits { @@ -531,6 +567,18 @@ struct OperatorDescTraits static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_BATCH_NORMALIZATION; }; +template <> +struct OperatorDescTraits +{ + static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_BATCH_NORMALIZATION_GRAD; +}; + +template <> +struct OperatorDescTraits +{ + static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_BATCH_NORMALIZATION_TRAINING_GRAD; +}; + template <> struct OperatorDescTraits { @@ -543,6 +591,12 @@ struct OperatorDescTraits static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_LOCAL_RESPONSE_NORMALIZATION; }; +template <> +struct OperatorDescTraits +{ + static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_LOCAL_RESPONSE_NORMALIZATION_GRAD; +}; + template <> struct OperatorDescTraits { @@ -579,6 +633,12 @@ struct OperatorDescTraits static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_ELEMENT_WISE_IS_NAN; }; +template <> +struct OperatorDescTraits +{ + static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_ELEMENT_WISE_NEGATE; +}; + template <> struct OperatorDescTraits { @@ -717,6 +777,12 @@ struct OperatorDescTraits static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_CUMULATIVE_SUMMATION; }; +template <> +struct OperatorDescTraits +{ + static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_CUMULATIVE_PRODUCT; +}; + template <> struct OperatorDescTraits { @@ -891,12 +957,42 @@ struct OperatorDescTraits static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_ROI_ALIGN; }; +template <> +struct OperatorDescTraits +{ + static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_ROI_ALIGN1; +}; + template <> struct OperatorDescTraits { static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_GATHER_ND1; }; +template <> +struct OperatorDescTraits +{ + static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_DYNAMIC_QUANTIZE_LINEAR; +}; + +template <> +struct OperatorDescTraits +{ + static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_ELEMENT_WISE_QUANTIZED_LINEAR_ADD; +}; + +template <> +struct OperatorDescTraits +{ + static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_ROI_ALIGN_GRAD; +}; + +template <> +struct OperatorDescTraits +{ + static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_BATCH_NORMALIZATION_TRAINING; +}; + template <> struct OperatorDescTraits { @@ -1071,6 +1167,24 @@ struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_ELEMENT_WISE_CLIP> using DescType = DML_ELEMENT_WISE_CLIP_OPERATOR_DESC; }; +template <> +struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_ELEMENT_WISE_CLIP1> +{ + using DescType = DML_ELEMENT_WISE_CLIP1_OPERATOR_DESC; +}; + +template <> +struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_ELEMENT_WISE_CLIP_GRAD> +{ + using DescType = DML_ELEMENT_WISE_CLIP_GRAD_OPERATOR_DESC; +}; + +template <> +struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_ELEMENT_WISE_CLIP_GRAD1> +{ + using DescType = DML_ELEMENT_WISE_CLIP_GRAD1_OPERATOR_DESC; +}; + template <> struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_ELEMENT_WISE_COS> { @@ -1209,6 +1323,18 @@ struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_ELEMENT_WISE_SQRT> using DescType = DML_ELEMENT_WISE_SQRT_OPERATOR_DESC; }; +template <> +struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_ELEMENT_WISE_DIFFERENCE_SQUARE> +{ + using DescType = DML_ELEMENT_WISE_DIFFERENCE_SQUARE_OPERATOR_DESC; +}; + +template <> +struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_ELEMENT_WISE_ATAN_YX> +{ + using DescType = DML_ELEMENT_WISE_ATAN_YX_OPERATOR_DESC; +}; + template <> struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_ELEMENT_WISE_SUBTRACT> { @@ -1329,6 +1455,12 @@ struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_PADDING> using DescType = DML_PADDING_OPERATOR_DESC; }; +template <> +struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_PADDING1> +{ + using DescType = DML_PADDING1_OPERATOR_DESC; +}; + template <> struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_VALUE_SCALE_2D> { @@ -1377,6 +1509,18 @@ struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_BATCH_NORMALIZATION> using DescType = DML_BATCH_NORMALIZATION_OPERATOR_DESC; }; +template <> +struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_BATCH_NORMALIZATION_GRAD> +{ + using DescType = DML_BATCH_NORMALIZATION_GRAD_OPERATOR_DESC; +}; + +template <> +struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_BATCH_NORMALIZATION_TRAINING_GRAD> +{ + using DescType = DML_BATCH_NORMALIZATION_TRAINING_GRAD_OPERATOR_DESC; +}; + template <> struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_MEAN_VARIANCE_NORMALIZATION> { @@ -1389,6 +1533,12 @@ struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_LOCAL_RESPONSE_NORMALI using DescType = DML_LOCAL_RESPONSE_NORMALIZATION_OPERATOR_DESC; }; +template <> +struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_LOCAL_RESPONSE_NORMALIZATION_GRAD> +{ + using DescType = DML_LOCAL_RESPONSE_NORMALIZATION_GRAD_OPERATOR_DESC; +}; + template <> struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_LP_NORMALIZATION> { @@ -1425,6 +1575,12 @@ struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_ELEMENT_WISE_IS_NAN> using DescType = DML_ELEMENT_WISE_IS_NAN_OPERATOR_DESC; }; +template <> +struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_ELEMENT_WISE_NEGATE> +{ + using DescType = DML_ELEMENT_WISE_NEGATE_OPERATOR_DESC; +}; + template <> struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_ELEMENT_WISE_ERF> { @@ -1563,6 +1719,12 @@ struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_CUMULATIVE_SUMMATION> using DescType = DML_CUMULATIVE_SUMMATION_OPERATOR_DESC; }; +template <> +struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_CUMULATIVE_PRODUCT> +{ + using DescType = DML_CUMULATIVE_PRODUCT_OPERATOR_DESC; +}; + template <> struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_REVERSE_SUBSEQUENCES> { @@ -1737,12 +1899,42 @@ struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_ROI_ALIGN> using DescType = DML_ROI_ALIGN_OPERATOR_DESC; }; +template <> +struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_ROI_ALIGN1> +{ + using DescType = DML_ROI_ALIGN1_OPERATOR_DESC; +}; + template <> struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_GATHER_ND1> { using DescType = DML_GATHER_ND1_OPERATOR_DESC; }; +template <> +struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_DYNAMIC_QUANTIZE_LINEAR> +{ + using DescType = DML_DYNAMIC_QUANTIZE_LINEAR_OPERATOR_DESC; +}; + +template <> +struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_ELEMENT_WISE_QUANTIZED_LINEAR_ADD> +{ + using DescType = DML_ELEMENT_WISE_QUANTIZED_LINEAR_ADD_OPERATOR_DESC; +}; + +template <> +struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_ROI_ALIGN_GRAD> +{ + using DescType = DML_ROI_ALIGN_GRAD_OPERATOR_DESC; +}; + +template <> +struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_BATCH_NORMALIZATION_TRAINING> +{ + using DescType = DML_BATCH_NORMALIZATION_TRAINING_OPERATOR_DESC; +}; + template <> struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_ACTIVATION_ELU> { @@ -1894,6 +2086,12 @@ auto OperatorTypeVisitor(DML_OPERATOR_TYPE type, Visitor&& visitor, Ts&&... args return std::invoke(std::forward(visitor), DML_ELEMENT_WISE_CEIL_OPERATOR_DESC{}, std::forward(args)...); case DML_OPERATOR_ELEMENT_WISE_CLIP: return std::invoke(std::forward(visitor), DML_ELEMENT_WISE_CLIP_OPERATOR_DESC{}, std::forward(args)...); + case DML_OPERATOR_ELEMENT_WISE_CLIP1: + return std::invoke(std::forward(visitor), DML_ELEMENT_WISE_CLIP1_OPERATOR_DESC{}, std::forward(args)...); + case DML_OPERATOR_ELEMENT_WISE_CLIP_GRAD: + return std::invoke(std::forward(visitor), DML_ELEMENT_WISE_CLIP_GRAD_OPERATOR_DESC{}, std::forward(args)...); + case DML_OPERATOR_ELEMENT_WISE_CLIP_GRAD1: + return std::invoke(std::forward(visitor), DML_ELEMENT_WISE_CLIP_GRAD1_OPERATOR_DESC{}, std::forward(args)...); case DML_OPERATOR_ELEMENT_WISE_COS: return std::invoke(std::forward(visitor), DML_ELEMENT_WISE_COS_OPERATOR_DESC{}, std::forward(args)...); case DML_OPERATOR_ELEMENT_WISE_DIVIDE: @@ -1940,6 +2138,10 @@ auto OperatorTypeVisitor(DML_OPERATOR_TYPE type, Visitor&& visitor, Ts&&... args return std::invoke(std::forward(visitor), DML_ELEMENT_WISE_SIN_OPERATOR_DESC{}, std::forward(args)...); case DML_OPERATOR_ELEMENT_WISE_SQRT: return std::invoke(std::forward(visitor), DML_ELEMENT_WISE_SQRT_OPERATOR_DESC{}, std::forward(args)...); + case DML_OPERATOR_ELEMENT_WISE_DIFFERENCE_SQUARE: + return std::invoke(std::forward(visitor), DML_ELEMENT_WISE_DIFFERENCE_SQUARE_OPERATOR_DESC{}, std::forward(args)...); + case DML_OPERATOR_ELEMENT_WISE_ATAN_YX: + return std::invoke(std::forward(visitor), DML_ELEMENT_WISE_ATAN_YX_OPERATOR_DESC{}, std::forward(args)...); case DML_OPERATOR_ELEMENT_WISE_SUBTRACT: return std::invoke(std::forward(visitor), DML_ELEMENT_WISE_SUBTRACT_OPERATOR_DESC{}, std::forward(args)...); case DML_OPERATOR_ELEMENT_WISE_TAN: @@ -1980,6 +2182,8 @@ auto OperatorTypeVisitor(DML_OPERATOR_TYPE type, Visitor&& visitor, Ts&&... args return std::invoke(std::forward(visitor), DML_JOIN_OPERATOR_DESC{}, std::forward(args)...); case DML_OPERATOR_PADDING: return std::invoke(std::forward(visitor), DML_PADDING_OPERATOR_DESC{}, std::forward(args)...); + case DML_OPERATOR_PADDING1: + return std::invoke(std::forward(visitor), DML_PADDING1_OPERATOR_DESC{}, std::forward(args)...); case DML_OPERATOR_VALUE_SCALE_2D: return std::invoke(std::forward(visitor), DML_VALUE_SCALE_2D_OPERATOR_DESC{}, std::forward(args)...); case DML_OPERATOR_UPSAMPLE_2D: @@ -1996,10 +2200,16 @@ auto OperatorTypeVisitor(DML_OPERATOR_TYPE type, Visitor&& visitor, Ts&&... args return std::invoke(std::forward(visitor), DML_TOP_K_OPERATOR_DESC{}, std::forward(args)...); case DML_OPERATOR_BATCH_NORMALIZATION: return std::invoke(std::forward(visitor), DML_BATCH_NORMALIZATION_OPERATOR_DESC{}, std::forward(args)...); + case DML_OPERATOR_BATCH_NORMALIZATION_GRAD: + return std::invoke(std::forward(visitor), DML_BATCH_NORMALIZATION_GRAD_OPERATOR_DESC{}, std::forward(args)...); + case DML_OPERATOR_BATCH_NORMALIZATION_TRAINING_GRAD: + return std::invoke(std::forward(visitor), DML_BATCH_NORMALIZATION_TRAINING_GRAD_OPERATOR_DESC{}, std::forward(args)...); case DML_OPERATOR_MEAN_VARIANCE_NORMALIZATION: return std::invoke(std::forward(visitor), DML_MEAN_VARIANCE_NORMALIZATION_OPERATOR_DESC{}, std::forward(args)...); case DML_OPERATOR_LOCAL_RESPONSE_NORMALIZATION: return std::invoke(std::forward(visitor), DML_LOCAL_RESPONSE_NORMALIZATION_OPERATOR_DESC{}, std::forward(args)...); + case DML_OPERATOR_LOCAL_RESPONSE_NORMALIZATION_GRAD: + return std::invoke(std::forward(visitor), DML_LOCAL_RESPONSE_NORMALIZATION_GRAD_OPERATOR_DESC{}, std::forward(args)...); case DML_OPERATOR_LP_NORMALIZATION: return std::invoke(std::forward(visitor), DML_LP_NORMALIZATION_OPERATOR_DESC{}, std::forward(args)...); case DML_OPERATOR_RNN: @@ -2012,6 +2222,8 @@ auto OperatorTypeVisitor(DML_OPERATOR_TYPE type, Visitor&& visitor, Ts&&... args return std::invoke(std::forward(visitor), DML_ELEMENT_WISE_SIGN_OPERATOR_DESC{}, std::forward(args)...); case DML_OPERATOR_ELEMENT_WISE_IS_NAN: return std::invoke(std::forward(visitor), DML_ELEMENT_WISE_IS_NAN_OPERATOR_DESC{}, std::forward(args)...); + case DML_OPERATOR_ELEMENT_WISE_NEGATE: + return std::invoke(std::forward(visitor), DML_ELEMENT_WISE_NEGATE_OPERATOR_DESC{}, std::forward(args)...); case DML_OPERATOR_ELEMENT_WISE_ERF: return std::invoke(std::forward(visitor), DML_ELEMENT_WISE_ERF_OPERATOR_DESC{}, std::forward(args)...); case DML_OPERATOR_ELEMENT_WISE_SINH: @@ -2058,6 +2270,8 @@ auto OperatorTypeVisitor(DML_OPERATOR_TYPE type, Visitor&& visitor, Ts&&... args return std::invoke(std::forward(visitor), DML_FILL_VALUE_SEQUENCE_OPERATOR_DESC{}, std::forward(args)...); case DML_OPERATOR_CUMULATIVE_SUMMATION: return std::invoke(std::forward(visitor), DML_CUMULATIVE_SUMMATION_OPERATOR_DESC{}, std::forward(args)...); + case DML_OPERATOR_CUMULATIVE_PRODUCT: + return std::invoke(std::forward(visitor), DML_CUMULATIVE_PRODUCT_OPERATOR_DESC{}, std::forward(args)...); case DML_OPERATOR_REVERSE_SUBSEQUENCES: return std::invoke(std::forward(visitor), DML_REVERSE_SUBSEQUENCES_OPERATOR_DESC{}, std::forward(args)...); case DML_OPERATOR_GATHER_ELEMENTS: @@ -2116,8 +2330,18 @@ auto OperatorTypeVisitor(DML_OPERATOR_TYPE type, Visitor&& visitor, Ts&&... args return std::invoke(std::forward(visitor), DML_ADAM_OPTIMIZER_OPERATOR_DESC{}, std::forward(args)...); case DML_OPERATOR_ROI_ALIGN: return std::invoke(std::forward(visitor), DML_ROI_ALIGN_OPERATOR_DESC{}, std::forward(args)...); + case DML_OPERATOR_ROI_ALIGN1: + return std::invoke(std::forward(visitor), DML_ROI_ALIGN1_OPERATOR_DESC{}, std::forward(args)...); case DML_OPERATOR_GATHER_ND1: return std::invoke(std::forward(visitor), DML_GATHER_ND1_OPERATOR_DESC{}, std::forward(args)...); + case DML_OPERATOR_DYNAMIC_QUANTIZE_LINEAR: + return std::invoke(std::forward(visitor), DML_DYNAMIC_QUANTIZE_LINEAR_OPERATOR_DESC{}, std::forward(args)...); + case DML_OPERATOR_ELEMENT_WISE_QUANTIZED_LINEAR_ADD: + return std::invoke(std::forward(visitor), DML_ELEMENT_WISE_QUANTIZED_LINEAR_ADD_OPERATOR_DESC{}, std::forward(args)...); + case DML_OPERATOR_ROI_ALIGN_GRAD: + return std::invoke(std::forward(visitor), DML_ROI_ALIGN_GRAD_OPERATOR_DESC{}, std::forward(args)...); + case DML_OPERATOR_BATCH_NORMALIZATION_TRAINING: + return std::invoke(std::forward(visitor), DML_BATCH_NORMALIZATION_TRAINING_OPERATOR_DESC{}, std::forward(args)...); case DML_OPERATOR_ACTIVATION_ELU: return std::invoke(std::forward(visitor), DML_ACTIVATION_ELU_OPERATOR_DESC{}, std::forward(args)...); case DML_OPERATOR_ACTIVATION_CELU: @@ -2180,6 +2404,9 @@ inline gsl::czstring ToString(DML_OPERATOR_TYPE value) case DML_OPERATOR_ELEMENT_WISE_ATAN: return "DML_OPERATOR_ELEMENT_WISE_ATAN"; case DML_OPERATOR_ELEMENT_WISE_CEIL: return "DML_OPERATOR_ELEMENT_WISE_CEIL"; case DML_OPERATOR_ELEMENT_WISE_CLIP: return "DML_OPERATOR_ELEMENT_WISE_CLIP"; + case DML_OPERATOR_ELEMENT_WISE_CLIP1: return "DML_OPERATOR_ELEMENT_WISE_CLIP1"; + case DML_OPERATOR_ELEMENT_WISE_CLIP_GRAD: return "DML_OPERATOR_ELEMENT_WISE_CLIP_GRAD"; + case DML_OPERATOR_ELEMENT_WISE_CLIP_GRAD1: return "DML_OPERATOR_ELEMENT_WISE_CLIP_GRAD1"; case DML_OPERATOR_ELEMENT_WISE_COS: return "DML_OPERATOR_ELEMENT_WISE_COS"; case DML_OPERATOR_ELEMENT_WISE_DIVIDE: return "DML_OPERATOR_ELEMENT_WISE_DIVIDE"; case DML_OPERATOR_ELEMENT_WISE_EXP: return "DML_OPERATOR_ELEMENT_WISE_EXP"; @@ -2203,6 +2430,8 @@ inline gsl::czstring ToString(DML_OPERATOR_TYPE value) case DML_OPERATOR_ELEMENT_WISE_RECIP: return "DML_OPERATOR_ELEMENT_WISE_RECIP"; case DML_OPERATOR_ELEMENT_WISE_SIN: return "DML_OPERATOR_ELEMENT_WISE_SIN"; case DML_OPERATOR_ELEMENT_WISE_SQRT: return "DML_OPERATOR_ELEMENT_WISE_SQRT"; + case DML_OPERATOR_ELEMENT_WISE_DIFFERENCE_SQUARE: return "DML_OPERATOR_ELEMENT_WISE_DIFFERENCE_SQUARE"; + case DML_OPERATOR_ELEMENT_WISE_ATAN_YX: return "DML_OPERATOR_ELEMENT_WISE_ATAN_YX"; case DML_OPERATOR_ELEMENT_WISE_SUBTRACT: return "DML_OPERATOR_ELEMENT_WISE_SUBTRACT"; case DML_OPERATOR_ELEMENT_WISE_TAN: return "DML_OPERATOR_ELEMENT_WISE_TAN"; case DML_OPERATOR_ELEMENT_WISE_THRESHOLD: return "DML_OPERATOR_ELEMENT_WISE_THRESHOLD"; @@ -2223,6 +2452,7 @@ inline gsl::czstring ToString(DML_OPERATOR_TYPE value) case DML_OPERATOR_SPLIT: return "DML_OPERATOR_SPLIT"; case DML_OPERATOR_JOIN: return "DML_OPERATOR_JOIN"; case DML_OPERATOR_PADDING: return "DML_OPERATOR_PADDING"; + case DML_OPERATOR_PADDING1: return "DML_OPERATOR_PADDING1"; case DML_OPERATOR_VALUE_SCALE_2D: return "DML_OPERATOR_VALUE_SCALE_2D"; case DML_OPERATOR_UPSAMPLE_2D: return "DML_OPERATOR_UPSAMPLE_2D"; case DML_OPERATOR_GATHER: return "DML_OPERATOR_GATHER"; @@ -2231,14 +2461,18 @@ inline gsl::czstring ToString(DML_OPERATOR_TYPE value) case DML_OPERATOR_TILE: return "DML_OPERATOR_TILE"; case DML_OPERATOR_TOP_K: return "DML_OPERATOR_TOP_K"; case DML_OPERATOR_BATCH_NORMALIZATION: return "DML_OPERATOR_BATCH_NORMALIZATION"; + case DML_OPERATOR_BATCH_NORMALIZATION_GRAD: return "DML_OPERATOR_BATCH_NORMALIZATION_GRAD"; + case DML_OPERATOR_BATCH_NORMALIZATION_TRAINING_GRAD: return "DML_OPERATOR_BATCH_NORMALIZATION_TRAINING_GRAD"; case DML_OPERATOR_MEAN_VARIANCE_NORMALIZATION: return "DML_OPERATOR_MEAN_VARIANCE_NORMALIZATION"; case DML_OPERATOR_LOCAL_RESPONSE_NORMALIZATION: return "DML_OPERATOR_LOCAL_RESPONSE_NORMALIZATION"; + case DML_OPERATOR_LOCAL_RESPONSE_NORMALIZATION_GRAD: return "DML_OPERATOR_LOCAL_RESPONSE_NORMALIZATION_GRAD"; case DML_OPERATOR_LP_NORMALIZATION: return "DML_OPERATOR_LP_NORMALIZATION"; case DML_OPERATOR_RNN: return "DML_OPERATOR_RNN"; case DML_OPERATOR_LSTM: return "DML_OPERATOR_LSTM"; case DML_OPERATOR_GRU: return "DML_OPERATOR_GRU"; case DML_OPERATOR_ELEMENT_WISE_SIGN: return "DML_OPERATOR_ELEMENT_WISE_SIGN"; case DML_OPERATOR_ELEMENT_WISE_IS_NAN: return "DML_OPERATOR_ELEMENT_WISE_IS_NAN"; + case DML_OPERATOR_ELEMENT_WISE_NEGATE: return "DML_OPERATOR_ELEMENT_WISE_NEGATE"; case DML_OPERATOR_ELEMENT_WISE_ERF: return "DML_OPERATOR_ELEMENT_WISE_ERF"; case DML_OPERATOR_ELEMENT_WISE_SINH: return "DML_OPERATOR_ELEMENT_WISE_SINH"; case DML_OPERATOR_ELEMENT_WISE_COSH: return "DML_OPERATOR_ELEMENT_WISE_COSH"; @@ -2262,6 +2496,7 @@ inline gsl::czstring ToString(DML_OPERATOR_TYPE value) case DML_OPERATOR_FILL_VALUE_CONSTANT: return "DML_OPERATOR_FILL_VALUE_CONSTANT"; case DML_OPERATOR_FILL_VALUE_SEQUENCE: return "DML_OPERATOR_FILL_VALUE_SEQUENCE"; case DML_OPERATOR_CUMULATIVE_SUMMATION: return "DML_OPERATOR_CUMULATIVE_SUMMATION"; + case DML_OPERATOR_CUMULATIVE_PRODUCT: return "DML_OPERATOR_CUMULATIVE_PRODUCT"; case DML_OPERATOR_REVERSE_SUBSEQUENCES: return "DML_OPERATOR_REVERSE_SUBSEQUENCES"; case DML_OPERATOR_GATHER_ELEMENTS: return "DML_OPERATOR_GATHER_ELEMENTS"; case DML_OPERATOR_GATHER_ND: return "DML_OPERATOR_GATHER_ND"; @@ -2291,7 +2526,12 @@ inline gsl::czstring ToString(DML_OPERATOR_TYPE value) case DML_OPERATOR_SLICE_GRAD: return "DML_OPERATOR_SLICE_GRAD"; case DML_OPERATOR_ADAM_OPTIMIZER: return "DML_OPERATOR_ADAM_OPTIMIZER"; case DML_OPERATOR_ROI_ALIGN: return "DML_OPERATOR_ROI_ALIGN"; + case DML_OPERATOR_ROI_ALIGN1: return "DML_OPERATOR_ROI_ALIGN1"; case DML_OPERATOR_GATHER_ND1: return "DML_OPERATOR_GATHER_ND1"; + case DML_OPERATOR_DYNAMIC_QUANTIZE_LINEAR: return "DML_OPERATOR_DYNAMIC_QUANTIZE_LINEAR"; + case DML_OPERATOR_ELEMENT_WISE_QUANTIZED_LINEAR_ADD: return "DML_OPERATOR_ELEMENT_WISE_QUANTIZED_LINEAR_ADD"; + case DML_OPERATOR_ROI_ALIGN_GRAD: return "DML_OPERATOR_ROI_ALIGN_GRAD"; + case DML_OPERATOR_BATCH_NORMALIZATION_TRAINING: return "DML_OPERATOR_BATCH_NORMALIZATION_TRAINING"; default: assert(false); return ""; diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/DirectMLSchema.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/DirectMLSchema.h index ec3d24070b..32d1eda07e 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/DirectMLSchema.h +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/DirectMLSchema.h @@ -28,6 +28,7 @@ enum DML_SCHEMA_FIELD_TYPE DML_SCHEMA_FIELD_TYPE_SCALE_BIAS, DML_SCHEMA_FIELD_TYPE_SIZE_2D, DML_SCHEMA_FIELD_TYPE_SCALAR_UNION, + DML_SCHEMA_FIELD_TYPE_BOOL, }; enum DML_SCHEMA_OPERATOR_SUPPORT_FLAGS @@ -170,6 +171,56 @@ constexpr DML_OPERATOR_SCHEMA DML_ELEMENT_WISE_CLIP_OPERATOR_SCHEMA { DML_ELEMENT_WISE_CLIP_OPERATOR_SCHEMA_FIELDS, }; +constexpr DML_SCHEMA_FIELD DML_ELEMENT_WISE_CLIP1_OPERATOR_SCHEMA_FIELDS[6] { + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_SCALE_BIAS, "ScaleBias", true }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT, "MinMaxDataType", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_SCALAR_UNION, "Min", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_SCALAR_UNION, "Max", false }, +}; + +constexpr DML_OPERATOR_SCHEMA DML_ELEMENT_WISE_CLIP1_OPERATOR_SCHEMA { + "DML_OPERATOR_ELEMENT_WISE_CLIP1", + DML_OPERATOR_ELEMENT_WISE_CLIP1, + DML_SCHEMA_OPERATOR_SUPPORT_FLAG_IN_PLACE_EXECUTION, + 6, + DML_ELEMENT_WISE_CLIP1_OPERATOR_SCHEMA_FIELDS, +}; + +constexpr DML_SCHEMA_FIELD DML_ELEMENT_WISE_CLIP_GRAD_OPERATOR_SCHEMA_FIELDS[5] { + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputGradientTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputGradientTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_FLOAT, "Min", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_FLOAT, "Max", false }, +}; + +constexpr DML_OPERATOR_SCHEMA DML_ELEMENT_WISE_CLIP_GRAD_OPERATOR_SCHEMA { + "DML_OPERATOR_ELEMENT_WISE_CLIP_GRAD", + DML_OPERATOR_ELEMENT_WISE_CLIP_GRAD, + DML_SCHEMA_OPERATOR_SUPPORT_FLAG_IN_PLACE_EXECUTION, + 5, + DML_ELEMENT_WISE_CLIP_GRAD_OPERATOR_SCHEMA_FIELDS, +}; + +constexpr DML_SCHEMA_FIELD DML_ELEMENT_WISE_CLIP_GRAD1_OPERATOR_SCHEMA_FIELDS[6] { + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputGradientTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputGradientTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT, "MinMaxDataType", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_SCALAR_UNION, "Min", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_SCALAR_UNION, "Max", false }, +}; + +constexpr DML_OPERATOR_SCHEMA DML_ELEMENT_WISE_CLIP_GRAD1_OPERATOR_SCHEMA { + "DML_OPERATOR_ELEMENT_WISE_CLIP_GRAD1", + DML_OPERATOR_ELEMENT_WISE_CLIP_GRAD1, + DML_SCHEMA_OPERATOR_SUPPORT_FLAG_IN_PLACE_EXECUTION, + 6, + DML_ELEMENT_WISE_CLIP_GRAD1_OPERATOR_SCHEMA_FIELDS, +}; + constexpr DML_SCHEMA_FIELD DML_ELEMENT_WISE_COS_OPERATOR_SCHEMA_FIELDS[3] { DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false }, DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false }, @@ -493,6 +544,34 @@ constexpr DML_OPERATOR_SCHEMA DML_ELEMENT_WISE_SQRT_OPERATOR_SCHEMA { DML_ELEMENT_WISE_SQRT_OPERATOR_SCHEMA_FIELDS, }; +constexpr DML_SCHEMA_FIELD DML_ELEMENT_WISE_DIFFERENCE_SQUARE_OPERATOR_SCHEMA_FIELDS[3] { + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "ATensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "BTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false }, +}; + +constexpr DML_OPERATOR_SCHEMA DML_ELEMENT_WISE_DIFFERENCE_SQUARE_OPERATOR_SCHEMA { + "DML_OPERATOR_ELEMENT_WISE_DIFFERENCE_SQUARE", + DML_OPERATOR_ELEMENT_WISE_DIFFERENCE_SQUARE, + DML_SCHEMA_OPERATOR_SUPPORT_FLAG_IN_PLACE_EXECUTION, + 3, + DML_ELEMENT_WISE_DIFFERENCE_SQUARE_OPERATOR_SCHEMA_FIELDS, +}; + +constexpr DML_SCHEMA_FIELD DML_ELEMENT_WISE_ATAN_YX_OPERATOR_SCHEMA_FIELDS[3] { + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "ATensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "BTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false }, +}; + +constexpr DML_OPERATOR_SCHEMA DML_ELEMENT_WISE_ATAN_YX_OPERATOR_SCHEMA { + "DML_OPERATOR_ELEMENT_WISE_ATAN_YX", + DML_OPERATOR_ELEMENT_WISE_ATAN_YX, + DML_SCHEMA_OPERATOR_SUPPORT_FLAG_IN_PLACE_EXECUTION, + 3, + DML_ELEMENT_WISE_ATAN_YX_OPERATOR_SCHEMA_FIELDS, +}; + constexpr DML_SCHEMA_FIELD DML_ELEMENT_WISE_SUBTRACT_OPERATOR_SCHEMA_FIELDS[3] { DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "ATensor", false }, DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "BTensor", false }, @@ -828,6 +907,25 @@ constexpr DML_OPERATOR_SCHEMA DML_PADDING_OPERATOR_SCHEMA { DML_PADDING_OPERATOR_SCHEMA_FIELDS, }; +constexpr DML_SCHEMA_FIELD DML_PADDING1_OPERATOR_SCHEMA_FIELDS[8] { + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT, "PaddingMode", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT, "PaddingValueDataType", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_SCALAR_UNION, "PaddingValue", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT, "DimensionCount", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT_ARRAY, "StartPadding", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT_ARRAY, "EndPadding", false }, +}; + +constexpr DML_OPERATOR_SCHEMA DML_PADDING1_OPERATOR_SCHEMA { + "DML_OPERATOR_PADDING1", + DML_OPERATOR_PADDING1, + DML_SCHEMA_OPERATOR_SUPPORT_FLAG_NONE, + 8, + DML_PADDING1_OPERATOR_SCHEMA_FIELDS, +}; + constexpr DML_SCHEMA_FIELD DML_VALUE_SCALE_2D_OPERATOR_SCHEMA_FIELDS[5] { DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false }, DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false }, @@ -954,6 +1052,46 @@ constexpr DML_OPERATOR_SCHEMA DML_BATCH_NORMALIZATION_OPERATOR_SCHEMA { DML_BATCH_NORMALIZATION_OPERATOR_SCHEMA_FIELDS, }; +constexpr DML_SCHEMA_FIELD DML_BATCH_NORMALIZATION_GRAD_OPERATOR_SCHEMA_FIELDS[9] { + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputGradientTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "MeanTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "VarianceTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "ScaleTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputGradientTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputScaleGradientTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputBiasGradientTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_FLOAT, "Epsilon", true }, +}; + +constexpr DML_OPERATOR_SCHEMA DML_BATCH_NORMALIZATION_GRAD_OPERATOR_SCHEMA { + "DML_OPERATOR_BATCH_NORMALIZATION_GRAD", + DML_OPERATOR_BATCH_NORMALIZATION_GRAD, + DML_SCHEMA_OPERATOR_SUPPORT_FLAG_NONE, + 9, + DML_BATCH_NORMALIZATION_GRAD_OPERATOR_SCHEMA_FIELDS, +}; + +constexpr DML_SCHEMA_FIELD DML_BATCH_NORMALIZATION_TRAINING_GRAD_OPERATOR_SCHEMA_FIELDS[9] { + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputGradientTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "MeanTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "VarianceTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "ScaleTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputGradientTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputScaleGradientTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputBiasGradientTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_FLOAT, "Epsilon", true }, +}; + +constexpr DML_OPERATOR_SCHEMA DML_BATCH_NORMALIZATION_TRAINING_GRAD_OPERATOR_SCHEMA { + "DML_OPERATOR_BATCH_NORMALIZATION_TRAINING_GRAD", + DML_OPERATOR_BATCH_NORMALIZATION_TRAINING_GRAD, + DML_SCHEMA_OPERATOR_SUPPORT_FLAG_NONE, + 9, + DML_BATCH_NORMALIZATION_TRAINING_GRAD_OPERATOR_SCHEMA_FIELDS, +}; + constexpr DML_SCHEMA_FIELD DML_MEAN_VARIANCE_NORMALIZATION_OPERATOR_SCHEMA_FIELDS[8] { DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false }, DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "ScaleTensor", true }, @@ -991,6 +1129,25 @@ constexpr DML_OPERATOR_SCHEMA DML_LOCAL_RESPONSE_NORMALIZATION_OPERATOR_SCHEMA { DML_LOCAL_RESPONSE_NORMALIZATION_OPERATOR_SCHEMA_FIELDS, }; +constexpr DML_SCHEMA_FIELD DML_LOCAL_RESPONSE_NORMALIZATION_GRAD_OPERATOR_SCHEMA_FIELDS[8] { + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputGradientTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputGradientTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_BOOL, "CrossChannel", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT, "LocalSize", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_FLOAT, "Alpha", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_FLOAT, "Beta", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_FLOAT, "Bias", false }, +}; + +constexpr DML_OPERATOR_SCHEMA DML_LOCAL_RESPONSE_NORMALIZATION_GRAD_OPERATOR_SCHEMA { + "DML_OPERATOR_LOCAL_RESPONSE_NORMALIZATION_GRAD", + DML_OPERATOR_LOCAL_RESPONSE_NORMALIZATION_GRAD, + DML_SCHEMA_OPERATOR_SUPPORT_FLAG_NONE, + 8, + DML_LOCAL_RESPONSE_NORMALIZATION_GRAD_OPERATOR_SCHEMA_FIELDS, +}; + constexpr DML_SCHEMA_FIELD DML_LP_NORMALIZATION_OPERATOR_SCHEMA_FIELDS[5] { DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false }, DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false }, @@ -1106,6 +1263,19 @@ constexpr DML_OPERATOR_SCHEMA DML_ELEMENT_WISE_IS_NAN_OPERATOR_SCHEMA { DML_ELEMENT_WISE_IS_NAN_OPERATOR_SCHEMA_FIELDS, }; +constexpr DML_SCHEMA_FIELD DML_ELEMENT_WISE_NEGATE_OPERATOR_SCHEMA_FIELDS[2] { + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false }, +}; + +constexpr DML_OPERATOR_SCHEMA DML_ELEMENT_WISE_NEGATE_OPERATOR_SCHEMA { + "DML_OPERATOR_ELEMENT_WISE_NEGATE", + DML_OPERATOR_ELEMENT_WISE_NEGATE, + DML_SCHEMA_OPERATOR_SUPPORT_FLAG_NONE, + 2, + DML_ELEMENT_WISE_NEGATE_OPERATOR_SCHEMA_FIELDS, +}; + constexpr DML_SCHEMA_FIELD DML_ELEMENT_WISE_ERF_OPERATOR_SCHEMA_FIELDS[3] { DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false }, DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false }, @@ -1426,8 +1596,8 @@ constexpr DML_SCHEMA_FIELD DML_CUMULATIVE_SUMMATION_OPERATOR_SCHEMA_FIELDS[5] { DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false }, DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false }, DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT, "Axis", false }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT, "HasExclusiveSum", false }, DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT, "AxisDirection", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT, "HasExclusiveSum", false }, }; constexpr DML_OPERATOR_SCHEMA DML_CUMULATIVE_SUMMATION_OPERATOR_SCHEMA { @@ -1438,6 +1608,22 @@ constexpr DML_OPERATOR_SCHEMA DML_CUMULATIVE_SUMMATION_OPERATOR_SCHEMA { DML_CUMULATIVE_SUMMATION_OPERATOR_SCHEMA_FIELDS, }; +constexpr DML_SCHEMA_FIELD DML_CUMULATIVE_PRODUCT_OPERATOR_SCHEMA_FIELDS[5] { + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT, "Axis", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT, "AxisDirection", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT, "HasExclusiveProduct", false }, +}; + +constexpr DML_OPERATOR_SCHEMA DML_CUMULATIVE_PRODUCT_OPERATOR_SCHEMA { + "DML_OPERATOR_CUMULATIVE_PRODUCT", + DML_OPERATOR_CUMULATIVE_PRODUCT, + DML_SCHEMA_OPERATOR_SUPPORT_FLAG_IN_PLACE_EXECUTION, + 5, + DML_CUMULATIVE_PRODUCT_OPERATOR_SCHEMA_FIELDS, +}; + constexpr DML_SCHEMA_FIELD DML_REVERSE_SUBSEQUENCES_OPERATOR_SCHEMA_FIELDS[4] { DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false }, DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "SequenceLengthsTensor", false }, @@ -1448,7 +1634,7 @@ constexpr DML_SCHEMA_FIELD DML_REVERSE_SUBSEQUENCES_OPERATOR_SCHEMA_FIELDS[4] { constexpr DML_OPERATOR_SCHEMA DML_REVERSE_SUBSEQUENCES_OPERATOR_SCHEMA { "DML_OPERATOR_REVERSE_SUBSEQUENCES", DML_OPERATOR_REVERSE_SUBSEQUENCES, - DML_SCHEMA_OPERATOR_SUPPORT_FLAG_IN_PLACE_EXECUTION, + DML_SCHEMA_OPERATOR_SUPPORT_FLAG_NONE, 4, DML_REVERSE_SUBSEQUENCES_OPERATOR_SCHEMA_FIELDS, }; @@ -1809,17 +1995,23 @@ constexpr DML_OPERATOR_SCHEMA DML_AVERAGE_POOLING_GRAD_OPERATOR_SCHEMA { DML_AVERAGE_POOLING_GRAD_OPERATOR_SCHEMA_FIELDS, }; -constexpr DML_SCHEMA_FIELD DML_MAX_POOLING_GRAD_OPERATOR_SCHEMA_FIELDS[3] { +constexpr DML_SCHEMA_FIELD DML_MAX_POOLING_GRAD_OPERATOR_SCHEMA_FIELDS[9] { DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false }, DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputGradientTensor", false }, DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputGradientTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT, "DimensionCount", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT_ARRAY, "Strides", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT_ARRAY, "WindowSize", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT_ARRAY, "StartPadding", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT_ARRAY, "EndPadding", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT_ARRAY, "Dilations", false }, }; constexpr DML_OPERATOR_SCHEMA DML_MAX_POOLING_GRAD_OPERATOR_SCHEMA { "DML_OPERATOR_MAX_POOLING_GRAD", DML_OPERATOR_MAX_POOLING_GRAD, DML_SCHEMA_OPERATOR_SUPPORT_FLAG_NONE, - 3, + 9, DML_MAX_POOLING_GRAD_OPERATOR_SCHEMA_FIELDS, }; @@ -1932,6 +2124,31 @@ constexpr DML_OPERATOR_SCHEMA DML_ROI_ALIGN_OPERATOR_SCHEMA { DML_ROI_ALIGN_OPERATOR_SCHEMA_FIELDS, }; +constexpr DML_SCHEMA_FIELD DML_ROI_ALIGN1_OPERATOR_SCHEMA_FIELDS[14] { + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "ROITensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "BatchIndicesTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT, "ReductionFunction", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT, "InterpolationMode", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_FLOAT, "SpatialScaleX", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_FLOAT, "SpatialScaleY", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_FLOAT, "InputPixelOffset", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_FLOAT, "OutputPixelOffset", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_FLOAT, "OutOfBoundsInputValue", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT, "MinimumSamplesPerOutput", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT, "MaximumSamplesPerOutput", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT, "AlignRegionsToCorners", false }, +}; + +constexpr DML_OPERATOR_SCHEMA DML_ROI_ALIGN1_OPERATOR_SCHEMA { + "DML_OPERATOR_ROI_ALIGN1", + DML_OPERATOR_ROI_ALIGN1, + DML_SCHEMA_OPERATOR_SUPPORT_FLAG_NONE, + 14, + DML_ROI_ALIGN1_OPERATOR_SCHEMA_FIELDS, +}; + constexpr DML_SCHEMA_FIELD DML_GATHER_ND1_OPERATOR_SCHEMA_FIELDS[6] { DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false }, DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "IndicesTensor", false }, @@ -1949,6 +2166,87 @@ constexpr DML_OPERATOR_SCHEMA DML_GATHER_ND1_OPERATOR_SCHEMA { DML_GATHER_ND1_OPERATOR_SCHEMA_FIELDS, }; +constexpr DML_SCHEMA_FIELD DML_DYNAMIC_QUANTIZE_LINEAR_OPERATOR_SCHEMA_FIELDS[4] { + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputScaleTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputZeroPointTensor", false }, +}; + +constexpr DML_OPERATOR_SCHEMA DML_DYNAMIC_QUANTIZE_LINEAR_OPERATOR_SCHEMA { + "DML_OPERATOR_DYNAMIC_QUANTIZE_LINEAR", + DML_OPERATOR_DYNAMIC_QUANTIZE_LINEAR, + DML_SCHEMA_OPERATOR_SUPPORT_FLAG_NONE, + 4, + DML_DYNAMIC_QUANTIZE_LINEAR_OPERATOR_SCHEMA_FIELDS, +}; + +constexpr DML_SCHEMA_FIELD DML_ELEMENT_WISE_QUANTIZED_LINEAR_ADD_OPERATOR_SCHEMA_FIELDS[9] { + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "ATensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "AScaleTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "AZeroPointTensor", true }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "BTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "BScaleTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "BZeroPointTensor", true }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputScaleTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputZeroPointTensor", true }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false }, +}; + +constexpr DML_OPERATOR_SCHEMA DML_ELEMENT_WISE_QUANTIZED_LINEAR_ADD_OPERATOR_SCHEMA { + "DML_OPERATOR_ELEMENT_WISE_QUANTIZED_LINEAR_ADD", + DML_OPERATOR_ELEMENT_WISE_QUANTIZED_LINEAR_ADD, + DML_SCHEMA_OPERATOR_SUPPORT_FLAG_NONE, + 9, + DML_ELEMENT_WISE_QUANTIZED_LINEAR_ADD_OPERATOR_SCHEMA_FIELDS, +}; + +constexpr DML_SCHEMA_FIELD DML_ROI_ALIGN_GRAD_OPERATOR_SCHEMA_FIELDS[15] { + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", true }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputGradientTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "ROITensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "BatchIndicesTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputGradientTensor", true }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputROIGradientTensor", true }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT, "ReductionFunction", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT, "InterpolationMode", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_FLOAT, "SpatialScaleX", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_FLOAT, "SpatialScaleY", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_FLOAT, "InputPixelOffset", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_FLOAT, "OutputPixelOffset", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT, "MinimumSamplesPerOutput", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT, "MaximumSamplesPerOutput", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT, "AlignRegionsToCorners", false }, +}; + +constexpr DML_OPERATOR_SCHEMA DML_ROI_ALIGN_GRAD_OPERATOR_SCHEMA { + "DML_OPERATOR_ROI_ALIGN_GRAD", + DML_OPERATOR_ROI_ALIGN_GRAD, + DML_SCHEMA_OPERATOR_SUPPORT_FLAG_NONE, + 15, + DML_ROI_ALIGN_GRAD_OPERATOR_SCHEMA_FIELDS, +}; + +constexpr DML_SCHEMA_FIELD DML_BATCH_NORMALIZATION_TRAINING_OPERATOR_SCHEMA_FIELDS[9] { + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "ScaleTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "BiasTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "FusedAddTensor", true }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputMeanTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputVarianceTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_FLOAT, "Epsilon", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_OPERATOR_DESC, "FusedActivation", true }, +}; + +constexpr DML_OPERATOR_SCHEMA DML_BATCH_NORMALIZATION_TRAINING_OPERATOR_SCHEMA { + "DML_OPERATOR_BATCH_NORMALIZATION_TRAINING", + DML_OPERATOR_BATCH_NORMALIZATION_TRAINING, + DML_SCHEMA_OPERATOR_SUPPORT_FLAG_NONE, + 9, + DML_BATCH_NORMALIZATION_TRAINING_OPERATOR_SCHEMA_FIELDS, +}; + constexpr DML_SCHEMA_FIELD DML_ACTIVATION_ELU_OPERATOR_SCHEMA_FIELDS[3] { DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false }, DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false }, diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/GeneratedSchemaHelpers.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/GeneratedSchemaHelpers.h index f9a12b92c5..227c6aa46c 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/GeneratedSchemaHelpers.h +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/GeneratedSchemaHelpers.h @@ -73,6 +73,38 @@ inline std::vector GetFields(const DML_ELEMENT_WISE_CLIP_OPERATOR OperatorField(&DML_ELEMENT_WISE_CLIP_OPERATOR_SCHEMA.Fields[4], ToOperatorFieldType(static_cast(desc.Max))), }; } +inline std::vector GetFields(const DML_ELEMENT_WISE_CLIP1_OPERATOR_DESC& desc) +{ + return { + OperatorField(&DML_ELEMENT_WISE_CLIP1_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast(desc.InputTensor))), + OperatorField(&DML_ELEMENT_WISE_CLIP1_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast(desc.OutputTensor))), + OperatorField(&DML_ELEMENT_WISE_CLIP1_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast(desc.ScaleBias))), + OperatorField(&DML_ELEMENT_WISE_CLIP1_OPERATOR_SCHEMA.Fields[3], ToOperatorFieldType(static_cast(desc.MinMaxDataType))), + OperatorField(&DML_ELEMENT_WISE_CLIP1_OPERATOR_SCHEMA.Fields[4], ToOperatorFieldType(static_cast(desc.Min))), + OperatorField(&DML_ELEMENT_WISE_CLIP1_OPERATOR_SCHEMA.Fields[5], ToOperatorFieldType(static_cast(desc.Max))), + }; +} +inline std::vector GetFields(const DML_ELEMENT_WISE_CLIP_GRAD_OPERATOR_DESC& desc) +{ + return { + OperatorField(&DML_ELEMENT_WISE_CLIP_GRAD_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast(desc.InputTensor))), + OperatorField(&DML_ELEMENT_WISE_CLIP_GRAD_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast(desc.InputGradientTensor))), + OperatorField(&DML_ELEMENT_WISE_CLIP_GRAD_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast(desc.OutputGradientTensor))), + OperatorField(&DML_ELEMENT_WISE_CLIP_GRAD_OPERATOR_SCHEMA.Fields[3], ToOperatorFieldType(static_cast(desc.Min))), + OperatorField(&DML_ELEMENT_WISE_CLIP_GRAD_OPERATOR_SCHEMA.Fields[4], ToOperatorFieldType(static_cast(desc.Max))), + }; +} +inline std::vector GetFields(const DML_ELEMENT_WISE_CLIP_GRAD1_OPERATOR_DESC& desc) +{ + return { + OperatorField(&DML_ELEMENT_WISE_CLIP_GRAD1_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast(desc.InputTensor))), + OperatorField(&DML_ELEMENT_WISE_CLIP_GRAD1_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast(desc.InputGradientTensor))), + OperatorField(&DML_ELEMENT_WISE_CLIP_GRAD1_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast(desc.OutputGradientTensor))), + OperatorField(&DML_ELEMENT_WISE_CLIP_GRAD1_OPERATOR_SCHEMA.Fields[3], ToOperatorFieldType(static_cast(desc.MinMaxDataType))), + OperatorField(&DML_ELEMENT_WISE_CLIP_GRAD1_OPERATOR_SCHEMA.Fields[4], ToOperatorFieldType(static_cast(desc.Min))), + OperatorField(&DML_ELEMENT_WISE_CLIP_GRAD1_OPERATOR_SCHEMA.Fields[5], ToOperatorFieldType(static_cast(desc.Max))), + }; +} inline std::vector GetFields(const DML_ELEMENT_WISE_COS_OPERATOR_DESC& desc) { return { @@ -258,6 +290,22 @@ inline std::vector GetFields(const DML_ELEMENT_WISE_SQRT_OPERATOR OperatorField(&DML_ELEMENT_WISE_SQRT_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast(desc.ScaleBias))), }; } +inline std::vector GetFields(const DML_ELEMENT_WISE_DIFFERENCE_SQUARE_OPERATOR_DESC& desc) +{ + return { + OperatorField(&DML_ELEMENT_WISE_DIFFERENCE_SQUARE_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast(desc.ATensor))), + OperatorField(&DML_ELEMENT_WISE_DIFFERENCE_SQUARE_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast(desc.BTensor))), + OperatorField(&DML_ELEMENT_WISE_DIFFERENCE_SQUARE_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast(desc.OutputTensor))), + }; +} +inline std::vector GetFields(const DML_ELEMENT_WISE_ATAN_YX_OPERATOR_DESC& desc) +{ + return { + OperatorField(&DML_ELEMENT_WISE_ATAN_YX_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast(desc.ATensor))), + OperatorField(&DML_ELEMENT_WISE_ATAN_YX_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast(desc.BTensor))), + OperatorField(&DML_ELEMENT_WISE_ATAN_YX_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast(desc.OutputTensor))), + }; +} inline std::vector GetFields(const DML_ELEMENT_WISE_SUBTRACT_OPERATOR_DESC& desc) { return { @@ -473,6 +521,19 @@ inline std::vector GetFields(const DML_PADDING_OPERATOR_DESC& des OperatorField(&DML_PADDING_OPERATOR_SCHEMA.Fields[6], ToOperatorFieldType(static_cast(desc.EndPadding), desc.DimensionCount)), }; } +inline std::vector GetFields(const DML_PADDING1_OPERATOR_DESC& desc) +{ + return { + OperatorField(&DML_PADDING1_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast(desc.InputTensor))), + OperatorField(&DML_PADDING1_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast(desc.OutputTensor))), + OperatorField(&DML_PADDING1_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast(desc.PaddingMode))), + OperatorField(&DML_PADDING1_OPERATOR_SCHEMA.Fields[3], ToOperatorFieldType(static_cast(desc.PaddingValueDataType))), + OperatorField(&DML_PADDING1_OPERATOR_SCHEMA.Fields[4], ToOperatorFieldType(static_cast(desc.PaddingValue))), + OperatorField(&DML_PADDING1_OPERATOR_SCHEMA.Fields[5], ToOperatorFieldType(static_cast(desc.DimensionCount))), + OperatorField(&DML_PADDING1_OPERATOR_SCHEMA.Fields[6], ToOperatorFieldType(static_cast(desc.StartPadding), desc.DimensionCount)), + OperatorField(&DML_PADDING1_OPERATOR_SCHEMA.Fields[7], ToOperatorFieldType(static_cast(desc.EndPadding), desc.DimensionCount)), + }; +} inline std::vector GetFields(const DML_VALUE_SCALE_2D_OPERATOR_DESC& desc) { return { @@ -551,6 +612,34 @@ inline std::vector GetFields(const DML_BATCH_NORMALIZATION_OPERAT OperatorField(&DML_BATCH_NORMALIZATION_OPERATOR_SCHEMA.Fields[8], ToOperatorFieldType(static_cast(desc.FusedActivation))), }; } +inline std::vector GetFields(const DML_BATCH_NORMALIZATION_GRAD_OPERATOR_DESC& desc) +{ + return { + OperatorField(&DML_BATCH_NORMALIZATION_GRAD_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast(desc.InputTensor))), + OperatorField(&DML_BATCH_NORMALIZATION_GRAD_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast(desc.InputGradientTensor))), + OperatorField(&DML_BATCH_NORMALIZATION_GRAD_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast(desc.MeanTensor))), + OperatorField(&DML_BATCH_NORMALIZATION_GRAD_OPERATOR_SCHEMA.Fields[3], ToOperatorFieldType(static_cast(desc.VarianceTensor))), + OperatorField(&DML_BATCH_NORMALIZATION_GRAD_OPERATOR_SCHEMA.Fields[4], ToOperatorFieldType(static_cast(desc.ScaleTensor))), + OperatorField(&DML_BATCH_NORMALIZATION_GRAD_OPERATOR_SCHEMA.Fields[5], ToOperatorFieldType(static_cast(desc.OutputGradientTensor))), + OperatorField(&DML_BATCH_NORMALIZATION_GRAD_OPERATOR_SCHEMA.Fields[6], ToOperatorFieldType(static_cast(desc.OutputScaleGradientTensor))), + OperatorField(&DML_BATCH_NORMALIZATION_GRAD_OPERATOR_SCHEMA.Fields[7], ToOperatorFieldType(static_cast(desc.OutputBiasGradientTensor))), + OperatorField(&DML_BATCH_NORMALIZATION_GRAD_OPERATOR_SCHEMA.Fields[8], ToOperatorFieldType(static_cast(desc.Epsilon))), + }; +} +inline std::vector GetFields(const DML_BATCH_NORMALIZATION_TRAINING_GRAD_OPERATOR_DESC& desc) +{ + return { + OperatorField(&DML_BATCH_NORMALIZATION_TRAINING_GRAD_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast(desc.InputTensor))), + OperatorField(&DML_BATCH_NORMALIZATION_TRAINING_GRAD_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast(desc.InputGradientTensor))), + OperatorField(&DML_BATCH_NORMALIZATION_TRAINING_GRAD_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast(desc.MeanTensor))), + OperatorField(&DML_BATCH_NORMALIZATION_TRAINING_GRAD_OPERATOR_SCHEMA.Fields[3], ToOperatorFieldType(static_cast(desc.VarianceTensor))), + OperatorField(&DML_BATCH_NORMALIZATION_TRAINING_GRAD_OPERATOR_SCHEMA.Fields[4], ToOperatorFieldType(static_cast(desc.ScaleTensor))), + OperatorField(&DML_BATCH_NORMALIZATION_TRAINING_GRAD_OPERATOR_SCHEMA.Fields[5], ToOperatorFieldType(static_cast(desc.OutputGradientTensor))), + OperatorField(&DML_BATCH_NORMALIZATION_TRAINING_GRAD_OPERATOR_SCHEMA.Fields[6], ToOperatorFieldType(static_cast(desc.OutputScaleGradientTensor))), + OperatorField(&DML_BATCH_NORMALIZATION_TRAINING_GRAD_OPERATOR_SCHEMA.Fields[7], ToOperatorFieldType(static_cast(desc.OutputBiasGradientTensor))), + OperatorField(&DML_BATCH_NORMALIZATION_TRAINING_GRAD_OPERATOR_SCHEMA.Fields[8], ToOperatorFieldType(static_cast(desc.Epsilon))), + }; +} inline std::vector GetFields(const DML_MEAN_VARIANCE_NORMALIZATION_OPERATOR_DESC& desc) { return { @@ -576,6 +665,19 @@ inline std::vector GetFields(const DML_LOCAL_RESPONSE_NORMALIZATI OperatorField(&DML_LOCAL_RESPONSE_NORMALIZATION_OPERATOR_SCHEMA.Fields[6], ToOperatorFieldType(static_cast(desc.Bias))), }; } +inline std::vector GetFields(const DML_LOCAL_RESPONSE_NORMALIZATION_GRAD_OPERATOR_DESC& desc) +{ + return { + OperatorField(&DML_LOCAL_RESPONSE_NORMALIZATION_GRAD_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast(desc.InputTensor))), + OperatorField(&DML_LOCAL_RESPONSE_NORMALIZATION_GRAD_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast(desc.InputGradientTensor))), + OperatorField(&DML_LOCAL_RESPONSE_NORMALIZATION_GRAD_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast(desc.OutputGradientTensor))), + OperatorField(&DML_LOCAL_RESPONSE_NORMALIZATION_GRAD_OPERATOR_SCHEMA.Fields[3], ToOperatorFieldType(static_cast(desc.CrossChannel))), + OperatorField(&DML_LOCAL_RESPONSE_NORMALIZATION_GRAD_OPERATOR_SCHEMA.Fields[4], ToOperatorFieldType(static_cast(desc.LocalSize))), + OperatorField(&DML_LOCAL_RESPONSE_NORMALIZATION_GRAD_OPERATOR_SCHEMA.Fields[5], ToOperatorFieldType(static_cast(desc.Alpha))), + OperatorField(&DML_LOCAL_RESPONSE_NORMALIZATION_GRAD_OPERATOR_SCHEMA.Fields[6], ToOperatorFieldType(static_cast(desc.Beta))), + OperatorField(&DML_LOCAL_RESPONSE_NORMALIZATION_GRAD_OPERATOR_SCHEMA.Fields[7], ToOperatorFieldType(static_cast(desc.Bias))), + }; +} inline std::vector GetFields(const DML_LP_NORMALIZATION_OPERATOR_DESC& desc) { return { @@ -655,6 +757,13 @@ inline std::vector GetFields(const DML_ELEMENT_WISE_IS_NAN_OPERAT OperatorField(&DML_ELEMENT_WISE_IS_NAN_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast(desc.OutputTensor))), }; } +inline std::vector GetFields(const DML_ELEMENT_WISE_NEGATE_OPERATOR_DESC& desc) +{ + return { + OperatorField(&DML_ELEMENT_WISE_NEGATE_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast(desc.InputTensor))), + OperatorField(&DML_ELEMENT_WISE_NEGATE_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast(desc.OutputTensor))), + }; +} inline std::vector GetFields(const DML_ELEMENT_WISE_ERF_OPERATOR_DESC& desc) { return { @@ -845,8 +954,18 @@ inline std::vector GetFields(const DML_CUMULATIVE_SUMMATION_OPERA OperatorField(&DML_CUMULATIVE_SUMMATION_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast(desc.InputTensor))), OperatorField(&DML_CUMULATIVE_SUMMATION_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast(desc.OutputTensor))), OperatorField(&DML_CUMULATIVE_SUMMATION_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast(desc.Axis))), - OperatorField(&DML_CUMULATIVE_SUMMATION_OPERATOR_SCHEMA.Fields[3], ToOperatorFieldType(static_cast(desc.HasExclusiveSum))), - OperatorField(&DML_CUMULATIVE_SUMMATION_OPERATOR_SCHEMA.Fields[4], ToOperatorFieldType(static_cast(desc.AxisDirection))), + OperatorField(&DML_CUMULATIVE_SUMMATION_OPERATOR_SCHEMA.Fields[3], ToOperatorFieldType(static_cast(desc.AxisDirection))), + OperatorField(&DML_CUMULATIVE_SUMMATION_OPERATOR_SCHEMA.Fields[4], ToOperatorFieldType(static_cast(desc.HasExclusiveSum))), + }; +} +inline std::vector GetFields(const DML_CUMULATIVE_PRODUCT_OPERATOR_DESC& desc) +{ + return { + OperatorField(&DML_CUMULATIVE_PRODUCT_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast(desc.InputTensor))), + OperatorField(&DML_CUMULATIVE_PRODUCT_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast(desc.OutputTensor))), + OperatorField(&DML_CUMULATIVE_PRODUCT_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast(desc.Axis))), + OperatorField(&DML_CUMULATIVE_PRODUCT_OPERATOR_SCHEMA.Fields[3], ToOperatorFieldType(static_cast(desc.AxisDirection))), + OperatorField(&DML_CUMULATIVE_PRODUCT_OPERATOR_SCHEMA.Fields[4], ToOperatorFieldType(static_cast(desc.HasExclusiveProduct))), }; } inline std::vector GetFields(const DML_REVERSE_SUBSEQUENCES_OPERATOR_DESC& desc) @@ -1094,6 +1213,12 @@ inline std::vector GetFields(const DML_MAX_POOLING_GRAD_OPERATOR_ OperatorField(&DML_MAX_POOLING_GRAD_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast(desc.InputTensor))), OperatorField(&DML_MAX_POOLING_GRAD_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast(desc.InputGradientTensor))), OperatorField(&DML_MAX_POOLING_GRAD_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast(desc.OutputGradientTensor))), + OperatorField(&DML_MAX_POOLING_GRAD_OPERATOR_SCHEMA.Fields[3], ToOperatorFieldType(static_cast(desc.DimensionCount))), + OperatorField(&DML_MAX_POOLING_GRAD_OPERATOR_SCHEMA.Fields[4], ToOperatorFieldType(static_cast(desc.Strides), desc.DimensionCount)), + OperatorField(&DML_MAX_POOLING_GRAD_OPERATOR_SCHEMA.Fields[5], ToOperatorFieldType(static_cast(desc.WindowSize), desc.DimensionCount)), + OperatorField(&DML_MAX_POOLING_GRAD_OPERATOR_SCHEMA.Fields[6], ToOperatorFieldType(static_cast(desc.StartPadding), desc.DimensionCount)), + OperatorField(&DML_MAX_POOLING_GRAD_OPERATOR_SCHEMA.Fields[7], ToOperatorFieldType(static_cast(desc.EndPadding), desc.DimensionCount)), + OperatorField(&DML_MAX_POOLING_GRAD_OPERATOR_SCHEMA.Fields[8], ToOperatorFieldType(static_cast(desc.Dilations), desc.DimensionCount)), }; } inline std::vector GetFields(const DML_RANDOM_GENERATOR_OPERATOR_DESC& desc) @@ -1169,6 +1294,25 @@ inline std::vector GetFields(const DML_ROI_ALIGN_OPERATOR_DESC& d OperatorField(&DML_ROI_ALIGN_OPERATOR_SCHEMA.Fields[10], ToOperatorFieldType(static_cast(desc.MaximumSamplesPerOutput))), }; } +inline std::vector GetFields(const DML_ROI_ALIGN1_OPERATOR_DESC& desc) +{ + return { + OperatorField(&DML_ROI_ALIGN1_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast(desc.InputTensor))), + OperatorField(&DML_ROI_ALIGN1_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast(desc.ROITensor))), + OperatorField(&DML_ROI_ALIGN1_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast(desc.BatchIndicesTensor))), + OperatorField(&DML_ROI_ALIGN1_OPERATOR_SCHEMA.Fields[3], ToOperatorFieldType(static_cast(desc.OutputTensor))), + OperatorField(&DML_ROI_ALIGN1_OPERATOR_SCHEMA.Fields[4], ToOperatorFieldType(static_cast(desc.ReductionFunction))), + OperatorField(&DML_ROI_ALIGN1_OPERATOR_SCHEMA.Fields[5], ToOperatorFieldType(static_cast(desc.InterpolationMode))), + OperatorField(&DML_ROI_ALIGN1_OPERATOR_SCHEMA.Fields[6], ToOperatorFieldType(static_cast(desc.SpatialScaleX))), + OperatorField(&DML_ROI_ALIGN1_OPERATOR_SCHEMA.Fields[7], ToOperatorFieldType(static_cast(desc.SpatialScaleY))), + OperatorField(&DML_ROI_ALIGN1_OPERATOR_SCHEMA.Fields[8], ToOperatorFieldType(static_cast(desc.InputPixelOffset))), + OperatorField(&DML_ROI_ALIGN1_OPERATOR_SCHEMA.Fields[9], ToOperatorFieldType(static_cast(desc.OutputPixelOffset))), + OperatorField(&DML_ROI_ALIGN1_OPERATOR_SCHEMA.Fields[10], ToOperatorFieldType(static_cast(desc.OutOfBoundsInputValue))), + OperatorField(&DML_ROI_ALIGN1_OPERATOR_SCHEMA.Fields[11], ToOperatorFieldType(static_cast(desc.MinimumSamplesPerOutput))), + OperatorField(&DML_ROI_ALIGN1_OPERATOR_SCHEMA.Fields[12], ToOperatorFieldType(static_cast(desc.MaximumSamplesPerOutput))), + OperatorField(&DML_ROI_ALIGN1_OPERATOR_SCHEMA.Fields[13], ToOperatorFieldType(static_cast(desc.AlignRegionsToCorners))), + }; +} inline std::vector GetFields(const DML_GATHER_ND1_OPERATOR_DESC& desc) { return { @@ -1180,6 +1324,63 @@ inline std::vector GetFields(const DML_GATHER_ND1_OPERATOR_DESC& OperatorField(&DML_GATHER_ND1_OPERATOR_SCHEMA.Fields[5], ToOperatorFieldType(static_cast(desc.BatchDimensionCount))), }; } +inline std::vector GetFields(const DML_DYNAMIC_QUANTIZE_LINEAR_OPERATOR_DESC& desc) +{ + return { + OperatorField(&DML_DYNAMIC_QUANTIZE_LINEAR_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast(desc.InputTensor))), + OperatorField(&DML_DYNAMIC_QUANTIZE_LINEAR_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast(desc.OutputTensor))), + OperatorField(&DML_DYNAMIC_QUANTIZE_LINEAR_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast(desc.OutputScaleTensor))), + OperatorField(&DML_DYNAMIC_QUANTIZE_LINEAR_OPERATOR_SCHEMA.Fields[3], ToOperatorFieldType(static_cast(desc.OutputZeroPointTensor))), + }; +} +inline std::vector GetFields(const DML_ELEMENT_WISE_QUANTIZED_LINEAR_ADD_OPERATOR_DESC& desc) +{ + return { + OperatorField(&DML_ELEMENT_WISE_QUANTIZED_LINEAR_ADD_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast(desc.ATensor))), + OperatorField(&DML_ELEMENT_WISE_QUANTIZED_LINEAR_ADD_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast(desc.AScaleTensor))), + OperatorField(&DML_ELEMENT_WISE_QUANTIZED_LINEAR_ADD_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast(desc.AZeroPointTensor))), + OperatorField(&DML_ELEMENT_WISE_QUANTIZED_LINEAR_ADD_OPERATOR_SCHEMA.Fields[3], ToOperatorFieldType(static_cast(desc.BTensor))), + OperatorField(&DML_ELEMENT_WISE_QUANTIZED_LINEAR_ADD_OPERATOR_SCHEMA.Fields[4], ToOperatorFieldType(static_cast(desc.BScaleTensor))), + OperatorField(&DML_ELEMENT_WISE_QUANTIZED_LINEAR_ADD_OPERATOR_SCHEMA.Fields[5], ToOperatorFieldType(static_cast(desc.BZeroPointTensor))), + OperatorField(&DML_ELEMENT_WISE_QUANTIZED_LINEAR_ADD_OPERATOR_SCHEMA.Fields[6], ToOperatorFieldType(static_cast(desc.OutputScaleTensor))), + OperatorField(&DML_ELEMENT_WISE_QUANTIZED_LINEAR_ADD_OPERATOR_SCHEMA.Fields[7], ToOperatorFieldType(static_cast(desc.OutputZeroPointTensor))), + OperatorField(&DML_ELEMENT_WISE_QUANTIZED_LINEAR_ADD_OPERATOR_SCHEMA.Fields[8], ToOperatorFieldType(static_cast(desc.OutputTensor))), + }; +} +inline std::vector GetFields(const DML_ROI_ALIGN_GRAD_OPERATOR_DESC& desc) +{ + return { + OperatorField(&DML_ROI_ALIGN_GRAD_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast(desc.InputTensor))), + OperatorField(&DML_ROI_ALIGN_GRAD_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast(desc.InputGradientTensor))), + OperatorField(&DML_ROI_ALIGN_GRAD_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast(desc.ROITensor))), + OperatorField(&DML_ROI_ALIGN_GRAD_OPERATOR_SCHEMA.Fields[3], ToOperatorFieldType(static_cast(desc.BatchIndicesTensor))), + OperatorField(&DML_ROI_ALIGN_GRAD_OPERATOR_SCHEMA.Fields[4], ToOperatorFieldType(static_cast(desc.OutputGradientTensor))), + OperatorField(&DML_ROI_ALIGN_GRAD_OPERATOR_SCHEMA.Fields[5], ToOperatorFieldType(static_cast(desc.OutputROIGradientTensor))), + OperatorField(&DML_ROI_ALIGN_GRAD_OPERATOR_SCHEMA.Fields[6], ToOperatorFieldType(static_cast(desc.ReductionFunction))), + OperatorField(&DML_ROI_ALIGN_GRAD_OPERATOR_SCHEMA.Fields[7], ToOperatorFieldType(static_cast(desc.InterpolationMode))), + OperatorField(&DML_ROI_ALIGN_GRAD_OPERATOR_SCHEMA.Fields[8], ToOperatorFieldType(static_cast(desc.SpatialScaleX))), + OperatorField(&DML_ROI_ALIGN_GRAD_OPERATOR_SCHEMA.Fields[9], ToOperatorFieldType(static_cast(desc.SpatialScaleY))), + OperatorField(&DML_ROI_ALIGN_GRAD_OPERATOR_SCHEMA.Fields[10], ToOperatorFieldType(static_cast(desc.InputPixelOffset))), + OperatorField(&DML_ROI_ALIGN_GRAD_OPERATOR_SCHEMA.Fields[11], ToOperatorFieldType(static_cast(desc.OutputPixelOffset))), + OperatorField(&DML_ROI_ALIGN_GRAD_OPERATOR_SCHEMA.Fields[12], ToOperatorFieldType(static_cast(desc.MinimumSamplesPerOutput))), + OperatorField(&DML_ROI_ALIGN_GRAD_OPERATOR_SCHEMA.Fields[13], ToOperatorFieldType(static_cast(desc.MaximumSamplesPerOutput))), + OperatorField(&DML_ROI_ALIGN_GRAD_OPERATOR_SCHEMA.Fields[14], ToOperatorFieldType(static_cast(desc.AlignRegionsToCorners))), + }; +} +inline std::vector GetFields(const DML_BATCH_NORMALIZATION_TRAINING_OPERATOR_DESC& desc) +{ + return { + OperatorField(&DML_BATCH_NORMALIZATION_TRAINING_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast(desc.InputTensor))), + OperatorField(&DML_BATCH_NORMALIZATION_TRAINING_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast(desc.ScaleTensor))), + OperatorField(&DML_BATCH_NORMALIZATION_TRAINING_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast(desc.BiasTensor))), + OperatorField(&DML_BATCH_NORMALIZATION_TRAINING_OPERATOR_SCHEMA.Fields[3], ToOperatorFieldType(static_cast(desc.FusedAddTensor))), + OperatorField(&DML_BATCH_NORMALIZATION_TRAINING_OPERATOR_SCHEMA.Fields[4], ToOperatorFieldType(static_cast(desc.OutputTensor))), + OperatorField(&DML_BATCH_NORMALIZATION_TRAINING_OPERATOR_SCHEMA.Fields[5], ToOperatorFieldType(static_cast(desc.OutputMeanTensor))), + OperatorField(&DML_BATCH_NORMALIZATION_TRAINING_OPERATOR_SCHEMA.Fields[6], ToOperatorFieldType(static_cast(desc.OutputVarianceTensor))), + OperatorField(&DML_BATCH_NORMALIZATION_TRAINING_OPERATOR_SCHEMA.Fields[7], ToOperatorFieldType(static_cast(desc.Epsilon))), + OperatorField(&DML_BATCH_NORMALIZATION_TRAINING_OPERATOR_SCHEMA.Fields[8], ToOperatorFieldType(static_cast(desc.FusedActivation))), + }; +} inline std::vector GetFields(const DML_ACTIVATION_ELU_OPERATOR_DESC& desc) { return { @@ -1351,6 +1552,9 @@ inline const DML_OPERATOR_SCHEMA& GetSchema(DML_OPERATOR_TYPE operatorType) case DML_OPERATOR_ELEMENT_WISE_ATAN: return DML_ELEMENT_WISE_ATAN_OPERATOR_SCHEMA; case DML_OPERATOR_ELEMENT_WISE_CEIL: return DML_ELEMENT_WISE_CEIL_OPERATOR_SCHEMA; case DML_OPERATOR_ELEMENT_WISE_CLIP: return DML_ELEMENT_WISE_CLIP_OPERATOR_SCHEMA; + case DML_OPERATOR_ELEMENT_WISE_CLIP1: return DML_ELEMENT_WISE_CLIP1_OPERATOR_SCHEMA; + case DML_OPERATOR_ELEMENT_WISE_CLIP_GRAD: return DML_ELEMENT_WISE_CLIP_GRAD_OPERATOR_SCHEMA; + case DML_OPERATOR_ELEMENT_WISE_CLIP_GRAD1: return DML_ELEMENT_WISE_CLIP_GRAD1_OPERATOR_SCHEMA; case DML_OPERATOR_ELEMENT_WISE_COS: return DML_ELEMENT_WISE_COS_OPERATOR_SCHEMA; case DML_OPERATOR_ELEMENT_WISE_DIVIDE: return DML_ELEMENT_WISE_DIVIDE_OPERATOR_SCHEMA; case DML_OPERATOR_ELEMENT_WISE_EXP: return DML_ELEMENT_WISE_EXP_OPERATOR_SCHEMA; @@ -1374,6 +1578,8 @@ inline const DML_OPERATOR_SCHEMA& GetSchema(DML_OPERATOR_TYPE operatorType) case DML_OPERATOR_ELEMENT_WISE_RECIP: return DML_ELEMENT_WISE_RECIP_OPERATOR_SCHEMA; case DML_OPERATOR_ELEMENT_WISE_SIN: return DML_ELEMENT_WISE_SIN_OPERATOR_SCHEMA; case DML_OPERATOR_ELEMENT_WISE_SQRT: return DML_ELEMENT_WISE_SQRT_OPERATOR_SCHEMA; + case DML_OPERATOR_ELEMENT_WISE_DIFFERENCE_SQUARE: return DML_ELEMENT_WISE_DIFFERENCE_SQUARE_OPERATOR_SCHEMA; + case DML_OPERATOR_ELEMENT_WISE_ATAN_YX: return DML_ELEMENT_WISE_ATAN_YX_OPERATOR_SCHEMA; case DML_OPERATOR_ELEMENT_WISE_SUBTRACT: return DML_ELEMENT_WISE_SUBTRACT_OPERATOR_SCHEMA; case DML_OPERATOR_ELEMENT_WISE_TAN: return DML_ELEMENT_WISE_TAN_OPERATOR_SCHEMA; case DML_OPERATOR_ELEMENT_WISE_THRESHOLD: return DML_ELEMENT_WISE_THRESHOLD_OPERATOR_SCHEMA; @@ -1394,6 +1600,7 @@ inline const DML_OPERATOR_SCHEMA& GetSchema(DML_OPERATOR_TYPE operatorType) case DML_OPERATOR_SPLIT: return DML_SPLIT_OPERATOR_SCHEMA; case DML_OPERATOR_JOIN: return DML_JOIN_OPERATOR_SCHEMA; case DML_OPERATOR_PADDING: return DML_PADDING_OPERATOR_SCHEMA; + case DML_OPERATOR_PADDING1: return DML_PADDING1_OPERATOR_SCHEMA; case DML_OPERATOR_VALUE_SCALE_2D: return DML_VALUE_SCALE_2D_OPERATOR_SCHEMA; case DML_OPERATOR_UPSAMPLE_2D: return DML_UPSAMPLE_2D_OPERATOR_SCHEMA; case DML_OPERATOR_GATHER: return DML_GATHER_OPERATOR_SCHEMA; @@ -1402,14 +1609,18 @@ inline const DML_OPERATOR_SCHEMA& GetSchema(DML_OPERATOR_TYPE operatorType) case DML_OPERATOR_TILE: return DML_TILE_OPERATOR_SCHEMA; case DML_OPERATOR_TOP_K: return DML_TOP_K_OPERATOR_SCHEMA; case DML_OPERATOR_BATCH_NORMALIZATION: return DML_BATCH_NORMALIZATION_OPERATOR_SCHEMA; + case DML_OPERATOR_BATCH_NORMALIZATION_GRAD: return DML_BATCH_NORMALIZATION_GRAD_OPERATOR_SCHEMA; + case DML_OPERATOR_BATCH_NORMALIZATION_TRAINING_GRAD: return DML_BATCH_NORMALIZATION_TRAINING_GRAD_OPERATOR_SCHEMA; case DML_OPERATOR_MEAN_VARIANCE_NORMALIZATION: return DML_MEAN_VARIANCE_NORMALIZATION_OPERATOR_SCHEMA; case DML_OPERATOR_LOCAL_RESPONSE_NORMALIZATION: return DML_LOCAL_RESPONSE_NORMALIZATION_OPERATOR_SCHEMA; + case DML_OPERATOR_LOCAL_RESPONSE_NORMALIZATION_GRAD: return DML_LOCAL_RESPONSE_NORMALIZATION_GRAD_OPERATOR_SCHEMA; case DML_OPERATOR_LP_NORMALIZATION: return DML_LP_NORMALIZATION_OPERATOR_SCHEMA; case DML_OPERATOR_RNN: return DML_RNN_OPERATOR_SCHEMA; case DML_OPERATOR_LSTM: return DML_LSTM_OPERATOR_SCHEMA; case DML_OPERATOR_GRU: return DML_GRU_OPERATOR_SCHEMA; case DML_OPERATOR_ELEMENT_WISE_SIGN: return DML_ELEMENT_WISE_SIGN_OPERATOR_SCHEMA; case DML_OPERATOR_ELEMENT_WISE_IS_NAN: return DML_ELEMENT_WISE_IS_NAN_OPERATOR_SCHEMA; + case DML_OPERATOR_ELEMENT_WISE_NEGATE: return DML_ELEMENT_WISE_NEGATE_OPERATOR_SCHEMA; case DML_OPERATOR_ELEMENT_WISE_ERF: return DML_ELEMENT_WISE_ERF_OPERATOR_SCHEMA; case DML_OPERATOR_ELEMENT_WISE_SINH: return DML_ELEMENT_WISE_SINH_OPERATOR_SCHEMA; case DML_OPERATOR_ELEMENT_WISE_COSH: return DML_ELEMENT_WISE_COSH_OPERATOR_SCHEMA; @@ -1433,6 +1644,7 @@ inline const DML_OPERATOR_SCHEMA& GetSchema(DML_OPERATOR_TYPE operatorType) case DML_OPERATOR_FILL_VALUE_CONSTANT: return DML_FILL_VALUE_CONSTANT_OPERATOR_SCHEMA; case DML_OPERATOR_FILL_VALUE_SEQUENCE: return DML_FILL_VALUE_SEQUENCE_OPERATOR_SCHEMA; case DML_OPERATOR_CUMULATIVE_SUMMATION: return DML_CUMULATIVE_SUMMATION_OPERATOR_SCHEMA; + case DML_OPERATOR_CUMULATIVE_PRODUCT: return DML_CUMULATIVE_PRODUCT_OPERATOR_SCHEMA; case DML_OPERATOR_REVERSE_SUBSEQUENCES: return DML_REVERSE_SUBSEQUENCES_OPERATOR_SCHEMA; case DML_OPERATOR_GATHER_ELEMENTS: return DML_GATHER_ELEMENTS_OPERATOR_SCHEMA; case DML_OPERATOR_GATHER_ND: return DML_GATHER_ND_OPERATOR_SCHEMA; @@ -1462,7 +1674,12 @@ inline const DML_OPERATOR_SCHEMA& GetSchema(DML_OPERATOR_TYPE operatorType) case DML_OPERATOR_SLICE_GRAD: return DML_SLICE_GRAD_OPERATOR_SCHEMA; case DML_OPERATOR_ADAM_OPTIMIZER: return DML_ADAM_OPTIMIZER_OPERATOR_SCHEMA; case DML_OPERATOR_ROI_ALIGN: return DML_ROI_ALIGN_OPERATOR_SCHEMA; + case DML_OPERATOR_ROI_ALIGN1: return DML_ROI_ALIGN1_OPERATOR_SCHEMA; case DML_OPERATOR_GATHER_ND1: return DML_GATHER_ND1_OPERATOR_SCHEMA; + case DML_OPERATOR_DYNAMIC_QUANTIZE_LINEAR: return DML_DYNAMIC_QUANTIZE_LINEAR_OPERATOR_SCHEMA; + case DML_OPERATOR_ELEMENT_WISE_QUANTIZED_LINEAR_ADD: return DML_ELEMENT_WISE_QUANTIZED_LINEAR_ADD_OPERATOR_SCHEMA; + case DML_OPERATOR_ROI_ALIGN_GRAD: return DML_ROI_ALIGN_GRAD_OPERATOR_SCHEMA; + case DML_OPERATOR_BATCH_NORMALIZATION_TRAINING: return DML_BATCH_NORMALIZATION_TRAINING_OPERATOR_SCHEMA; case DML_OPERATOR_ACTIVATION_ELU: return DML_ACTIVATION_ELU_OPERATOR_SCHEMA; case DML_OPERATOR_ACTIVATION_CELU: return DML_ACTIVATION_CELU_OPERATOR_SCHEMA; case DML_OPERATOR_ACTIVATION_HARDMAX: return DML_ACTIVATION_HARDMAX_OPERATOR_SCHEMA; @@ -1528,6 +1745,18 @@ inline AbstractOperatorDesc ConvertOperatorDesc(const DML_OPERATOR_DESC& opDesc) return AbstractOperatorDesc( &DML_ELEMENT_WISE_CLIP_OPERATOR_SCHEMA, GetFields(*static_cast(opDesc.Desc))); + case DML_OPERATOR_ELEMENT_WISE_CLIP1: + return AbstractOperatorDesc( + &DML_ELEMENT_WISE_CLIP1_OPERATOR_SCHEMA, + GetFields(*static_cast(opDesc.Desc))); + case DML_OPERATOR_ELEMENT_WISE_CLIP_GRAD: + return AbstractOperatorDesc( + &DML_ELEMENT_WISE_CLIP_GRAD_OPERATOR_SCHEMA, + GetFields(*static_cast(opDesc.Desc))); + case DML_OPERATOR_ELEMENT_WISE_CLIP_GRAD1: + return AbstractOperatorDesc( + &DML_ELEMENT_WISE_CLIP_GRAD1_OPERATOR_SCHEMA, + GetFields(*static_cast(opDesc.Desc))); case DML_OPERATOR_ELEMENT_WISE_COS: return AbstractOperatorDesc( &DML_ELEMENT_WISE_COS_OPERATOR_SCHEMA, @@ -1620,6 +1849,14 @@ inline AbstractOperatorDesc ConvertOperatorDesc(const DML_OPERATOR_DESC& opDesc) return AbstractOperatorDesc( &DML_ELEMENT_WISE_SQRT_OPERATOR_SCHEMA, GetFields(*static_cast(opDesc.Desc))); + case DML_OPERATOR_ELEMENT_WISE_DIFFERENCE_SQUARE: + return AbstractOperatorDesc( + &DML_ELEMENT_WISE_DIFFERENCE_SQUARE_OPERATOR_SCHEMA, + GetFields(*static_cast(opDesc.Desc))); + case DML_OPERATOR_ELEMENT_WISE_ATAN_YX: + return AbstractOperatorDesc( + &DML_ELEMENT_WISE_ATAN_YX_OPERATOR_SCHEMA, + GetFields(*static_cast(opDesc.Desc))); case DML_OPERATOR_ELEMENT_WISE_SUBTRACT: return AbstractOperatorDesc( &DML_ELEMENT_WISE_SUBTRACT_OPERATOR_SCHEMA, @@ -1700,6 +1937,10 @@ inline AbstractOperatorDesc ConvertOperatorDesc(const DML_OPERATOR_DESC& opDesc) return AbstractOperatorDesc( &DML_PADDING_OPERATOR_SCHEMA, GetFields(*static_cast(opDesc.Desc))); + case DML_OPERATOR_PADDING1: + return AbstractOperatorDesc( + &DML_PADDING1_OPERATOR_SCHEMA, + GetFields(*static_cast(opDesc.Desc))); case DML_OPERATOR_VALUE_SCALE_2D: return AbstractOperatorDesc( &DML_VALUE_SCALE_2D_OPERATOR_SCHEMA, @@ -1732,6 +1973,14 @@ inline AbstractOperatorDesc ConvertOperatorDesc(const DML_OPERATOR_DESC& opDesc) return AbstractOperatorDesc( &DML_BATCH_NORMALIZATION_OPERATOR_SCHEMA, GetFields(*static_cast(opDesc.Desc))); + case DML_OPERATOR_BATCH_NORMALIZATION_GRAD: + return AbstractOperatorDesc( + &DML_BATCH_NORMALIZATION_GRAD_OPERATOR_SCHEMA, + GetFields(*static_cast(opDesc.Desc))); + case DML_OPERATOR_BATCH_NORMALIZATION_TRAINING_GRAD: + return AbstractOperatorDesc( + &DML_BATCH_NORMALIZATION_TRAINING_GRAD_OPERATOR_SCHEMA, + GetFields(*static_cast(opDesc.Desc))); case DML_OPERATOR_MEAN_VARIANCE_NORMALIZATION: return AbstractOperatorDesc( &DML_MEAN_VARIANCE_NORMALIZATION_OPERATOR_SCHEMA, @@ -1740,6 +1989,10 @@ inline AbstractOperatorDesc ConvertOperatorDesc(const DML_OPERATOR_DESC& opDesc) return AbstractOperatorDesc( &DML_LOCAL_RESPONSE_NORMALIZATION_OPERATOR_SCHEMA, GetFields(*static_cast(opDesc.Desc))); + case DML_OPERATOR_LOCAL_RESPONSE_NORMALIZATION_GRAD: + return AbstractOperatorDesc( + &DML_LOCAL_RESPONSE_NORMALIZATION_GRAD_OPERATOR_SCHEMA, + GetFields(*static_cast(opDesc.Desc))); case DML_OPERATOR_LP_NORMALIZATION: return AbstractOperatorDesc( &DML_LP_NORMALIZATION_OPERATOR_SCHEMA, @@ -1764,6 +2017,10 @@ inline AbstractOperatorDesc ConvertOperatorDesc(const DML_OPERATOR_DESC& opDesc) return AbstractOperatorDesc( &DML_ELEMENT_WISE_IS_NAN_OPERATOR_SCHEMA, GetFields(*static_cast(opDesc.Desc))); + case DML_OPERATOR_ELEMENT_WISE_NEGATE: + return AbstractOperatorDesc( + &DML_ELEMENT_WISE_NEGATE_OPERATOR_SCHEMA, + GetFields(*static_cast(opDesc.Desc))); case DML_OPERATOR_ELEMENT_WISE_ERF: return AbstractOperatorDesc( &DML_ELEMENT_WISE_ERF_OPERATOR_SCHEMA, @@ -1856,6 +2113,10 @@ inline AbstractOperatorDesc ConvertOperatorDesc(const DML_OPERATOR_DESC& opDesc) return AbstractOperatorDesc( &DML_CUMULATIVE_SUMMATION_OPERATOR_SCHEMA, GetFields(*static_cast(opDesc.Desc))); + case DML_OPERATOR_CUMULATIVE_PRODUCT: + return AbstractOperatorDesc( + &DML_CUMULATIVE_PRODUCT_OPERATOR_SCHEMA, + GetFields(*static_cast(opDesc.Desc))); case DML_OPERATOR_REVERSE_SUBSEQUENCES: return AbstractOperatorDesc( &DML_REVERSE_SUBSEQUENCES_OPERATOR_SCHEMA, @@ -1972,10 +2233,30 @@ inline AbstractOperatorDesc ConvertOperatorDesc(const DML_OPERATOR_DESC& opDesc) return AbstractOperatorDesc( &DML_ROI_ALIGN_OPERATOR_SCHEMA, GetFields(*static_cast(opDesc.Desc))); + case DML_OPERATOR_ROI_ALIGN1: + return AbstractOperatorDesc( + &DML_ROI_ALIGN1_OPERATOR_SCHEMA, + GetFields(*static_cast(opDesc.Desc))); case DML_OPERATOR_GATHER_ND1: return AbstractOperatorDesc( &DML_GATHER_ND1_OPERATOR_SCHEMA, GetFields(*static_cast(opDesc.Desc))); + case DML_OPERATOR_DYNAMIC_QUANTIZE_LINEAR: + return AbstractOperatorDesc( + &DML_DYNAMIC_QUANTIZE_LINEAR_OPERATOR_SCHEMA, + GetFields(*static_cast(opDesc.Desc))); + case DML_OPERATOR_ELEMENT_WISE_QUANTIZED_LINEAR_ADD: + return AbstractOperatorDesc( + &DML_ELEMENT_WISE_QUANTIZED_LINEAR_ADD_OPERATOR_SCHEMA, + GetFields(*static_cast(opDesc.Desc))); + case DML_OPERATOR_ROI_ALIGN_GRAD: + return AbstractOperatorDesc( + &DML_ROI_ALIGN_GRAD_OPERATOR_SCHEMA, + GetFields(*static_cast(opDesc.Desc))); + case DML_OPERATOR_BATCH_NORMALIZATION_TRAINING: + return AbstractOperatorDesc( + &DML_BATCH_NORMALIZATION_TRAINING_OPERATOR_SCHEMA, + GetFields(*static_cast(opDesc.Desc))); case DML_OPERATOR_ACTIVATION_ELU: return AbstractOperatorDesc( &DML_ACTIVATION_ELU_OPERATOR_SCHEMA, diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphDescBuilder.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphDescBuilder.cpp index ed3df5c784..3cb05beba5 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphDescBuilder.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphDescBuilder.cpp @@ -139,29 +139,6 @@ namespace Dml::GraphDescBuilder &graphNodeInfo ); - // Determine the number of valid inputs and outputs of this node. The graph currently supports opererators - // with unused inputs and outputs only at the end of each list. - uint32_t validOpInputCount = 0; - uint32_t validOpOutputCount = 0; - - for (uint32_t i = 0; i < graphNodeInfo.kernelInputIndices.size(); ++i) - { - if (graphNodeInfo.kernelInputIndices[i] != std::numeric_limits::max()) - { - assert(i - validOpInputCount == 0); - ++validOpInputCount; - } - } - - for (uint32_t i = 0; i < graphNodeInfo.kernelOutputIndices.size(); ++i) - { - if (graphNodeInfo.kernelOutputIndices[i] != std::numeric_limits::max()) - { - assert(i - validOpOutputCount == 0); - ++validOpOutputCount; - } - } - uint32_t nodeIndex = gsl::narrow_cast(graphNodes.size()); AbstractOperatorDesc opDesc = *graphNodeInfo.desc; // Make a copy @@ -171,8 +148,13 @@ namespace Dml::GraphDescBuilder std::vector outputTensorDescs = opDesc.GetOutputTensors(); // Set connections of the new node - for (uint32_t inputIndex = 0; inputIndex < validOpInputCount; ++inputIndex) + for (uint32_t inputIndex = 0; inputIndex < graphNodeInfo.kernelInputIndices.size(); ++inputIndex) { + if (graphNodeInfo.kernelInputIndices[inputIndex] == std::numeric_limits::max()) + { + continue; + } + uint32_t kernelInputIndex = graphNodeInfo.kernelInputIndices[inputIndex]; const onnxruntime::NodeArg* arg = node.InputDefs()[kernelInputIndex]; @@ -224,8 +206,13 @@ namespace Dml::GraphDescBuilder // Store the new node for lookup when downstream nodes consume it. - for (uint32_t outputIndex = 0; outputIndex < validOpOutputCount; ++outputIndex) + for (uint32_t outputIndex = 0; outputIndex < graphNodeInfo.kernelOutputIndices.size(); ++outputIndex) { + if (graphNodeInfo.kernelOutputIndices[outputIndex] == std::numeric_limits::max()) + { + continue; + } + uint32_t kernelOutputIndex = graphNodeInfo.kernelOutputIndices[outputIndex]; const onnxruntime::NodeArg* arg = node.OutputDefs()[kernelOutputIndex]; if (arg->Exists()) diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphPartitioner.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphPartitioner.cpp index 8283bdf509..cac81590ed 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphPartitioner.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphPartitioner.cpp @@ -139,7 +139,11 @@ namespace Dml } }; - bool NodeArgSupportedInGraph(const onnxruntime::NodeArg* arg, uint32_t supportedDeviceDataTypeMask) + bool NodeArgSupportedInGraph( + const onnxruntime::NodeArg* arg, + bool supports64BitTensorsViaEmulation, + uint32_t supportedDeviceDataTypeMask + ) { if (arg->Exists()) { @@ -154,8 +158,11 @@ namespace Dml MLOperatorTensorDataType mlDataType = ToMLTensorDataType(static_cast(tensorType.elem_type())); - if (mlDataType == MLOperatorTensorDataType::UInt64 || - mlDataType == MLOperatorTensorDataType::Int64) + // Do not include operators in the graph if tensor types are unsupported, + // except cases that are always supported via emulation. + if ((mlDataType == MLOperatorTensorDataType::UInt64 || + mlDataType == MLOperatorTensorDataType::Int64) && + !supports64BitTensorsViaEmulation) { constexpr uint32_t deviceDataTypeMask64bit = (1 << DML_TENSOR_DATA_TYPE_UINT64) | (1 << DML_TENSOR_DATA_TYPE_INT64); if ((supportedDeviceDataTypeMask & deviceDataTypeMask64bit) != deviceDataTypeMask64bit) @@ -181,6 +188,7 @@ namespace Dml if (!isConstantCpuInput && !NodeArgSupportedInGraph( node.InputDefs()[i], + registration.support64BitTensorsViaEmulation, supportedDeviceDataTypeMask )) { @@ -192,6 +200,7 @@ namespace Dml { if (!NodeArgSupportedInGraph( arg, + registration.support64BitTensorsViaEmulation, supportedDeviceDataTypeMask )) { @@ -234,6 +243,7 @@ namespace Dml ORT_THROW_HR_IF(E_INVALIDARG, allow64BitInputThroughStrides && !nodeNameToPartitionMap); bool prefer64BitTensorsDirectly = false; + bool support64BitTensorsViaEmulation = false; bool supportedWith64BitTensorsVia32BitStrides = false; bool supportedWith64BitTensorsVia32BitStridesFromAnyEp = false; std::vector constantCpuInputs; @@ -244,9 +254,10 @@ namespace Dml // to 32-bit tensors via strides. If the caller passes allow64BitInputThroughStrides = false // in this particular call, then the operator-specific flags do not matter as the caller has // disabled 64-bit support. + prefer64BitTensorsDirectly = regInfo->prefer64BitTensorsDirectly; + support64BitTensorsViaEmulation = regInfo->support64BitTensorsViaEmulation; if (allow64BitInputThroughStrides) { - prefer64BitTensorsDirectly = regInfo->prefer64BitTensorsDirectly; supportedWith64BitTensorsVia32BitStridesFromAnyEp = regInfo->supportedWith64BitTensorsVia32BitStridesFromAnyEp; supportedWith64BitTensorsVia32BitStrides = regInfo->supportedWith64BitTensorsVia32BitStrides | supportedWith64BitTensorsVia32BitStridesFromAnyEp; } @@ -320,31 +331,44 @@ namespace Dml // operator, graph input or initializer, it's not safe to assume the input // can be represented with 32 bits. // + bool isDataTypeSupported = (1 << dmlElementType) & supportedDeviceDataTypeMask; bool is64BitIntType = (dmlElementType == DML_TENSOR_DATA_TYPE_UINT64 || dmlElementType == DML_TENSOR_DATA_TYPE_INT64); - bool needsFallbackTo32Bit = !prefer64BitTensorsDirectly || !((1 << dmlElementType) & supportedDeviceDataTypeMask); - if (is64BitIntType && supportedWith64BitTensorsVia32BitStrides && needsFallbackTo32Bit) + if (is64BitIntType) { - dmlElementType = Remap64bitDmlDataTypeTo32bit(dmlElementType); - - if (isInput && !supportedWith64BitTensorsVia32BitStridesFromAnyEp) + if (support64BitTensorsViaEmulation) { - // Look up the input partition. If it's a graph input or initializer it will be missing - // from the partition map. - const std::string& argName = nodeArg.Name(); + // Consider it supported regardless of hardware support. + isDataTypeSupported = true; + } + else if (prefer64BitTensorsDirectly && isDataTypeSupported) + { + // Operator supports native int64/uint64 tensors. + } + else if (supportedWith64BitTensorsVia32BitStrides || supportedWith64BitTensorsVia32BitStridesFromAnyEp) + { + dmlElementType = Remap64bitDmlDataTypeTo32bit(dmlElementType); + isDataTypeSupported = (1 << dmlElementType) & supportedDeviceDataTypeMask; - // If input tensor's data comes from the output of a different execution provider, - // consider it unsafe to apply fallback to. - auto partitionIter = nodeNameToPartitionMap->find(argName); - if (partitionIter == nodeNameToPartitionMap->end() || !partitionIter->second->IsDmlPartition()) + if (isInput && !supportedWith64BitTensorsVia32BitStridesFromAnyEp) { - nodeContainsSupportedDataTypes = false; - return; + // Look up the input partition. If it's a graph input or initializer it will be missing + // from the partition map. + const std::string& argName = nodeArg.Name(); + + // If input tensor's data comes from the output of a different execution provider, + // consider it unsafe to apply fallback to. + auto partitionIter = nodeNameToPartitionMap->find(argName); + if (partitionIter == nodeNameToPartitionMap->end() || !partitionIter->second->IsDmlPartition()) + { + nodeContainsSupportedDataTypes = false; + return; + } } } } // Reject node if the data type is unsupported by the device. - if (!((1 << dmlElementType) & supportedDeviceDataTypeMask)) + if (!isDataTypeSupported) { nodeContainsSupportedDataTypes = false; return; diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorConstantOfShape.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorConstantOfShape.cpp index ddef82ac05..c0098c47e9 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorConstantOfShape.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorConstantOfShape.cpp @@ -22,6 +22,13 @@ public: std::vector> outputIndices = { 0 }; Initialize(kernelCreationContext, inputIndices, outputIndices); + std::vector outputDescs = GetDmlOutputDescs(); + + DML_FILL_VALUE_CONSTANT_OPERATOR_DESC operatorDesc = {}; + operatorDesc.OutputTensor = outputDescs.data(); + operatorDesc.ValueDataType = this->m_outputTensorDescs.front().GetDmlDataType(); + // operatorDesc.Value already zeroed. + // Read the tensor attribute for the output fill pattern. if (kernelCreationContext.HasAttribute(AttrName::Value, MLOperatorAttributeTypeTensor)) { @@ -40,15 +47,14 @@ public: ML_CHECK_VALID_ARGUMENT(elementCount == 1); // Expect exactly one element. const size_t rawDataByteSize = GetByteSizeFromMlDataType(wrappedValueTensor.GetTensorDataType()); const std::byte* rawData = static_cast(valueTensor->GetData()); - valueBytes.assign(rawData, rawData + rawDataByteSize); + + memcpy(operatorDesc.Value.Bytes, rawData, std::min(rawDataByteSize, sizeof(operatorDesc.Value.Bytes))); } // Else valueBytes is empty, and the default fill pattern is 0. - } - void Compute(const MLOperatorKernelContext& kernelContext) override - { - std::vector outputTensors = GetOutputTensorsForExecute(kernelContext); - ORT_THROW_IF_FAILED(m_executionProvider->FillTensorWithPattern(outputTensors.front(), valueBytes)); + + DML_OPERATOR_DESC opDesc = { DML_OPERATOR_FILL_VALUE_CONSTANT, &operatorDesc }; + SetDmlOperatorDesc(opDesc, kernelCreationContext); } private: diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorDynamicQuantizeLinear.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorDynamicQuantizeLinear.cpp index b727817e98..75a19968e8 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorDynamicQuantizeLinear.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorDynamicQuantizeLinear.cpp @@ -7,30 +7,45 @@ namespace Dml { class DmlOperatorDynamicQuantizeLinear : public DmlOperator -{ +{ + enum DmlInputTensors { + IN_A, + }; + + enum DmlOutputTensors { + OUT_Y, + OUT_Y_SCALE, + OUT_Y_ZERO_POINT + }; + public: using Self = DmlOperatorDynamicQuantizeLinear; DmlOperatorDynamicQuantizeLinear(const MLOperatorKernelCreationContext& kernelCreationContext) : DmlOperator(kernelCreationContext) { -#if 0 // TODO:NickFe - https://github.com/onnx/onnx/blob/master/docs/Operators.md#dynamicquantizelinear ML_CHECK_VALID_ARGUMENT(kernelCreationContext.GetInputCount() == 1); ML_CHECK_VALID_ARGUMENT(kernelCreationContext.GetOutputCount() == 3); - DmlOperator::Initialize(kernelCreationContext, inputIndices, outputIndices); + + DmlOperator::Initialize(kernelCreationContext); + + m_inputTensorDescs[IN_A] = CreateTensorDescFromInput(kernelCreationContext, 0/*A OnnxIndex*/, TensorAxis::DoNotCoerce, TensorAxis::W, TensorAxis::RightAligned); + + m_outputTensorDescs[OUT_Y] = CreateTensorDescFromOutput(kernelCreationContext, 0/*Y OnnxIndex*/, TensorAxis::DoNotCoerce, TensorAxis::W, TensorAxis::RightAligned); + m_outputTensorDescs[OUT_Y_SCALE] = CreateTensorDescFromOutput(kernelCreationContext, 1/*Y Scale OnnxIndex*/, TensorAxis::DoNotCoerce, TensorAxis::W, TensorAxis::RightAligned); + m_outputTensorDescs[OUT_Y_ZERO_POINT] = CreateTensorDescFromOutput(kernelCreationContext, 2/*Y Zero point OnnxIndex*/, TensorAxis::DoNotCoerce, TensorAxis::W, TensorAxis::RightAligned); std::vector inputDescs = GetDmlInputDescs(); std::vector outputDescs = GetDmlOutputDescs(); - DML_PLACEHOLDER_OPERATOR_DESC operatorDesc = {}; - operatorDesc.IndicesTensor = &inputDescs[0]; - operatorDesc.ValuesTensor = &inputDescs[1]; - operatorDesc.OutputTensor = outputDescs.data(); - operatorDesc.Axis = dmlAxis; + DML_DYNAMIC_QUANTIZE_LINEAR_OPERATOR_DESC operatorDesc = {}; + operatorDesc.InputTensor = &inputDescs[IN_A]; + operatorDesc.OutputTensor = &outputDescs[OUT_Y]; + operatorDesc.OutputScaleTensor = &outputDescs[OUT_Y_SCALE]; + operatorDesc.OutputZeroPointTensor = &outputDescs[OUT_Y_ZERO_POINT]; - DML_OPERATOR_DESC opDesc = { DML_OPERATOR_PLACEHOLDER, &operatorDesc }; + DML_OPERATOR_DESC opDesc = { DML_OPERATOR_DYNAMIC_QUANTIZE_LINEAR, &operatorDesc }; SetDmlOperatorDesc(opDesc, kernelCreationContext); -#endif } }; diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorElementWise.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorElementWise.cpp index c8d12b2b22..08398e4b8e 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorElementWise.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorElementWise.cpp @@ -442,28 +442,34 @@ public: std::vector inputDescs = GetDmlInputDescs(); std::vector outputDescs = GetDmlOutputDescs(); - float minValue = -FLT_MAX; - float maxValue = FLT_MAX; - if (kernelInfo.IsInputValid(1)) - { - minValue = static_cast(ReadScalarTensorCastToFloat64(kernelInfo.GetConstantInputTensor(1))); - } - if (kernelInfo.IsInputValid(2)) { - maxValue = static_cast(ReadScalarTensorCastToFloat64(kernelInfo.GetConstantInputTensor(2))); - } - - DML_ELEMENT_WISE_CLIP_OPERATOR_DESC opDesc = {}; + DML_ELEMENT_WISE_CLIP1_OPERATOR_DESC opDesc = {}; opDesc.InputTensor = inputDescs.data(); opDesc.OutputTensor = outputDescs.data(); - opDesc.Min = minValue; - opDesc.Max = maxValue; + // MinMaxDataType will always be equal to inputDataTensorDataType + // Assigning minMaxDataType to inputDataTensorDataType because this field + // has to be assigned even if program does not go through below conditional + // logic for some corner test case + // Same applies to min and max value. + opDesc.MinMaxDataType = this->m_inputTensorDescs[0].GetDmlDataType(); + CastToClampedScalarUnion(opDesc.MinMaxDataType, -DBL_MAX, /*out*/&opDesc.Min); + CastToClampedScalarUnion(opDesc.MinMaxDataType, DBL_MAX, /*out*/&opDesc.Max); - SetDmlOperatorDesc({ DML_OPERATOR_ELEMENT_WISE_CLIP, &opDesc}, kernelInfo); + if (kernelInfo.IsInputValid(1)) + { + ReadScalarTensorData(kernelInfo.GetConstantInputTensor(1), /*out*/ &opDesc.Min.Bytes, sizeof(opDesc.Min.Bytes)); + } + if (kernelInfo.IsInputValid(2)) + { + ReadScalarTensorData(kernelInfo.GetConstantInputTensor(2), /*out*/ &opDesc.Max.Bytes, sizeof(opDesc.Max.Bytes)); + } + + SetDmlOperatorDesc({ DML_OPERATOR_ELEMENT_WISE_CLIP1, &opDesc}, kernelInfo); } }; // Same operator signature as 11. Only difference is new type support using DmlOperatorElementwiseClip12 = DmlOperatorElementwiseClip11; +using DmlOperatorElementwiseClip13 = DmlOperatorElementwiseClip11; class DmlOperatorElementwisePow : public DmlOperator { @@ -692,7 +698,7 @@ DML_OP_DEFINE_CREATION_FUNCTION(Ceil, DmlOperatorElementwiseUnary); DML_OP_DEFINE_CREATION_FUNCTION(Not, DmlOperatorElementwiseUnary); DML_OP_DEFINE_CREATION_FUNCTION(Sign, DmlOperatorElementwiseUnary); -DML_OP_DEFINE_CREATION_FUNCTION(IsNan, DmlOperatorElementwiseUnary); +DML_OP_DEFINE_CREATION_FUNCTION(IsNaN, DmlOperatorElementwiseUnary); DML_OP_DEFINE_CREATION_FUNCTION(Sinh, DmlOperatorElementwiseUnary); DML_OP_DEFINE_CREATION_FUNCTION(Cosh, DmlOperatorElementwiseUnary); DML_OP_DEFINE_CREATION_FUNCTION(Asinh, DmlOperatorElementwiseUnary); @@ -724,6 +730,7 @@ DML_OP_DEFINE_CREATION_FUNCTION(Mean, DmlOperatorElementwiseMean); DML_OP_DEFINE_CREATION_FUNCTION(Clip7, DmlOperatorElementwiseClip7); DML_OP_DEFINE_CREATION_FUNCTION(Clip11, DmlOperatorElementwiseClip11); DML_OP_DEFINE_CREATION_FUNCTION(Clip12, DmlOperatorElementwiseClip12); +DML_OP_DEFINE_CREATION_FUNCTION(Clip13, DmlOperatorElementwiseClip13); DML_OP_DEFINE_CREATION_FUNCTION(Pow, DmlOperatorElementwisePow); DML_OP_DEFINE_CREATION_FUNCTION(QuantizeLinear, DmlOperatorElementwiseQLinear); DML_OP_DEFINE_CREATION_FUNCTION(DequantizeLinear, DmlOperatorElementwiseQLinear); diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorNeg.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorNeg.cpp index 14de885f71..df1d782dbd 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorNeg.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorNeg.cpp @@ -22,16 +22,11 @@ public: std::vector inputDescs = GetDmlInputDescs(); std::vector outputDescs = GetDmlOutputDescs(); - DML_SCALE_BIAS scaleBias = {}; - scaleBias.Scale = -1.0f; - scaleBias.Bias = 0.0f; - - DML_ELEMENT_WISE_IDENTITY_OPERATOR_DESC opDesc = {}; + DML_ELEMENT_WISE_NEGATE_OPERATOR_DESC opDesc = {}; opDesc.InputTensor = inputDescs.data(); opDesc.OutputTensor = outputDescs.data(); - opDesc.ScaleBias = &scaleBias; - SetDmlOperatorDesc({ DML_OPERATOR_ELEMENT_WISE_IDENTITY, &opDesc}, kernelInfo); + SetDmlOperatorDesc({ DML_OPERATOR_ELEMENT_WISE_NEGATE, &opDesc}, kernelInfo); } }; diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorPadding.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorPadding.cpp index 2598c8c015..b0abb3baef 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorPadding.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorPadding.cpp @@ -56,34 +56,40 @@ public: ML_INVALID_ARGUMENT("Unknown Pad mode attribute."); } + std::vector inputDescs = GetDmlInputDescs(); + std::vector outputDescs = GetDmlOutputDescs(); + + DML_PADDING1_OPERATOR_DESC paddingDesc = {}; + paddingDesc.InputTensor = inputDescs.data(); + paddingDesc.OutputTensor = outputDescs.data(); + paddingDesc.PaddingMode = mode; + paddingDesc.DimensionCount = gsl::narrow_cast(m_startPadding.size()); + paddingDesc.StartPadding = m_startPadding.data(); + paddingDesc.EndPadding = m_endPadding.data(); + // PaddingValueDataType will always be equal to inputDataTensorDataType + // Assigning paddingValueDataType to inputDataTensorDataType because this field + // has to be assigned even if program does not go through below conditional + // logic for some corner test case (like opsetVersion >= 11, but no validInput at index 2) + // Same applies to paddingValue. + paddingDesc.PaddingValueDataType = this->m_inputTensorDescs[0].GetDmlDataType(); + CastToClampedScalarUnion(paddingDesc.PaddingValueDataType, 0.0f, /*out*/&paddingDesc.PaddingValue); + // Read the constant value which can come from an attribute or tensor. - float value = 0.0f; if (opsetVersion >= 11) { if (kernelInfo.IsInputValid(2)) { - auto valueTensor = kernelInfo.GetConstantInputTensor(2); - value = static_cast(ReadScalarTensorCastToFloat64(valueTensor)); + MLOperatorTensor constantPaddingValueTensor = kernelInfo.GetConstantInputTensor(2); + ReadScalarTensorData(constantPaddingValueTensor, /*out*/ &paddingDesc.PaddingValue.Bytes, sizeof(paddingDesc.PaddingValue.Bytes)); } } else { - value = kernelInfo.GetOptionalAttribute(AttrName::Value, 0.0f); + auto value = kernelInfo.GetOptionalAttribute(AttrName::Value, 0.0f); + CastToClampedScalarUnion(paddingDesc.PaddingValueDataType, value, /*out*/&paddingDesc.PaddingValue); } - std::vector inputDescs = GetDmlInputDescs(); - std::vector outputDescs = GetDmlOutputDescs(); - - DML_PADDING_OPERATOR_DESC paddingDesc = {}; - paddingDesc.InputTensor = inputDescs.data(); - paddingDesc.OutputTensor = outputDescs.data(); - paddingDesc.PaddingMode = mode; - paddingDesc.PaddingValue = value; - paddingDesc.DimensionCount = gsl::narrow_cast(m_startPadding.size()); - paddingDesc.StartPadding = m_startPadding.data(); - paddingDesc.EndPadding = m_endPadding.data(); - - DML_OPERATOR_DESC opDesc = { DML_OPERATOR_PADDING, &paddingDesc }; + DML_OPERATOR_DESC opDesc = { DML_OPERATOR_PADDING1, &paddingDesc }; SetDmlOperatorDesc(opDesc, kernelInfo); } @@ -91,5 +97,6 @@ public: DML_OP_DEFINE_CREATION_FUNCTION(Pad7, VersionedKernel); DML_OP_DEFINE_CREATION_FUNCTION(Pad11, VersionedKernel); +DML_OP_DEFINE_CREATION_FUNCTION(Pad13, VersionedKernel); } // namespace Dml diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorQLinearAdd.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorQLinearAdd.cpp new file mode 100644 index 0000000000..7b50dfb9ff --- /dev/null +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorQLinearAdd.cpp @@ -0,0 +1,69 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "precomp.h" + +namespace Dml +{ + +class DmlOperatorQLinearAdd : public DmlOperator +{ + enum InputTensors { + IN_A, + IN_A_SCALE, + IN_A_ZERO_POINT, + IN_B, + IN_B_SCALE, + IN_B_ZERO_POINT, + IN_C_SCALE, + IN_C_ZERO_POINT + }; + +public: + DmlOperatorQLinearAdd(const MLOperatorKernelCreationContext& kernelInfo) + : DmlOperator(kernelInfo) + { + DmlOperator::Initialize(kernelInfo); + + std::vector outputShape = kernelInfo.GetTensorShapeDescription().GetOutputTensorShape(0); + + uint32_t dmlDimSize = m_inputTensorDescs[0].GetDimensionCount(); + + // Initialize the input descriptions with broadcasting + m_inputTensorDescs[IN_A] = CreateTensorDescFromInput(kernelInfo, 0/*A OnnxIndex*/, TensorAxis::DoNotCoerce, TensorAxis::W, TensorAxis::RightAligned, outputShape); + m_inputTensorDescs[IN_B] = CreateTensorDescFromInput(kernelInfo, 3/*B OnnxIndex*/, TensorAxis::DoNotCoerce, TensorAxis::W, TensorAxis::RightAligned, outputShape); + + m_inputTensorDescs[IN_A_SCALE] = CreateTensorDescFromInput(kernelInfo, 1/*A Scale OnnxIndex*/, TensorAxis::DoNotCoerce, TensorAxis::W, TensorAxis::RightAligned, std::nullopt, dmlDimSize); + m_inputTensorDescs[IN_A_ZERO_POINT] = CreateTensorDescFromInput(kernelInfo, 2/*A Zero point OnnxIndex*/, TensorAxis::DoNotCoerce, TensorAxis::W, TensorAxis::RightAligned, std::nullopt, dmlDimSize); + + m_inputTensorDescs[IN_B_SCALE] = CreateTensorDescFromInput(kernelInfo, 4/*B Scale OnnxIndex*/, TensorAxis::DoNotCoerce, TensorAxis::W, TensorAxis::RightAligned, std::nullopt, dmlDimSize); + m_inputTensorDescs[IN_B_ZERO_POINT] = CreateTensorDescFromInput(kernelInfo, 5/*B Zero point OnnxIndex*/, TensorAxis::DoNotCoerce, TensorAxis::W, TensorAxis::RightAligned, std::nullopt, dmlDimSize); + + m_inputTensorDescs[IN_C_SCALE] = CreateTensorDescFromInput(kernelInfo, 6/*C Zero point OnnxIndex*/, TensorAxis::DoNotCoerce, TensorAxis::W, TensorAxis::RightAligned, std::nullopt, dmlDimSize); + m_inputTensorDescs[IN_C_ZERO_POINT] = CreateTensorDescFromInput(kernelInfo, 7/*C Zero point OnnxIndex*/, TensorAxis::DoNotCoerce, TensorAxis::W, TensorAxis::RightAligned, std::nullopt, dmlDimSize); + + // Initialize the output description while overriding the shape + m_outputTensorDescs[0] = CreateTensorDescFromOutput(kernelInfo, 0, TensorAxis::DoNotCoerce, TensorAxis::W, TensorAxis::RightAligned, outputShape); + + std::vector inputDescs = GetDmlInputDescs(); + std::vector outputDescs = GetDmlOutputDescs(); + + DML_ELEMENT_WISE_QUANTIZED_LINEAR_ADD_OPERATOR_DESC AddDesc = {}; + AddDesc.ATensor = &inputDescs[IN_A]; + AddDesc.AScaleTensor = &inputDescs[IN_A_SCALE]; + AddDesc.AZeroPointTensor = inputDescs[IN_A_ZERO_POINT].Desc != nullptr ? &inputDescs[IN_A_ZERO_POINT] : nullptr; + AddDesc.BTensor = &inputDescs[IN_B]; + AddDesc.BScaleTensor = &inputDescs[IN_B_SCALE]; + AddDesc.BZeroPointTensor = inputDescs[IN_B_ZERO_POINT].Desc != nullptr ? &inputDescs[IN_B_ZERO_POINT] : nullptr; + AddDesc.OutputScaleTensor = &inputDescs[IN_C_SCALE]; + AddDesc.OutputZeroPointTensor = inputDescs[IN_C_ZERO_POINT].Desc != nullptr ? &inputDescs[IN_C_ZERO_POINT] : nullptr; + AddDesc.OutputTensor = &outputDescs[0]; + + DML_OPERATOR_DESC opDesc = { DML_OPERATOR_ELEMENT_WISE_QUANTIZED_LINEAR_ADD, &AddDesc }; + SetDmlOperatorDesc(opDesc, kernelInfo); + } +}; + +DML_OP_DEFINE_CREATION_FUNCTION(QLinearAdd, DmlOperatorQLinearAdd); + +} // namespace Dml diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorResize.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorResize.cpp index 63cc666aa7..f1c6531a1d 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorResize.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorResize.cpp @@ -341,5 +341,6 @@ DML_OP_DEFINE_CREATION_FUNCTION(Resize11, VersionedKernel DML_OP_DEFINE_CREATION_FUNCTION(Upsample7, VersionedKernel); DML_OP_DEFINE_CREATION_FUNCTION(Upsample9, VersionedKernel); DML_OP_DEFINE_CREATION_FUNCTION(Upsample10, VersionedKernel); +DML_OP_DEFINE_CREATION_FUNCTION(Upsample13, VersionedKernel); } // namespace Dml diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorScatter.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorScatter.cpp index 9a7f7de526..a58098a81e 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorScatter.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorScatter.cpp @@ -113,6 +113,7 @@ public: DML_OP_DEFINE_CREATION_FUNCTION(Scatter9, DmlOperatorScatter); DML_OP_DEFINE_CREATION_FUNCTION(Scatter11, DmlOperatorScatter); +DML_OP_DEFINE_CREATION_FUNCTION(Scatter13, DmlOperatorScatter); DML_OP_DEFINE_CREATION_FUNCTION(ScatterElements, DmlOperatorScatter); DML_OP_DEFINE_CREATION_FUNCTION(ScatterND, DmlOperatorScatterNd); diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorSlice.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorSlice.cpp index fd6a12ecae..04aea8562e 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorSlice.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorSlice.cpp @@ -55,4 +55,5 @@ void CALLBACK QuerySlice(IMLOperatorSupportQueryContextPrivate* context, bool* i DML_OP_DEFINE_CREATION_FUNCTION(Slice7, VersionedKernel ); DML_OP_DEFINE_CREATION_FUNCTION(Slice10, VersionedKernel); DML_OP_DEFINE_CREATION_FUNCTION(Slice11, VersionedKernel); +DML_OP_DEFINE_CREATION_FUNCTION(Slice13, VersionedKernel); } // namespace Dml diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.cpp index 03bb77bf0d..ff33a8d55a 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.cpp @@ -35,14 +35,18 @@ enum class SupportedTensorDataTypes : uint32_t Complex64 = 1<<14, Complex128 = 1<<15, Ints8to32 = UInt8|Int8|UInt16|Int16|UInt32|Int32, - Int32to64 = UInt32|Int32|UInt64|Int64, - Float16to32 = Float16|Float32, // Float64 is not supported by DirectML. - NumericDefault = Ints8to32|Float16to32, + Ints32to64 = UInt32|Int32|UInt64|Int64, + Ints8to64 = UInt8|Int8|UInt16|Int16|UInt32|Int32|UInt64|Int64, + UInt8to64 = UInt8|UInt16|UInt32|UInt64, + Float16to32 = Float16|Float32, + Float16to64 = Float16|Float32|Float64, + NumericDefault = Ints8to32|Float16to32, // Only simple numbers, not bool, complex, or string. Scalars8to32 = UInt8|Int8|UInt16|Int16|UInt32|Int32|Float16to32|Bool, - AllScalars = UInt8|Int8|UInt16|Int16|UInt32|Int32|UInt64|Int64|Float16to32|Bool, + AllScalars = UInt8|Int8|UInt16|Int16|UInt32|Int32|UInt64|Int64|Float16|Float32|Float64|Bool, Ints8Bit = UInt8|Int8, Ints16Bit = UInt16|Int16, Ints32Bit = UInt32|Int32, + Ints64Bit = UInt64|Int64, All = static_cast(-1), }; DEFINE_ENUM_FLAG_OPERATORS(Dml::SupportedTensorDataTypes); @@ -54,6 +58,7 @@ enum class DmlGraphSupport : uint32_t SupportedWith64BitTensorsVia32BitStrides = 2, // Supports them via 32-bit tensors and doubled strides. SupportedWith64BitTensorsVia32BitStridesFromAnyEp = 4, // Supports input from any execution provider (otherwise only inputs from other DML nodes) Prefer64BitTensorsDirectly = 8, // Natively supports 64-bit tensors. So avoid strided 32-bit unless the device lacks support. + Support64BitTensorsViaEmulation = 16,// supports int/uint64 tensors via emulation of 32-bit types. }; DEFINE_ENUM_FLAG_OPERATORS(DmlGraphSupport); @@ -109,8 +114,10 @@ DML_OP_EXTERN_CREATION_FUNCTION(Concat); DML_OP_EXTERN_CREATION_FUNCTION(Slice7); DML_OP_EXTERN_CREATION_FUNCTION(Slice10); DML_OP_EXTERN_CREATION_FUNCTION(Slice11); +DML_OP_EXTERN_CREATION_FUNCTION(Slice13); DML_OP_EXTERN_CREATION_FUNCTION(Pad7); DML_OP_EXTERN_CREATION_FUNCTION(Pad11); +DML_OP_EXTERN_CREATION_FUNCTION(Pad13); DML_OP_EXTERN_CREATION_FUNCTION(SpaceToDepth); DML_OP_EXTERN_CREATION_FUNCTION(DepthToSpace); DML_OP_EXTERN_CREATION_FUNCTION(Sqrt); @@ -124,6 +131,7 @@ DML_OP_EXTERN_CREATION_FUNCTION(Floor); DML_OP_EXTERN_CREATION_FUNCTION(Clip7); DML_OP_EXTERN_CREATION_FUNCTION(Clip11); DML_OP_EXTERN_CREATION_FUNCTION(Clip12); +DML_OP_EXTERN_CREATION_FUNCTION(Clip13); DML_OP_EXTERN_CREATION_FUNCTION(Greater); DML_OP_EXTERN_CREATION_FUNCTION(Less); DML_OP_EXTERN_CREATION_FUNCTION(GreaterOrEqual); @@ -161,6 +169,7 @@ DML_OP_EXTERN_CREATION_FUNCTION(ImageScaler); DML_OP_EXTERN_CREATION_FUNCTION(Upsample7); DML_OP_EXTERN_CREATION_FUNCTION(Upsample9); DML_OP_EXTERN_CREATION_FUNCTION(Upsample10); +DML_OP_EXTERN_CREATION_FUNCTION(Upsample13); DML_OP_EXTERN_CREATION_FUNCTION(Sigmoid); DML_OP_EXTERN_CREATION_FUNCTION(HardSigmoid); DML_OP_EXTERN_CREATION_FUNCTION(Tanh); @@ -206,7 +215,7 @@ DML_OP_EXTERN_CREATION_FUNCTION(FusedSum); DML_OP_EXTERN_CREATION_FUNCTION(QuantizeLinear); DML_OP_EXTERN_CREATION_FUNCTION(DequantizeLinear); DML_OP_EXTERN_CREATION_FUNCTION(Sign); -DML_OP_EXTERN_CREATION_FUNCTION(IsNan); +DML_OP_EXTERN_CREATION_FUNCTION(IsNaN); DML_OP_EXTERN_CREATION_FUNCTION(Sinh); DML_OP_EXTERN_CREATION_FUNCTION(Cosh); DML_OP_EXTERN_CREATION_FUNCTION(Tanh); @@ -221,6 +230,7 @@ DML_OP_EXTERN_CREATION_FUNCTION(EyeLike); DML_OP_EXTERN_CREATION_FUNCTION(MaxUnpool); DML_OP_EXTERN_CREATION_FUNCTION(Scatter9); DML_OP_EXTERN_CREATION_FUNCTION(Scatter11); +DML_OP_EXTERN_CREATION_FUNCTION(Scatter13); DML_OP_EXTERN_CREATION_FUNCTION(Resize10); DML_OP_EXTERN_CREATION_FUNCTION(Resize11); DML_OP_EXTERN_CREATION_FUNCTION(ConstantOfShape); @@ -235,6 +245,7 @@ DML_OP_EXTERN_CREATION_FUNCTION(ReverseSequence); DML_OP_EXTERN_CREATION_FUNCTION(Round); DML_OP_EXTERN_CREATION_FUNCTION(ScatterElements); DML_OP_EXTERN_CREATION_FUNCTION(ScatterND); +DML_OP_EXTERN_CREATION_FUNCTION(QLinearAdd); DML_OP_EXTERN_CREATION_FUNCTION(QLinearConv); DML_OP_EXTERN_CREATION_FUNCTION(QLinearMatMul); DML_OP_EXTERN_CREATION_FUNCTION(DynamicQuantizeLinear); @@ -251,6 +262,7 @@ constexpr static std::array typeNameListTwo = { "T1", "T2" }; constexpr static std::array typeNameListThree = { "T1", "T2", "T3" }; constexpr static std::array typeNameListFour = { "T1", "T2", "T3", "T4" }; constexpr static std::array typeNameListTopK = { "T", "I" }; +constexpr static std::array typeNameListMaxPool = { "T", "I" }; constexpr static std::array typeNameListLogicalComparison = { "T", "T1" }; constexpr static std::array typeNameListPow12 = {"T", "T1"}; constexpr static std::array typeNameListConstantOfShape = { "T1", "T2" }; @@ -259,51 +271,65 @@ constexpr static std::array typeNameListScatterGatherND = { "T" constexpr static std::array typeNameListSlice10 = { "T", "Tind" }; constexpr static std::array typeNameListWhere = { "B", "T" }; constexpr static std::array typeNameListEyeLike = { "T2" }; + constexpr static std::array supportedTypeListAll = {SupportedTensorDataTypes::All}; constexpr static std::array supportedTypeListFloat32 = {SupportedTensorDataTypes::Float32}; constexpr static std::array supportedTypeListFloat16to32 = {SupportedTensorDataTypes::Float16to32}; -constexpr static std::array supportedTypeListFloat16to32Int8 = {SupportedTensorDataTypes::Float16to32 | SupportedTensorDataTypes::Ints8Bit }; -constexpr static std::array supportedTypeListFloat16to32Int32 = {SupportedTensorDataTypes::Float16to32 | SupportedTensorDataTypes::Int32 | SupportedTensorDataTypes::UInt32}; -constexpr static std::array supportedTypeListFloat16to32Int8to32 = {SupportedTensorDataTypes::Float16to32 | SupportedTensorDataTypes::Ints8Bit | SupportedTensorDataTypes::Ints16Bit | SupportedTensorDataTypes::Ints32Bit}; -constexpr static std::array supportedTypeListInt8to32 = {SupportedTensorDataTypes::Ints8to32}; -constexpr static std::array supportedTypeListInt32to64AndFloat16to32 = {SupportedTensorDataTypes::Int32to64|SupportedTensorDataTypes::Float16to32}; +constexpr static std::array supportedTypeListFloat16to32Ints8 = {SupportedTensorDataTypes::Float16to32 | SupportedTensorDataTypes::Ints8Bit }; +constexpr static std::array supportedTypeListFloat16to32Ints32 = {SupportedTensorDataTypes::Float16to32 | SupportedTensorDataTypes::Int32 | SupportedTensorDataTypes::UInt32}; +constexpr static std::array supportedTypeListFloat16to32Ints8to32 = {SupportedTensorDataTypes::Float16to32 | SupportedTensorDataTypes::Ints8Bit | SupportedTensorDataTypes::Ints16Bit | SupportedTensorDataTypes::Ints32Bit}; +constexpr static std::array supportedTypeListFloat16to32Ints8to64 = {SupportedTensorDataTypes::Float16to32 | SupportedTensorDataTypes::Ints8Bit | SupportedTensorDataTypes::Ints16Bit | SupportedTensorDataTypes::Ints32Bit | SupportedTensorDataTypes::Ints64Bit}; +constexpr static std::array supportedTypeListFloat16to32Ints32to64 = {SupportedTensorDataTypes::Float16to32 | SupportedTensorDataTypes::Ints32Bit | SupportedTensorDataTypes::Ints64Bit}; +constexpr static std::array supportedTypeListUInt8to64 = {SupportedTensorDataTypes::UInt8to64}; constexpr static std::array supportedTypeListNumericDefault = { SupportedTensorDataTypes::NumericDefault }; constexpr static std::array supportedTypeListAllScalars = { SupportedTensorDataTypes::AllScalars }; constexpr static std::array supportedTypeListBool = {SupportedTensorDataTypes::Bool}; constexpr static std::array supportedTypeListPow12 = {SupportedTensorDataTypes::Int32 | SupportedTensorDataTypes::Float16to32, SupportedTensorDataTypes::NumericDefault}; -constexpr static std::array supportedTypeListTopK = {SupportedTensorDataTypes::NumericDefault, SupportedTensorDataTypes::Int64}; +constexpr static std::array supportedTypeListTopK = {SupportedTensorDataTypes::NumericDefault | SupportedTensorDataTypes::Ints64Bit, SupportedTensorDataTypes::Int64}; +constexpr static std::array supportedTypeListMaxPool = {SupportedTensorDataTypes::Float16to32 | SupportedTensorDataTypes::Ints8Bit, SupportedTensorDataTypes::Int64}; +constexpr static std::array supportedTypeListMaxUnpool = {SupportedTensorDataTypes::Float16to32, SupportedTensorDataTypes::Int64}; constexpr static std::array supportedTypeListIndices = { SupportedTensorDataTypes::Int32|SupportedTensorDataTypes::Int64 }; constexpr static std::array supportedTypeListCast = { SupportedTensorDataTypes::AllScalars, SupportedTensorDataTypes::AllScalars }; constexpr static std::array supportedTypeListScalars8to32 = { SupportedTensorDataTypes::Scalars8to32 }; -constexpr static std::array supportedTypeListScatterGather = { SupportedTensorDataTypes::Scalars8to32, SupportedTensorDataTypes::Int32 | SupportedTensorDataTypes::Int64 }; -constexpr static std::array supportedTypeListScatterGatherND = { SupportedTensorDataTypes::Scalars8to32 }; -constexpr static std::array supportedTypeListSlice10 = { SupportedTensorDataTypes::NumericDefault, SupportedTensorDataTypes::Int32 | SupportedTensorDataTypes::Int64 }; +constexpr static std::array supportedTypeListScatterGather = { SupportedTensorDataTypes::AllScalars, SupportedTensorDataTypes::Int32 | SupportedTensorDataTypes::Int64 }; +constexpr static std::array supportedTypeListScatterGatherND = { SupportedTensorDataTypes::AllScalars }; +constexpr static std::array supportedTypeListSlice10 = { SupportedTensorDataTypes::AllScalars, SupportedTensorDataTypes::Int32 | SupportedTensorDataTypes::Int64 }; constexpr static std::array supportedTypeListQuantizeLinear = { SupportedTensorDataTypes::Float32 | SupportedTensorDataTypes::Int32, SupportedTensorDataTypes::UInt8 | SupportedTensorDataTypes::Int8 }; constexpr static std::array supportedTypeListDequantizeLinear = { SupportedTensorDataTypes::Float32, SupportedTensorDataTypes::UInt8 | SupportedTensorDataTypes::Int8 | SupportedTensorDataTypes::Int32 }; constexpr static std::array supportedTypeListQuantize = { SupportedTensorDataTypes::Float32, SupportedTensorDataTypes::UInt8 }; constexpr static std::array supportedTypeListIsNan = { SupportedTensorDataTypes::Float16to32, SupportedTensorDataTypes::Bool }; constexpr static std::array supportedTypeListIsInf = { SupportedTensorDataTypes::Float32, SupportedTensorDataTypes::Bool }; -constexpr static std::array supportedTypeListConstantOfShape = { SupportedTensorDataTypes::Int32|SupportedTensorDataTypes::Int64, SupportedTensorDataTypes::Float16to32 }; -constexpr static std::array supportedTypeListWhere = { SupportedTensorDataTypes::Bool, SupportedTensorDataTypes::AllScalars }; -constexpr static std::array supportedTypeListOneHot = /* indices, depth, values */ { SupportedTensorDataTypes::Int32to64, SupportedTensorDataTypes::AllScalars, SupportedTensorDataTypes::Scalars8to32 }; +constexpr static std::array supportedTypeListConstantOfShape = { SupportedTensorDataTypes::Int64, SupportedTensorDataTypes::AllScalars }; +constexpr static std::array supportedTypeListWhere = { SupportedTensorDataTypes::Bool, SupportedTensorDataTypes::Scalars8to32 }; +constexpr static std::array supportedTypeListOneHot = /* indices, depth, values */ { SupportedTensorDataTypes::Ints32to64, SupportedTensorDataTypes::AllScalars, SupportedTensorDataTypes::AllScalars }; constexpr static std::array supportedTypeListLogicalComparison7 = /* A&B,C */ { SupportedTensorDataTypes::Float16to32, SupportedTensorDataTypes::Bool }; -constexpr static std::array supportedTypeListLogicalComparison9 = /* A&B,C */ { SupportedTensorDataTypes::NumericDefault, SupportedTensorDataTypes::Bool }; -constexpr static std::array supportedTypeListSigned = { SupportedTensorDataTypes::Float16to32 | SupportedTensorDataTypes::Int32 | SupportedTensorDataTypes::Int16 | SupportedTensorDataTypes::Int8 }; -constexpr static std::array supportedTypeListRange = {SupportedTensorDataTypes::Int16|SupportedTensorDataTypes::Int32|SupportedTensorDataTypes::Float32}; +constexpr static std::array supportedTypeListLogicalComparison9 = /* A&B,C */ { SupportedTensorDataTypes::Float16to32|SupportedTensorDataTypes::Ints8to64, SupportedTensorDataTypes::Bool }; +constexpr static std::array supportedTypeListSigned = { SupportedTensorDataTypes::Float16to32 | SupportedTensorDataTypes::Int64 | SupportedTensorDataTypes::Int32 | SupportedTensorDataTypes::Int16 | SupportedTensorDataTypes::Int8 }; +constexpr static std::array supportedTypeListRange = {SupportedTensorDataTypes::Int16|SupportedTensorDataTypes::Int32|SupportedTensorDataTypes::Int64|SupportedTensorDataTypes::Float32}; +constexpr static std::array supportedTypeListResize11 = {SupportedTensorDataTypes::Float16to32, SupportedTensorDataTypes::Float16to32 /* ROI read by CPU */}; constexpr static std::array supportedTypeListInteger = {SupportedTensorDataTypes::Int8|SupportedTensorDataTypes::UInt8, SupportedTensorDataTypes::Int8|SupportedTensorDataTypes::UInt8, SupportedTensorDataTypes::Int32 }; +constexpr static std::array supportedTypeListInteger8 = {SupportedTensorDataTypes::Int8|SupportedTensorDataTypes::UInt8 }; constexpr static std::array supportedTypeListRoiAlign = {SupportedTensorDataTypes::Float16to32, SupportedTensorDataTypes::Int32|SupportedTensorDataTypes::Int64 }; +constexpr static std::array supportedTypeListArgMinMax = {SupportedTensorDataTypes::Float16to32|SupportedTensorDataTypes::Ints8to64}; + constexpr static std::array supportedTypeListQLinearMatMul = { - SupportedTensorDataTypes::Int8|SupportedTensorDataTypes::UInt8, - SupportedTensorDataTypes::Int8|SupportedTensorDataTypes::UInt8, - SupportedTensorDataTypes::Int8|SupportedTensorDataTypes::UInt8 + SupportedTensorDataTypes::Int8|SupportedTensorDataTypes::UInt8, + SupportedTensorDataTypes::Int8|SupportedTensorDataTypes::UInt8, + SupportedTensorDataTypes::Int8|SupportedTensorDataTypes::UInt8 }; constexpr static std::array supportedTypeListQLinearConv = { - SupportedTensorDataTypes::Int8|SupportedTensorDataTypes::UInt8, - SupportedTensorDataTypes::Int8|SupportedTensorDataTypes::UInt8, - SupportedTensorDataTypes::Int8|SupportedTensorDataTypes::UInt8, + SupportedTensorDataTypes::Int8|SupportedTensorDataTypes::UInt8, + SupportedTensorDataTypes::Int8|SupportedTensorDataTypes::UInt8, + SupportedTensorDataTypes::Int8|SupportedTensorDataTypes::UInt8, SupportedTensorDataTypes::Int32 }; + +constexpr static std::array supportedTypeListDynamicQuantizeLinear = { + SupportedTensorDataTypes::Float32, + SupportedTensorDataTypes::UInt8, +}; + template constexpr auto requiredConstantCpuInputs(Args... args) { @@ -336,249 +362,313 @@ constexpr auto requiredConstantCpuInputs(Args... args) constexpr static OperatorRegistrationInformation operatorRegistrationInformationTable[] = { -/// Domain/Type, Ver, Name, TypeNames, Types, Graph Support, Required const CPU inputs, -/// Input count required for graph support, -/// Support query function +/// Domain/Type, Ver, Name, TypeNames, Types, Graph Support, Required const CPU inputs, +/// Input count required for graph support, +/// Support query function // Deep Learning Standard Layers - {REG_INFO( 7, Conv, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, - {REG_INFO( 11, Conv, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, - {REG_INFO( 7, ConvTranspose, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, - {REG_INFO( 11, ConvTranspose, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, - {REG_INFO( 7, AveragePool, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, - {REG_INFO( 10, AveragePool, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, - {REG_INFO( 11, AveragePool, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, - {REG_INFO( 7, GlobalAveragePool, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, - {REG_INFO( 7, MaxPool, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides)}, - {REG_INFO( 8, MaxPool, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides, requiredConstantCpuInputs(), std::nullopt, QueryMaxPool)}, - {REG_INFO( 10, MaxPool, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides, requiredConstantCpuInputs(), std::nullopt, QueryMaxPool)}, - {REG_INFO( 11, MaxPool, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides, requiredConstantCpuInputs(), std::nullopt, QueryMaxPool)}, - {REG_INFO( 12, MaxPool, typeNameListDefault, supportedTypeListFloat16to32Int8, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides, requiredConstantCpuInputs(), std::nullopt, QueryMaxPool)}, + {REG_INFO( 7, Conv, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, + {REG_INFO( 11, Conv, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, + {REG_INFO( 7, ConvTranspose, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, + {REG_INFO( 11, ConvTranspose, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, + {REG_INFO( 7, AveragePool, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, + {REG_INFO( 10, AveragePool, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, + {REG_INFO( 11, AveragePool, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, + {REG_INFO( 7, GlobalAveragePool, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, + {REG_INFO( 7, MaxPool, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides)}, + {REG_INFO( 8, MaxPool, typeNameListMaxPool, supportedTypeListMaxPool, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly, requiredConstantCpuInputs(), std::nullopt, QueryMaxPool)}, + {REG_INFO( 10, MaxPool, typeNameListMaxPool, supportedTypeListMaxPool, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly, requiredConstantCpuInputs(), std::nullopt, QueryMaxPool)}, + {REG_INFO( 11, MaxPool, typeNameListMaxPool, supportedTypeListMaxPool, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly, requiredConstantCpuInputs(), std::nullopt, QueryMaxPool)}, + {REG_INFO( 12, MaxPool, typeNameListMaxPool, supportedTypeListMaxPool, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly, requiredConstantCpuInputs(), std::nullopt, QueryMaxPool)}, - {REG_INFO( 7, GlobalMaxPool, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, - {REG_INFO( 7, LpPool, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, - {REG_INFO( 11, LpPool, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, - {REG_INFO( 7, GlobalLpPool, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, - {REG_INFO( 7, MaxRoiPool, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, - {REG_INFO_VER( 10, RoiAlign, typeNameListTwo, supportedTypeListRoiAlign, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStridesFromAnyEp)}, - {REG_INFO( 7, InstanceNormalization, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, - {REG_INFO( 7, BatchNormalization, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, - {REG_INFO( 9, BatchNormalization, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, // v9 just removes 'spatial' attribute. - {REG_INFO( 7, LRN, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, - {REG_INFO( 7, MeanVarianceNormalization, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, - {REG_INFO( 9, MeanVarianceNormalization, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, - {REG_INFO( 7, LpNormalization, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, - {REG_INFO( 7, RNN, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::NotSupported)}, - {REG_INFO( 7, GRU, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::NotSupported)}, - {REG_INFO( 7, LSTM, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::NotSupported)}, - {REG_INFO_MS( 1, ConvTransposeWithDynamicPads, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported, requiredConstantCpuInputs(2))}, + {REG_INFO( 7, GlobalMaxPool, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, + {REG_INFO( 7, LpPool, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, + {REG_INFO( 11, LpPool, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, + {REG_INFO( 7, GlobalLpPool, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, + {REG_INFO( 7, MaxRoiPool, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, + {REG_INFO_VER( 10, RoiAlign, typeNameListTwo, supportedTypeListRoiAlign, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly)}, + {REG_INFO( 7, InstanceNormalization, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, + {REG_INFO( 7, BatchNormalization, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, + {REG_INFO( 9, BatchNormalization, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, // v9 just removes 'spatial' attribute. + {REG_INFO( 7, LRN, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, + {REG_INFO( 13, LRN, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, + {REG_INFO( 7, MeanVarianceNormalization, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, + {REG_INFO( 9, MeanVarianceNormalization, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, + {REG_INFO( 13, MeanVarianceNormalization, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, + {REG_INFO( 7, LpNormalization, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, + {REG_INFO( 7, RNN, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::NotSupported)}, + {REG_INFO( 7, GRU, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::NotSupported)}, + {REG_INFO( 7, LSTM, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::NotSupported)}, + {REG_INFO_MS( 1, ConvTransposeWithDynamicPads, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported, requiredConstantCpuInputs(2))}, // Data Reorganization Layers - {REG_INFO( 7, Split, typeNameListDefault, supportedTypeListNumericDefault, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides)}, - {REG_INFO( 11, Split, typeNameListDefault, supportedTypeListNumericDefault, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides)}, // Adds negative axis. - {REG_INFO( 7, Transpose, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides)}, - {REG_INFO( 7, Concat, typeNameListDefault, supportedTypeListNumericDefault, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides)}, - {REG_INFO( 11, Concat, typeNameListDefault, supportedTypeListNumericDefault, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides)}, // Adds negative axis. - {REG_INFO_VER( 7, Slice, typeNameListDefault, supportedTypeListNumericDefault, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides)}, - {REG_INFO_VER( 10, Slice, typeNameListSlice10, supportedTypeListSlice10, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides, requiredConstantCpuInputs(1, 2, 3, 4), std::nullopt, QuerySlice)}, // Adds negative axes. - {REG_INFO_VER( 11, Slice, typeNameListSlice10, supportedTypeListSlice10, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides, requiredConstantCpuInputs(1, 2, 3, 4), std::nullopt, QuerySlice)}, - {REG_INFO_VER( 7, Pad, typeNameListDefault, supportedTypeListNumericDefault, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides)}, - {REG_INFO_VER( 11, Pad, typeNameListDefault, supportedTypeListNumericDefault, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides, requiredConstantCpuInputs(1, 2) /*pads, value*/)}, // https://microsoft.visualstudio.com/OS/_workitems/edit/26007728 - {REG_INFO( 7, SpaceToDepth, typeNameListDefault, supportedTypeListScalars8to32, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides)}, - {REG_INFO( 7, DepthToSpace, typeNameListDefault, supportedTypeListScalars8to32, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides)}, - {REG_INFO( 11, DepthToSpace, typeNameListDefault, supportedTypeListScalars8to32, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides)}, - {REG_INFO( 7, Tile, typeNameListDefault, supportedTypeListScalars8to32, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides, requiredConstantCpuInputs(1))}, - {REG_INFO( 8, Expand, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides, requiredConstantCpuInputs(1))}, - {REG_INFO( 9, ConstantOfShape, typeNameListConstantOfShape, supportedTypeListConstantOfShape, DmlGraphSupport::NotSupported, requiredConstantCpuInputs(0))}, - {REG_INFO( 7, Gather, typeNameListScatterGather, supportedTypeListScatterGather, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStridesFromAnyEp)}, - {REG_INFO( 11, Gather, typeNameListScatterGather, supportedTypeListScatterGather, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStridesFromAnyEp)}, - {REG_INFO( 11, GatherElements, typeNameListScatterGather, supportedTypeListScatterGather, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStridesFromAnyEp)}, - {REG_INFO( 11, GatherND, typeNameListScatterGatherND, supportedTypeListScatterGatherND, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStridesFromAnyEp)}, - {REG_INFO( 12, GatherND, typeNameListScatterGatherND, supportedTypeListScatterGatherND, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStridesFromAnyEp)}, - {REG_INFO_VER( 9, Scatter, typeNameListScatterGather, supportedTypeListScatterGather, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStridesFromAnyEp)}, - {REG_INFO_VER( 11, Scatter, typeNameListScatterGather, supportedTypeListScatterGather, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStridesFromAnyEp)}, - {REG_INFO( 11, ScatterElements, typeNameListScatterGather, supportedTypeListScatterGather, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStridesFromAnyEp)}, - {REG_INFO( 11, ScatterND, typeNameListScatterGatherND, supportedTypeListScatterGatherND, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStridesFromAnyEp)}, - {REG_INFO( 9, EyeLike, typeNameListEyeLike, supportedTypeListScalars8to32, DmlGraphSupport::Supported)}, + {REG_INFO( 7, Split, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly)}, + {REG_INFO( 11, Split, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly)}, // Adds negative axis. + {REG_INFO( 7, Transpose, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly)}, + {REG_INFO( 13, Transpose, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly)}, + {REG_INFO( 7, Concat, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly)}, + {REG_INFO( 11, Concat, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly)}, // Adds negative axis. + {REG_INFO( 13, Concat, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly)}, // Adds negative axis. + {REG_INFO_VER( 7, Slice, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly)}, + {REG_INFO_VER( 10, Slice, typeNameListSlice10, supportedTypeListSlice10, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly, requiredConstantCpuInputs(1, 2, 3, 4), std::nullopt, QuerySlice)}, // Adds negative axes. + {REG_INFO_VER( 11, Slice, typeNameListSlice10, supportedTypeListSlice10, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly, requiredConstantCpuInputs(1, 2, 3, 4), std::nullopt, QuerySlice)}, + {REG_INFO_VER( 13, Slice, typeNameListSlice10, supportedTypeListSlice10, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly, requiredConstantCpuInputs(1, 2, 3, 4), std::nullopt, QuerySlice)}, + {REG_INFO_VER( 7, Pad, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly)}, + {REG_INFO_VER( 11, Pad, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly, requiredConstantCpuInputs(1, 2) /*pads, value*/)}, // https://microsoft.visualstudio.com/OS/_workitems/edit/26007728 + {REG_INFO_VER( 13, Pad, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly, requiredConstantCpuInputs(1, 2) /*pads, value*/)}, // https://microsoft.visualstudio.com/OS/_workitems/edit/26007728 + {REG_INFO( 7, SpaceToDepth, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly)}, + {REG_INFO( 13, SpaceToDepth, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly)}, + {REG_INFO( 7, DepthToSpace, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly)}, + {REG_INFO( 11, DepthToSpace, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly)}, + {REG_INFO( 13, DepthToSpace, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly)}, + {REG_INFO( 7, Tile, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly, requiredConstantCpuInputs(1))}, + {REG_INFO( 13, Tile, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly, requiredConstantCpuInputs(1))}, + {REG_INFO( 8, Expand, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly, requiredConstantCpuInputs(1))}, + {REG_INFO( 13, Expand, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly, requiredConstantCpuInputs(1))}, + {REG_INFO( 9, ConstantOfShape, typeNameListConstantOfShape, supportedTypeListConstantOfShape, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly, requiredConstantCpuInputs(0))}, + {REG_INFO( 7, Gather, typeNameListScatterGather, supportedTypeListScatterGather, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly)}, + {REG_INFO( 11, Gather, typeNameListScatterGather, supportedTypeListScatterGather, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly)}, + {REG_INFO( 13, Gather, typeNameListScatterGather, supportedTypeListScatterGather, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly)}, + {REG_INFO( 11, GatherElements, typeNameListScatterGather, supportedTypeListScatterGather, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly)}, + {REG_INFO( 13, GatherElements, typeNameListScatterGather, supportedTypeListScatterGather, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly)}, + {REG_INFO( 11, GatherND, typeNameListScatterGatherND, supportedTypeListScatterGatherND, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly)}, + {REG_INFO( 12, GatherND, typeNameListScatterGatherND, supportedTypeListScatterGatherND, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly)}, + {REG_INFO( 13, GatherND, typeNameListScatterGatherND, supportedTypeListScatterGatherND, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly)}, + {REG_INFO_VER( 9, Scatter, typeNameListScatterGather, supportedTypeListScatterGather, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly)}, + {REG_INFO_VER( 11, Scatter, typeNameListScatterGather, supportedTypeListScatterGather, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly)}, + {REG_INFO_VER( 13, Scatter, typeNameListScatterGather, supportedTypeListScatterGather, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly)}, + {REG_INFO( 11, ScatterElements, typeNameListScatterGather, supportedTypeListScatterGather, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly)}, + {REG_INFO( 13, ScatterElements, typeNameListScatterGather, supportedTypeListScatterGather, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly)}, + {REG_INFO( 11, ScatterND, typeNameListScatterGatherND, supportedTypeListScatterGatherND, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly)}, + {REG_INFO( 13, ScatterND, typeNameListScatterGatherND, supportedTypeListScatterGatherND, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly)}, + {REG_INFO( 9, EyeLike, typeNameListEyeLike, supportedTypeListScalars8to32, DmlGraphSupport::Supported)}, // Data reorganization that merely changes the dimensions while keeping the data identical. - {REG_INFO_ID( 7, Identity, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides)}, - {REG_INFO_ID( 7, Flatten, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides)}, - {REG_INFO_ID( 9, Flatten, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides)}, - {REG_INFO_ID( 11, Flatten, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides)}, - {REG_INFO_ID( 7, Squeeze, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides)}, - {REG_INFO_ID( 11, Squeeze, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides)}, - {REG_INFO_ID( 7, Unsqueeze, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides)}, - {REG_INFO_ID( 11, Unsqueeze, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides)}, - {REG_INFO_ID( 7, Reshape, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides, requiredConstantCpuInputs(1))}, + {REG_INFO_ID( 7, Identity, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly)}, + {REG_INFO_ID( 13, Identity, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly)}, + {REG_INFO_ID( 7, Flatten, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly)}, + {REG_INFO_ID( 9, Flatten, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly)}, + {REG_INFO_ID( 11, Flatten, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly)}, + {REG_INFO_ID( 13, Flatten, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly)}, + {REG_INFO_ID( 7, Squeeze, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly)}, + {REG_INFO_ID( 11, Squeeze, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly)}, + {REG_INFO_ID( 7, Unsqueeze, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly)}, + {REG_INFO_ID( 11, Unsqueeze, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly)}, + {REG_INFO_ID( 7, Reshape, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly, requiredConstantCpuInputs(1))}, + {REG_INFO_ID( 13, Reshape, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly, requiredConstantCpuInputs(1))}, // Elementwise - {REG_INFO( 7, Sqrt, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, - {REG_INFO( 7, Reciprocal, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, - {REG_INFO( 7, Pow, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, - {REG_INFO( 12, Pow, typeNameListPow12, supportedTypeListPow12, DmlGraphSupport::Supported)}, - {REG_INFO( 7, Exp, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, - {REG_INFO( 7, Log, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, - {REG_INFO( 7, Abs, typeNameListDefault, supportedTypeListNumericDefault, DmlGraphSupport::Supported)}, - {REG_INFO( 7, Ceil, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, - {REG_INFO( 7, Floor, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, - {REG_INFO_VER( 7, Clip, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, - {REG_INFO_VER( 11, Clip, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported, requiredConstantCpuInputs(1,2))}, - {REG_INFO_VER( 12, Clip, typeNameListDefault, supportedTypeListFloat16to32Int8to32,DmlGraphSupport::Supported, requiredConstantCpuInputs(1,2))}, - {REG_INFO( 7, Add, typeNameListDefault, supportedTypeListFloat16to32Int32, DmlGraphSupport::Supported)}, - {REG_INFO( 7, Sub, typeNameListDefault, supportedTypeListFloat16to32Int32, DmlGraphSupport::Supported)}, - {REG_INFO( 7, Mul, typeNameListDefault, supportedTypeListFloat16to32Int32, DmlGraphSupport::Supported)}, - {REG_INFO( 7, Div, typeNameListDefault, supportedTypeListFloat16to32Int32, DmlGraphSupport::Supported)}, - {REG_INFO( 7, Sum, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported, requiredConstantCpuInputs(), 2)}, - {REG_INFO( 8, Sum, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported, requiredConstantCpuInputs(), 2)}, - {REG_INFO( 7, Mean, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported, requiredConstantCpuInputs(), 2)}, - {REG_INFO( 8, Mean, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported, requiredConstantCpuInputs(), 2)}, - {REG_INFO( 7, Max, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported, requiredConstantCpuInputs(), 2)}, - {REG_INFO( 8, Max, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported, requiredConstantCpuInputs(), 2)}, - {REG_INFO( 12, Max, typeNameListDefault, supportedTypeListFloat16to32Int8to32,DmlGraphSupport::Supported, requiredConstantCpuInputs(), 2)}, - {REG_INFO( 7, Min, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported, requiredConstantCpuInputs(), 2)}, - {REG_INFO( 8, Min, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported, requiredConstantCpuInputs(), 2)}, - {REG_INFO( 12, Min, typeNameListDefault, supportedTypeListFloat16to32Int8to32,DmlGraphSupport::Supported, requiredConstantCpuInputs(), 2)}, - {REG_INFO( 7, Cos, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, - {REG_INFO( 7, Sin, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, - {REG_INFO( 7, Tan, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, - {REG_INFO( 7, Acos, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, - {REG_INFO( 7, Asin, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, - {REG_INFO( 7, Atan, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, - {REG_INFO( 7, Affine, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, - {REG_INFO( 10, QuantizeLinear, typeNameListTwo, supportedTypeListQuantizeLinear, DmlGraphSupport::Supported)}, - {REG_INFO( 10, DequantizeLinear, typeNameListTwo, supportedTypeListDequantizeLinear, DmlGraphSupport::Supported)}, - {REG_INFO_MS( 1, QuantizeLinear, typeNameListTwo, supportedTypeListQuantize, DmlGraphSupport::Supported)}, - {REG_INFO_MS( 1, DequantizeLinear, typeNameListTwo, supportedTypeListQuantize, DmlGraphSupport::Supported)}, - {REG_INFO( 9, Sign, typeNameListDefault, supportedTypeListNumericDefault, DmlGraphSupport::Supported)}, - {REG_INFO( 9, IsNan, typeNameListTwo, supportedTypeListIsNan, DmlGraphSupport::Supported)}, - {REG_INFO( 9, Sinh, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, - {REG_INFO( 9, Cosh, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, - {REG_INFO( 9, Asinh, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, - {REG_INFO( 9, Acosh, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, - {REG_INFO( 9, Atanh, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, - {REG_INFO( 9, Erf, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, - {REG_INFO( 9, Where, typeNameListWhere, supportedTypeListWhere, DmlGraphSupport::Supported)}, - {REG_INFO( 7, ReduceSum, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, - {REG_INFO( 11, ReduceSum, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, - {REG_INFO_VER( 12, Einsum, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported, requiredConstantCpuInputs(), std::nullopt, QueryEinSum )}, - {REG_INFO( 7, ReduceMean, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, - {REG_INFO( 11, ReduceMean, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, - {REG_INFO( 7, ReduceProd, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, - {REG_INFO( 11, ReduceProd, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, - {REG_INFO( 7, ReduceLogSum, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, - {REG_INFO( 11, ReduceLogSum, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, - {REG_INFO( 7, ReduceLogSumExp, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, - {REG_INFO( 11, ReduceLogSumExp, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, - {REG_INFO( 7, ReduceSumSquare, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, - {REG_INFO( 11, ReduceSumSquare, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, - {REG_INFO( 7, ReduceL1, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, - {REG_INFO( 11, ReduceL1, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, - {REG_INFO( 7, ReduceL2, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, - {REG_INFO( 11, ReduceL2, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, - {REG_INFO( 7, ReduceMax, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, - {REG_INFO( 11, ReduceMax, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, - {REG_INFO( 12, ReduceMax, typeNameListDefault, supportedTypeListFloat16to32Int8, DmlGraphSupport::Supported)}, - {REG_INFO( 7, ReduceMin, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, - {REG_INFO( 11, ReduceMin, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, - {REG_INFO( 12, ReduceMin, typeNameListDefault, supportedTypeListFloat16to32Int8, DmlGraphSupport::Supported)}, - {REG_INFO( 7, ArgMax, typeNameListDefault, supportedTypeListNumericDefault, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides)}, - {REG_INFO( 11, ArgMax, typeNameListDefault, supportedTypeListNumericDefault, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides)}, - {REG_INFO( 12, ArgMax, typeNameListDefault, supportedTypeListNumericDefault, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides)}, - {REG_INFO( 7, ArgMin, typeNameListDefault, supportedTypeListNumericDefault, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides)}, - {REG_INFO( 11, ArgMin, typeNameListDefault, supportedTypeListNumericDefault, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides)}, - {REG_INFO( 12, ArgMin, typeNameListDefault, supportedTypeListNumericDefault, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides)}, - {REG_INFO( 7, Gemm, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, - {REG_INFO( 11, Gemm, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, - {REG_INFO( 9, Gemm, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, - {REG_INFO( 7, Neg, typeNameListDefault, supportedTypeListSigned, DmlGraphSupport::Supported)}, - {REG_INFO( 7, Greater, typeNameListLogicalComparison, supportedTypeListLogicalComparison7,DmlGraphSupport::Supported)}, - {REG_INFO( 9, Greater, typeNameListLogicalComparison, supportedTypeListLogicalComparison9,DmlGraphSupport::Supported)}, - {REG_INFO( 7, Less, typeNameListLogicalComparison, supportedTypeListLogicalComparison7,DmlGraphSupport::Supported)}, - {REG_INFO( 9, Less, typeNameListLogicalComparison, supportedTypeListLogicalComparison9,DmlGraphSupport::Supported)}, - {REG_INFO( 12, GreaterOrEqual, typeNameListLogicalComparison, supportedTypeListLogicalComparison9,DmlGraphSupport::Supported)}, - {REG_INFO( 12, LessOrEqual, typeNameListLogicalComparison, supportedTypeListLogicalComparison9,DmlGraphSupport::Supported)}, - {REG_INFO( 7, Equal, typeNameListLogicalComparison, supportedTypeListLogicalComparison7,DmlGraphSupport::Supported)}, - {REG_INFO( 11, Equal, typeNameListLogicalComparison, supportedTypeListLogicalComparison9,DmlGraphSupport::Supported)}, - {REG_INFO( 7, Not, typeNameListDefault, supportedTypeListBool, DmlGraphSupport::Supported)}, - {REG_INFO( 7, And, typeNameListDefault, supportedTypeListBool, DmlGraphSupport::Supported)}, - {REG_INFO( 7, Or, typeNameListDefault, supportedTypeListBool, DmlGraphSupport::Supported)}, - {REG_INFO( 7, Xor, typeNameListDefault, supportedTypeListBool, DmlGraphSupport::Supported)}, + {REG_INFO( 7, Sqrt, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, + {REG_INFO( 13, Sqrt, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, + {REG_INFO( 7, Reciprocal, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, + {REG_INFO( 13, Reciprocal, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, + {REG_INFO( 7, Pow, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, + {REG_INFO( 12, Pow, typeNameListPow12, supportedTypeListPow12, DmlGraphSupport::Supported)}, + {REG_INFO( 13, Pow, typeNameListPow12, supportedTypeListPow12, DmlGraphSupport::Supported)}, + {REG_INFO( 7, Exp, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, + {REG_INFO( 13, Exp, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, + {REG_INFO( 7, Log, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, + {REG_INFO( 13, Log, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, + {REG_INFO( 7, Abs, typeNameListDefault, supportedTypeListNumericDefault, DmlGraphSupport::Supported)}, + {REG_INFO( 13, Abs, typeNameListDefault, supportedTypeListNumericDefault, DmlGraphSupport::Supported)}, + {REG_INFO( 7, Ceil, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, + {REG_INFO( 13, Ceil, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, + {REG_INFO( 7, Floor, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, + {REG_INFO( 13, Floor, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, + {REG_INFO_VER( 7, Clip, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, + {REG_INFO_VER( 11, Clip, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported, requiredConstantCpuInputs(1,2))}, + {REG_INFO_VER( 12, Clip, typeNameListDefault, supportedTypeListFloat16to32Ints8to64, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly, requiredConstantCpuInputs(1,2))}, + {REG_INFO_VER( 13, Clip, typeNameListDefault, supportedTypeListFloat16to32Ints8to64, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly, requiredConstantCpuInputs(1,2))}, + {REG_INFO( 7, Add, typeNameListDefault, supportedTypeListFloat16to32Ints32to64, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly)}, + {REG_INFO( 13, Add, typeNameListDefault, supportedTypeListFloat16to32Ints32to64, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly)}, + {REG_INFO( 7, Sub, typeNameListDefault, supportedTypeListFloat16to32Ints32to64, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly)}, + {REG_INFO( 13, Sub, typeNameListDefault, supportedTypeListFloat16to32Ints32to64, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly)}, + {REG_INFO( 7, Mul, typeNameListDefault, supportedTypeListFloat16to32Ints32to64, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly)}, + {REG_INFO( 13, Mul, typeNameListDefault, supportedTypeListFloat16to32Ints32to64, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly)}, + {REG_INFO( 7, Div, typeNameListDefault, supportedTypeListFloat16to32Ints32, DmlGraphSupport::Supported)}, + {REG_INFO( 13, Div, typeNameListDefault, supportedTypeListFloat16to32Ints32, DmlGraphSupport::Supported)}, + {REG_INFO( 7, Sum, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported, requiredConstantCpuInputs(), 2)}, + {REG_INFO( 8, Sum, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported, requiredConstantCpuInputs(), 2)}, + {REG_INFO( 13, Sum, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported, requiredConstantCpuInputs(), 2)}, + {REG_INFO( 7, Mean, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported, requiredConstantCpuInputs(), 2)}, + {REG_INFO( 8, Mean, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported, requiredConstantCpuInputs(), 2)}, + {REG_INFO( 13, Mean, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported, requiredConstantCpuInputs(), 2)}, + {REG_INFO( 7, Max, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported, requiredConstantCpuInputs(), 2)}, + {REG_INFO( 8, Max, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported, requiredConstantCpuInputs(), 2)}, + {REG_INFO( 12, Max, typeNameListDefault, supportedTypeListFloat16to32Ints8to64, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly, requiredConstantCpuInputs(), 2)}, + {REG_INFO( 13, Max, typeNameListDefault, supportedTypeListFloat16to32Ints8to64, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly, requiredConstantCpuInputs(), 2)}, + {REG_INFO( 7, Min, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported, requiredConstantCpuInputs(), 2)}, + {REG_INFO( 8, Min, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported, requiredConstantCpuInputs(), 2)}, + {REG_INFO( 12, Min, typeNameListDefault, supportedTypeListFloat16to32Ints8to64, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly, requiredConstantCpuInputs(), 2)}, + {REG_INFO( 13, Min, typeNameListDefault, supportedTypeListFloat16to32Ints8to64, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly, requiredConstantCpuInputs(), 2)}, + {REG_INFO( 7, Cos, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, + {REG_INFO( 7, Sin, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, + {REG_INFO( 7, Tan, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, + {REG_INFO( 7, Acos, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, + {REG_INFO( 7, Asin, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, + {REG_INFO( 7, Atan, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, + {REG_INFO( 7, Affine, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, + {REG_INFO( 10, QuantizeLinear, typeNameListTwo, supportedTypeListQuantizeLinear, DmlGraphSupport::Supported)}, + {REG_INFO( 10, DequantizeLinear, typeNameListTwo, supportedTypeListDequantizeLinear, DmlGraphSupport::Supported)}, + {REG_INFO_MS( 1, QuantizeLinear, typeNameListTwo, supportedTypeListQuantize, DmlGraphSupport::Supported)}, + {REG_INFO_MS( 1, DequantizeLinear, typeNameListTwo, supportedTypeListQuantize, DmlGraphSupport::Supported)}, + {REG_INFO( 9, Sign, typeNameListDefault, supportedTypeListNumericDefault, DmlGraphSupport::Supported)}, + {REG_INFO( 13, Sign, typeNameListDefault, supportedTypeListNumericDefault, DmlGraphSupport::Supported)}, + {REG_INFO( 9, IsNaN, typeNameListTwo, supportedTypeListIsNan, DmlGraphSupport::Supported)}, + {REG_INFO( 13, IsNaN, typeNameListTwo, supportedTypeListIsNan, DmlGraphSupport::Supported)}, + {REG_INFO( 9, Sinh, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, + {REG_INFO( 9, Cosh, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, + {REG_INFO( 9, Asinh, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, + {REG_INFO( 9, Acosh, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, + {REG_INFO( 9, Atanh, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, + {REG_INFO( 9, Erf, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, + {REG_INFO( 13, Erf, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, + {REG_INFO( 9, Where, typeNameListWhere, supportedTypeListWhere, DmlGraphSupport::Supported)}, + {REG_INFO( 7, ReduceSum, typeNameListDefault, supportedTypeListFloat16to32Ints32to64, DmlGraphSupport::Supported)}, + {REG_INFO( 11, ReduceSum, typeNameListDefault, supportedTypeListFloat16to32Ints32to64, DmlGraphSupport::Supported)}, + {REG_INFO_VER( 12, Einsum, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported, requiredConstantCpuInputs(), std::nullopt, QueryEinSum )}, + {REG_INFO( 7, ReduceMean, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, + {REG_INFO( 11, ReduceMean, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, + {REG_INFO( 13, ReduceMean, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, + {REG_INFO( 7, ReduceProd, typeNameListDefault, supportedTypeListFloat16to32Ints32to64, DmlGraphSupport::Supported)}, + {REG_INFO( 11, ReduceProd, typeNameListDefault, supportedTypeListFloat16to32Ints32to64, DmlGraphSupport::Supported)}, + {REG_INFO( 13, ReduceProd, typeNameListDefault, supportedTypeListFloat16to32Ints32to64, DmlGraphSupport::Supported)}, + {REG_INFO( 7, ReduceLogSum, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, + {REG_INFO( 11, ReduceLogSum, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, + {REG_INFO( 13, ReduceLogSum, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, + {REG_INFO( 7, ReduceLogSumExp, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, + {REG_INFO( 11, ReduceLogSumExp, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, + {REG_INFO( 13, ReduceLogSumExp, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, + {REG_INFO( 7, ReduceSumSquare, typeNameListDefault, supportedTypeListFloat16to32Ints32to64, DmlGraphSupport::Supported)}, + {REG_INFO( 11, ReduceSumSquare, typeNameListDefault, supportedTypeListFloat16to32Ints32to64, DmlGraphSupport::Supported)}, + {REG_INFO( 13, ReduceSumSquare, typeNameListDefault, supportedTypeListFloat16to32Ints32to64, DmlGraphSupport::Supported)}, + {REG_INFO( 7, ReduceL1, typeNameListDefault, supportedTypeListFloat16to32Ints32to64, DmlGraphSupport::Supported)}, + {REG_INFO( 11, ReduceL1, typeNameListDefault, supportedTypeListFloat16to32Ints32to64, DmlGraphSupport::Supported)}, + {REG_INFO( 13, ReduceL1, typeNameListDefault, supportedTypeListFloat16to32Ints32to64, DmlGraphSupport::Supported)}, + {REG_INFO( 7, ReduceL2, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, + {REG_INFO( 11, ReduceL2, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, + {REG_INFO( 13, ReduceL2, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, + {REG_INFO( 7, ReduceMax, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, + {REG_INFO( 11, ReduceMax, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, + {REG_INFO( 12, ReduceMax, typeNameListDefault, supportedTypeListFloat16to32Ints8to64, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly)}, + {REG_INFO( 13, ReduceMax, typeNameListDefault, supportedTypeListFloat16to32Ints8to64, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly)}, + {REG_INFO( 7, ReduceMin, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, + {REG_INFO( 11, ReduceMin, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, + {REG_INFO( 12, ReduceMin, typeNameListDefault, supportedTypeListFloat16to32Ints8to64, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly)}, + {REG_INFO( 13, ReduceMin, typeNameListDefault, supportedTypeListFloat16to32Ints8to64, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly)}, + {REG_INFO( 7, ArgMax, typeNameListDefault, supportedTypeListArgMinMax, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly)}, + {REG_INFO( 11, ArgMax, typeNameListDefault, supportedTypeListArgMinMax, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly)}, + {REG_INFO( 12, ArgMax, typeNameListDefault, supportedTypeListArgMinMax, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly)}, + {REG_INFO( 13, ArgMax, typeNameListDefault, supportedTypeListArgMinMax, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly)}, + {REG_INFO( 7, ArgMin, typeNameListDefault, supportedTypeListArgMinMax, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly)}, + {REG_INFO( 11, ArgMin, typeNameListDefault, supportedTypeListArgMinMax, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly)}, + {REG_INFO( 12, ArgMin, typeNameListDefault, supportedTypeListArgMinMax, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly)}, + {REG_INFO( 13, ArgMin, typeNameListDefault, supportedTypeListArgMinMax, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly)}, + {REG_INFO( 7, Gemm, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, + {REG_INFO( 9, Gemm, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, + {REG_INFO( 11, Gemm, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, + {REG_INFO( 13, Gemm, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, + {REG_INFO( 7, Neg, typeNameListDefault, supportedTypeListSigned, DmlGraphSupport::Supported)}, + {REG_INFO( 13, Neg, typeNameListDefault, supportedTypeListSigned, DmlGraphSupport::Supported)}, + {REG_INFO( 7, Greater, typeNameListLogicalComparison, supportedTypeListLogicalComparison7, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly)}, + {REG_INFO( 9, Greater, typeNameListLogicalComparison, supportedTypeListLogicalComparison9, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly)}, + {REG_INFO( 13, Greater, typeNameListLogicalComparison, supportedTypeListLogicalComparison9, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly)}, + {REG_INFO( 7, Less, typeNameListLogicalComparison, supportedTypeListLogicalComparison7, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly)}, + {REG_INFO( 9, Less, typeNameListLogicalComparison, supportedTypeListLogicalComparison9, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly)}, + {REG_INFO( 13, Less, typeNameListLogicalComparison, supportedTypeListLogicalComparison9, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly)}, + {REG_INFO( 12, GreaterOrEqual, typeNameListLogicalComparison, supportedTypeListLogicalComparison9, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly)}, + {REG_INFO( 12, LessOrEqual, typeNameListLogicalComparison, supportedTypeListLogicalComparison9, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly)}, + {REG_INFO( 7, Equal, typeNameListLogicalComparison, supportedTypeListLogicalComparison7, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly)}, + {REG_INFO( 11, Equal, typeNameListLogicalComparison, supportedTypeListLogicalComparison9, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly)}, + {REG_INFO( 13, Equal, typeNameListLogicalComparison, supportedTypeListLogicalComparison9, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly)}, + {REG_INFO( 7, Not, typeNameListDefault, supportedTypeListBool, DmlGraphSupport::Supported)}, + {REG_INFO( 7, And, typeNameListDefault, supportedTypeListBool, DmlGraphSupport::Supported)}, + {REG_INFO( 7, Or, typeNameListDefault, supportedTypeListBool, DmlGraphSupport::Supported)}, + {REG_INFO( 7, Xor, typeNameListDefault, supportedTypeListBool, DmlGraphSupport::Supported)}, // Imaging Operators - {REG_INFO( 7, Crop, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, - {REG_INFO( 7, ImageScaler, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, - {REG_INFO_VER( 7, Upsample, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, - {REG_INFO_VER( 9, Upsample, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported, requiredConstantCpuInputs(1) /*scales*/)}, - {REG_INFO_VER( 10, Upsample, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported, requiredConstantCpuInputs(1) /*scales*/)}, - {REG_INFO_VER( 10, Resize, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported, requiredConstantCpuInputs(1) /*scales*/)}, - {REG_INFO_VER( 11, Resize, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported, requiredConstantCpuInputs(1, 2, 3) /*roi, scales, sizes*/, std::nullopt, QueryResize)}, + {REG_INFO( 7, Crop, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, + {REG_INFO( 7, ImageScaler, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, + {REG_INFO_VER( 7, Upsample, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, + {REG_INFO_VER( 9, Upsample, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported, requiredConstantCpuInputs(1) /*scales*/)}, + {REG_INFO_VER( 10, Upsample, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported, requiredConstantCpuInputs(1) /*scales*/)}, + {REG_INFO_VER( 13, Upsample, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported, requiredConstantCpuInputs(1) /*scales*/)}, + {REG_INFO_VER( 10, Resize, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported, requiredConstantCpuInputs(1) /*scales*/)}, + {REG_INFO_VER( 11, Resize, typeNameListTwo, supportedTypeListResize11, DmlGraphSupport::Supported, requiredConstantCpuInputs(1, 2, 3) /*roi, scales, sizes*/, std::nullopt, QueryResize)}, // Activation Functions - {REG_INFO( 7, Sigmoid, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, - {REG_INFO( 7, HardSigmoid, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, - {REG_INFO( 7, Tanh, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, - {REG_INFO( 7, ScaledTanh, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, - {REG_INFO( 7, Relu, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, - {REG_INFO( 7, LeakyRelu, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, - {REG_INFO( 7, PRelu, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, - {REG_INFO( 9, PRelu, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, - {REG_INFO( 7, ThresholdedRelu, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, - {REG_INFO( 10, ThresholdedRelu, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, - {REG_INFO( 7, Elu, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, - {REG_INFO( 12, Celu, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, - {REG_INFO( 7, Selu, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, - {REG_INFO( 7, Softmax, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, - {REG_INFO( 11, Softmax, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, - {REG_INFO( 7, LogSoftmax, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, - {REG_INFO( 11, LogSoftmax, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, - {REG_INFO( 7, Hardmax, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, - {REG_INFO( 11, Hardmax, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, - {REG_INFO( 7, Softsign, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, - {REG_INFO( 7, Softplus, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, - {REG_INFO( 7, ParametricSoftplus, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, - {REG_INFO( 7, Dropout, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, - {REG_INFO( 9, Shrink, typeNameListDefault, supportedTypeListNumericDefault, DmlGraphSupport::Supported)}, + {REG_INFO( 7, Sigmoid, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, + {REG_INFO( 13, Sigmoid, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, + {REG_INFO( 7, HardSigmoid, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, + {REG_INFO( 7, Tanh, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, + {REG_INFO( 13, Tanh, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, + {REG_INFO( 7, ScaledTanh, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, + {REG_INFO( 7, Relu, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, + {REG_INFO( 13, Relu, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, + {REG_INFO( 7, LeakyRelu, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, + {REG_INFO( 7, PRelu, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, + {REG_INFO( 9, PRelu, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, + {REG_INFO( 7, ThresholdedRelu, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, + {REG_INFO( 10, ThresholdedRelu, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, + {REG_INFO( 7, Elu, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, + {REG_INFO( 12, Celu, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, + {REG_INFO( 7, Selu, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, + {REG_INFO( 7, Softmax, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, + {REG_INFO( 11, Softmax, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, + {REG_INFO( 7, LogSoftmax, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, + {REG_INFO( 11, LogSoftmax, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, + {REG_INFO( 7, Hardmax, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, + {REG_INFO( 11, Hardmax, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, + {REG_INFO( 7, Softsign, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, + {REG_INFO( 7, Softplus, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, + {REG_INFO( 7, ParametricSoftplus, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, + {REG_INFO( 7, Dropout, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, + {REG_INFO( 9, Shrink, typeNameListDefault, supportedTypeListNumericDefault, DmlGraphSupport::Supported)}, // Uncategorized - {REG_INFO( 7, MatMul, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, - {REG_INFO( 9, MatMul, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, - {REG_INFO( 7, Cast, typeNameListTwo, supportedTypeListCast, DmlGraphSupport::Supported)}, - {REG_INFO( 9, Cast, typeNameListTwo, supportedTypeListCast, DmlGraphSupport::Supported)}, + {REG_INFO( 7, MatMul, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, + {REG_INFO( 9, MatMul, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, + {REG_INFO( 13, MatMul, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, + {REG_INFO( 7, Cast, typeNameListTwo, supportedTypeListCast, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly)}, + {REG_INFO( 9, Cast, typeNameListTwo, supportedTypeListCast, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly)}, + {REG_INFO( 13, Cast, typeNameListTwo, supportedTypeListCast, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly)}, {REG_INFO( 7, MemcpyFromHost, typeNameListDefault, supportedTypeListAll)}, {REG_INFO( 7, MemcpyToHost, typeNameListDefault, supportedTypeListAll)}, - {REG_INFO_VER( 7, TopK, typeNameListTopK, supportedTypeListTopK, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides)}, - {REG_INFO_VER( 10, TopK, typeNameListTopK, supportedTypeListTopK, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides, requiredConstantCpuInputs(1))}, - {REG_INFO_VER( 11, TopK, typeNameListTopK, supportedTypeListTopK, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides, requiredConstantCpuInputs(1))}, - {REG_INFO( 9, OneHot, typeNameListThree, supportedTypeListOneHot, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStridesFromAnyEp, requiredConstantCpuInputs(1))}, - {REG_INFO( 11, OneHot, typeNameListThree, supportedTypeListOneHot, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStridesFromAnyEp, requiredConstantCpuInputs(1))}, + {REG_INFO_VER( 7, TopK, typeNameListTopK, supportedTypeListTopK, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly)}, + {REG_INFO_VER( 10, TopK, typeNameListTopK, supportedTypeListTopK, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly, requiredConstantCpuInputs(1))}, + {REG_INFO_VER( 11, TopK, typeNameListTopK, supportedTypeListTopK, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly, requiredConstantCpuInputs(1))}, + {REG_INFO( 9, OneHot, typeNameListThree, supportedTypeListOneHot, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly, requiredConstantCpuInputs(1))}, + {REG_INFO( 11, OneHot, typeNameListThree, supportedTypeListOneHot, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly, requiredConstantCpuInputs(1))}, // Fused operators - {REG_INFO_MSDML(1, FusedConv, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, - {REG_INFO_MSDML(1, FusedConvTranspose, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, - {REG_INFO_MSDML(1, FusedInstanceNormalization, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, - {REG_INFO_MSDML(1, FusedBatchNormalization, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, - {REG_INFO_MSDML(1, FusedMeanVarianceNormalization, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, - {REG_INFO_MSDML(1, FusedGemm, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, - {REG_INFO_MSDML(1, FusedMatMul, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, - {REG_INFO_MSDML(1, FusedAdd, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, - {REG_INFO_MSDML(1, FusedSum, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported, requiredConstantCpuInputs(), 2)}, + {REG_INFO_MSDML(1, FusedConv, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, + {REG_INFO_MSDML(1, FusedConvTranspose, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, + {REG_INFO_MSDML(1, FusedInstanceNormalization, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, + {REG_INFO_MSDML(1, FusedBatchNormalization, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, + {REG_INFO_MSDML(1, FusedMeanVarianceNormalization, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, + {REG_INFO_MSDML(1, FusedGemm, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, + {REG_INFO_MSDML(1, FusedMatMul, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, + {REG_INFO_MSDML(1, FusedAdd, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, + {REG_INFO_MSDML(1, FusedSum, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported, requiredConstantCpuInputs(), 2)}, - {REG_INFO( 10, IsInf, typeNameListTwo, supportedTypeListIsInf, DmlGraphSupport::Supported)}, - {REG_INFO( 10, Mod, typeNameListDefault, supportedTypeListNumericDefault, DmlGraphSupport::Supported)}, - {REG_INFO( 11, BitShift, typeNameListDefault, supportedTypeListInt8to32, DmlGraphSupport::Supported)}, - {REG_INFO( 11, Round, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, - {REG_INFO( 10, ReverseSequence, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStridesFromAnyEp)}, - {REG_INFO( 11, CumSum, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported, requiredConstantCpuInputs(1))}, - {REG_INFO( 11, Range, typeNameListDefault, supportedTypeListRange, DmlGraphSupport::Supported, requiredConstantCpuInputs(0,1,2))}, + {REG_INFO( 10, IsInf, typeNameListTwo, supportedTypeListIsInf, DmlGraphSupport::Supported)}, + {REG_INFO( 10, Mod, typeNameListDefault, supportedTypeListNumericDefault, DmlGraphSupport::Supported)}, + {REG_INFO( 13, Mod, typeNameListDefault, supportedTypeListNumericDefault, DmlGraphSupport::Supported)}, + {REG_INFO( 11, BitShift, typeNameListDefault, supportedTypeListUInt8to64, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly)}, + {REG_INFO( 11, Round, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, + {REG_INFO( 10, ReverseSequence, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly)}, + {REG_INFO( 11, CumSum, typeNameListDefault, supportedTypeListFloat16to32Ints32to64, DmlGraphSupport::Supported, requiredConstantCpuInputs(1))}, + {REG_INFO( 11, Range, typeNameListDefault, supportedTypeListRange, DmlGraphSupport::Supported, requiredConstantCpuInputs(0,1,2))}, - {REG_INFO( 9, MaxUnpool, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStridesFromAnyEp, requiredConstantCpuInputs(2))}, - {REG_INFO( 11, MaxUnpool, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStridesFromAnyEp, requiredConstantCpuInputs(2))}, // 11 is identical to 9. - - {REG_INFO( 10, QLinearConv, typeNameListFour, supportedTypeListQLinearConv, DmlGraphSupport::NotSupported)}, - {REG_INFO( 10, QLinearMatMul, typeNameListThree, supportedTypeListQLinearMatMul, DmlGraphSupport::NotSupported)}, - {REG_INFO( 10, MatMulInteger, typeNameListThree, supportedTypeListInteger, DmlGraphSupport::NotSupported)}, - {REG_INFO( 10, ConvInteger, typeNameListThree, supportedTypeListInteger, DmlGraphSupport::NotSupported)}, + {REG_INFO( 9, MaxUnpool, typeNameListTwo, supportedTypeListMaxUnpool, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly, requiredConstantCpuInputs(2))}, + {REG_INFO( 11, MaxUnpool, typeNameListTwo, supportedTypeListMaxUnpool, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly, requiredConstantCpuInputs(2))}, // 11 is identical to 9. + + {REG_INFO_MS( 1, QLinearAdd, typeNameListDefault, supportedTypeListInteger8, DmlGraphSupport::Supported)}, + {REG_INFO( 10, QLinearConv, typeNameListFour, supportedTypeListQLinearConv, DmlGraphSupport::Supported)}, + {REG_INFO( 10, QLinearMatMul, typeNameListThree, supportedTypeListQLinearMatMul, DmlGraphSupport::Supported)}, + {REG_INFO( 10, MatMulInteger, typeNameListThree, supportedTypeListInteger, DmlGraphSupport::Supported)}, + {REG_INFO( 10, ConvInteger, typeNameListThree, supportedTypeListInteger, DmlGraphSupport::Supported)}, + {REG_INFO( 11, DynamicQuantizeLinear, typeNameListTwo, supportedTypeListDynamicQuantizeLinear, DmlGraphSupport::Supported)}, }; template @@ -607,6 +697,7 @@ void RegisterDmlOperators(IMLOperatorRegistry* registry) // The graph must be configured with operators from only the legacy DML API, or only the new DML API bool kernelSupportsGraph = !bool(information.dmlGraphSupport & DmlGraphSupport::NotSupported); bool prefer64BitTensorsDirectly = bool(information.dmlGraphSupport & DmlGraphSupport::Prefer64BitTensorsDirectly); + bool support64BitTensorsViaEmulation = bool(information.dmlGraphSupport & DmlGraphSupport::Support64BitTensorsViaEmulation); bool supportedWith64BitTensorsVia32BitStrides = bool(information.dmlGraphSupport & DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides); bool supportedWith64BitTensorsVia32BitStridesFromAnyEp = bool(information.dmlGraphSupport & DmlGraphSupport::SupportedWith64BitTensorsVia32BitStridesFromAnyEp); @@ -692,6 +783,7 @@ void RegisterDmlOperators(IMLOperatorRegistry* registry) supportedWith64BitTensorsVia32BitStrides, supportedWith64BitTensorsVia32BitStridesFromAnyEp, prefer64BitTensorsDirectly, + support64BitTensorsViaEmulation, information.requiredConstantCpuInputs.first.data(), static_cast(information.requiredConstantCpuInputs.second) )); diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorUtility.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorUtility.cpp index ba91d91953..2349f1c9bd 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorUtility.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorUtility.cpp @@ -151,15 +151,20 @@ namespace Dml OperatorInfo{ "InstanceNormalization", onnxruntime::kOnnxDomain, OnnxOperatorSet7::sc_sinceVer_InstanceNormalization }, OperatorInfo{ "MeanVarianceNormalization", onnxruntime::kOnnxDomain, OnnxOperatorSet7::sc_sinceVer_MeanVarianceNormalization }, OperatorInfo{ "MeanVarianceNormalization", onnxruntime::kOnnxDomain, OnnxOperatorSet9::sc_sinceVer_MeanVarianceNormalization }, + OperatorInfo{ "MeanVarianceNormalization", onnxruntime::kOnnxDomain, OnnxOperatorSet13::sc_sinceVer_MeanVarianceNormalization }, OperatorInfo{ "Gemm", onnxruntime::kOnnxDomain, OnnxOperatorSet7::sc_sinceVer_Gemm }, OperatorInfo{ "Gemm", onnxruntime::kOnnxDomain, OnnxOperatorSet9::sc_sinceVer_Gemm }, OperatorInfo{ "Gemm", onnxruntime::kOnnxDomain, OnnxOperatorSet11::sc_sinceVer_Gemm }, + OperatorInfo{ "Gemm", onnxruntime::kOnnxDomain, OnnxOperatorSet13::sc_sinceVer_Gemm }, OperatorInfo{ "MatMul", onnxruntime::kOnnxDomain, OnnxOperatorSet7::sc_sinceVer_MatMul }, OperatorInfo{ "MatMul", onnxruntime::kOnnxDomain, OnnxOperatorSet9::sc_sinceVer_MatMul }, + OperatorInfo{ "MatMul", onnxruntime::kOnnxDomain, OnnxOperatorSet13::sc_sinceVer_MatMul }, // The filter for activation functions maps to what DML's fused op internally fuses at the shader level. OperatorInfo{ "Add", onnxruntime::kOnnxDomain, OnnxOperatorSet7::sc_sinceVer_Add, {"Relu", "LeakyRelu"} }, + OperatorInfo{ "Add", onnxruntime::kOnnxDomain, OnnxOperatorSet13::sc_sinceVer_Add, {"Relu", "LeakyRelu"} }, OperatorInfo{ "Sum", onnxruntime::kOnnxDomain, OnnxOperatorSet8::sc_sinceVer_Sum, {"Relu", "LeakyRelu"}, 2 }, + OperatorInfo{ "Sum", onnxruntime::kOnnxDomain, OnnxOperatorSet13::sc_sinceVer_Sum, {"Relu", "LeakyRelu"}, 2 }, }; // Not all activations can be fused - only simple elementwise activations (i.e. activation functions which @@ -167,10 +172,13 @@ namespace Dml static const OperatorInfo c_activationOps[] = { OperatorInfo{ "Sigmoid", onnxruntime::kOnnxDomain, OnnxOperatorSet7::sc_sinceVer_Sigmoid }, + OperatorInfo{ "Sigmoid", onnxruntime::kOnnxDomain, OnnxOperatorSet13::sc_sinceVer_Sigmoid }, OperatorInfo{ "HardSigmoid", onnxruntime::kOnnxDomain, OnnxOperatorSet7::sc_sinceVer_HardSigmoid }, OperatorInfo{ "Tanh", onnxruntime::kOnnxDomain, OnnxOperatorSet7::sc_sinceVer_Tanh }, + OperatorInfo{ "Tanh", onnxruntime::kOnnxDomain, OnnxOperatorSet13::sc_sinceVer_Tanh }, OperatorInfo{ "ScaledTanh", onnxruntime::kOnnxDomain, OnnxOperatorSet7::sc_sinceVer_ScaledTanh }, OperatorInfo{ "Relu", onnxruntime::kOnnxDomain, OnnxOperatorSet7::sc_sinceVer_Relu }, + OperatorInfo{ "Relu", onnxruntime::kOnnxDomain, OnnxOperatorSet13::sc_sinceVer_Relu }, OperatorInfo{ "LeakyRelu", onnxruntime::kOnnxDomain, OnnxOperatorSet7::sc_sinceVer_LeakyRelu }, OperatorInfo{ "PRelu", onnxruntime::kOnnxDomain, OnnxOperatorSet7::sc_sinceVer_PRelu }, OperatorInfo{ "PRelu", onnxruntime::kOnnxDomain, OnnxOperatorSet9::sc_sinceVer_PRelu }, diff --git a/onnxruntime/core/providers/dml/OperatorAuthorHelper/Common.h b/onnxruntime/core/providers/dml/OperatorAuthorHelper/Common.h index ee423388c2..0e4ec35da2 100644 --- a/onnxruntime/core/providers/dml/OperatorAuthorHelper/Common.h +++ b/onnxruntime/core/providers/dml/OperatorAuthorHelper/Common.h @@ -4,27 +4,47 @@ #pragma once #define ML_CHECK_VALID_ARGUMENT(x, ...)\ - {\ - if ((x) == false) {\ - ORT_THROW_HR(E_INVALIDARG);\ - }\ - } + {\ + if ((x) == false)\ + {\ + ORT_THROW_HR(E_INVALIDARG);\ + }\ + } #define ML_INVALID_ARGUMENT(msg)\ - ORT_THROW_HR(E_INVALIDARG);\ + ORT_THROW_HR(E_INVALIDARG);\ #define ML_CHECK_HRESULT(hr, ...)\ - {\ - if (FAILED(hr)) {\ - ORT_THROW_HR(E_INVALIDARG);\ - }\ - } + {\ + if (FAILED(hr))\ + {\ + ORT_THROW_HR(E_INVALIDARG);\ + }\ + } namespace OperatorHelper { - template T clamp_cast(I input) + // Clamp the input value to the maximum range of the output, before casting to the output type. + // + // e.g. int32(300) would yield int8(255) rather than int8(44). + // float32(-42) would yield uint32(0) rather than a huge positive number. + template OutputType clamp_cast(InputType input) { - return static_cast(std::clamp(input, std::numeric_limits::lowest(), std::numeric_limits::max())); + // Determine the larger type to decide which numeric limits to clamp to. + using InputLimits = std::numeric_limits; + using OutputLimits = std::numeric_limits; + constexpr int inputMaxDigits = std::max(InputLimits::max_exponent, InputLimits::digits); + constexpr int outputMaxDigits = std::max(OutputLimits::max_exponent, OutputLimits::digits); + constexpr bool isEitherTypeUnsigned = std::is_unsigned_v || std::is_unsigned_v; + constexpr bool isOutputTypeLarger = outputMaxDigits > inputMaxDigits; + + InputType lowestValue = isEitherTypeUnsigned ? static_cast(0) : + isOutputTypeLarger ? InputLimits::lowest() : + static_cast(OutputLimits::lowest()); + InputType highestValue = isOutputTypeLarger ? InputLimits::max() : + static_cast(OutputLimits::max()); + + return static_cast(std::clamp(input, lowestValue, highestValue)); } enum TensorAxis { N, C, H, W, DoNotCoerce = INT_MAX, LeftAligned = INT_MAX, RightAligned = INT_MIN, NoPlacementAdjustment = 0 }; enum BroadcastMode { NoBroadcast, UnidirectionalBroadcast, MultidirectionalBroadcast }; diff --git a/onnxruntime/core/providers/dml/OperatorAuthorHelper/MLOperatorAuthorPrivate.h b/onnxruntime/core/providers/dml/OperatorAuthorHelper/MLOperatorAuthorPrivate.h index 4c8730d063..dd6032f81e 100644 --- a/onnxruntime/core/providers/dml/OperatorAuthorHelper/MLOperatorAuthorPrivate.h +++ b/onnxruntime/core/providers/dml/OperatorAuthorHelper/MLOperatorAuthorPrivate.h @@ -114,6 +114,7 @@ IMLOperatorRegistryPrivate : public IUnknown bool supportedWith64BitTensorsVia32BitStrides = false, bool supportedWith64BitTensorsVia32BitStridesFromAnyEp = false, bool prefer64BitTensorsDirectly = false, + bool support64BitTensorsViaEmulation = false, _In_reads_(constantCpuInputCount) const uint32_t* constantCpuInputs = nullptr, uint32_t constantCpuInputCount = 0 ) const noexcept PURE; diff --git a/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.cpp b/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.cpp index 72c956356a..f3b9ddceb2 100644 --- a/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.cpp +++ b/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.cpp @@ -32,87 +32,6 @@ namespace OperatorHelper } } - void ReadCpuLocalTensorIntoInt32( - const MLOperatorTensor& tensor, - std::vector& result - ) - { - result.clear(); - ML_CHECK_VALID_ARGUMENT(tensor.IsCpuData(), "Tensor must be CPU Tensor."); - - const std::vector& tensorDimensions = tensor.GetShape(); - const uint32_t elementCount = ComputeElementCountFromDimensions(tensorDimensions); - - switch (tensor.GetTensorDataType()) - { - case MLOperatorTensorDataType::Int32: - { - const int32_t* data = tensor.GetData(); - result.assign(data, data + elementCount); - } - break; - - case MLOperatorTensorDataType::Int64: - { - const int64_t* data = tensor.GetData(); - result.reserve(elementCount); - - // Use clamped cast rather than static_cast/narrow_cast, - // because it's not uncommon for a model to specify a - // 64-bit INTMAX constant as a sentinel value to mean - // the largest possible value (even though the actual - // dimension values come nowhere close to that, far - // less than 32-bit INTMAX). - for (auto d : gsl::make_span(data, data + elementCount)) - { - result.push_back(clamp_cast(d)); - } - } - break; - - default: - ML_INVALID_ARGUMENT("Expecting CPU local tensor of type int32 or int64."); - break; - } - } - - void ReadCpuLocalTensorIntoFloat32( - const MLOperatorTensor& tensor, - std::vector& result - ) - { - result.clear(); - ML_CHECK_VALID_ARGUMENT(tensor.IsCpuData(), "Tensor must be CPU Tensor."); - - const std::vector& tensorDimensions = tensor.GetShape(); - const uint32_t elementCount = ComputeElementCountFromDimensions(tensorDimensions); - - switch (tensor.GetTensorDataType()) - { - case MLOperatorTensorDataType::Float: - { - const float* data = tensor.GetData(); - result.assign(data, data + elementCount); - } - break; - - default: - ML_INVALID_ARGUMENT("Expecting CPU local tensor of type float32."); - break; - } - } - - void DowncastDimensions(gsl::span inputDimensions, std::vector& outputDimensions) - { - outputDimensions.reserve(inputDimensions.size()); - outputDimensions.clear(); - - for (int64_t dim : inputDimensions) - { - outputDimensions.push_back(gsl::narrow_cast(std::clamp(dim, INT32_MIN, INT32_MAX))); - } - } - float CastFloat16ToFloat32(uint16_t input) { // Promote float16m10e5s1 to float32m23e8s1. @@ -201,6 +120,130 @@ namespace OperatorHelper } #pragma warning(pop) + void ReadCpuLocalTensorIntoInt32( + const MLOperatorTensor& tensor, + std::vector& result + ) + { + result.clear(); + ML_CHECK_VALID_ARGUMENT(tensor.IsCpuData(), "Tensor must be CPU Tensor."); + + const std::vector& tensorDimensions = tensor.GetShape(); + const uint32_t elementCount = ComputeElementCountFromDimensions(tensorDimensions); + + switch (tensor.GetTensorDataType()) + { + case MLOperatorTensorDataType::Int32: + { + const int32_t* data = tensor.GetData(); + result.assign(data, data + elementCount); + } + break; + + case MLOperatorTensorDataType::Int64: + { + const int64_t* data = tensor.GetData(); + result.reserve(elementCount); + + // Use clamped cast rather than static_cast/narrow_cast, + // because it's not uncommon for a model to specify a + // 64-bit INTMAX constant as a sentinel value to mean + // the largest possible value (even though the actual + // dimension values come nowhere close to that, far + // less than 32-bit INTMAX). + for (auto d : gsl::make_span(data, data + elementCount)) + { + result.push_back(clamp_cast(d)); + } + } + break; + + default: + ML_INVALID_ARGUMENT("Expecting CPU local tensor of type int32 or int64."); + break; + } + } + + void ReadCpuLocalTensorIntoFloat32( + const MLOperatorTensor& tensor, + std::vector& result + ) + { + result.clear(); + ML_CHECK_VALID_ARGUMENT(tensor.IsCpuData(), "Tensor must be CPU Tensor."); + + const std::vector& tensorDimensions = tensor.GetShape(); + const uint32_t elementCount = ComputeElementCountFromDimensions(tensorDimensions); + result.resize(elementCount); + + switch (tensor.GetTensorDataType()) + { + case MLOperatorTensorDataType::Float16: + { + const uint16_t* data = reinterpret_cast(tensor.GetByteData()); + std::transform(data, data + elementCount, result.begin(), CastFloat16ToFloat32); + } + break; + + case MLOperatorTensorDataType::/*Float32*/Float: + { + const float* data = tensor.GetData(); + result.assign(data, data + elementCount); + } + break; + + case MLOperatorTensorDataType::/*Float64*/Double: + { + const double* data = tensor.GetData(); + std::transform(data, data + elementCount, result.begin(), [](auto v) {return static_cast(v); }); + } + break; + + case MLOperatorTensorDataType::Int32: + { + const int32_t* data = tensor.GetData(); + std::transform(data, data + elementCount, result.begin(), [](auto v) {return static_cast(v); }); + } + break; + + case MLOperatorTensorDataType::UInt32: + { + const uint32_t* data = tensor.GetData(); + std::transform(data, data + elementCount, result.begin(), [](auto v) {return static_cast(v); }); + } + break; + + case MLOperatorTensorDataType::Int64: + { + const int64_t* data = tensor.GetData(); + std::transform(data, data + elementCount, result.begin(), [](auto v) {return static_cast(v); }); + } + break; + + case MLOperatorTensorDataType::UInt64: + { + const uint64_t* data = tensor.GetData(); + std::transform(data, data + elementCount, result.begin(), [](auto v) {return static_cast(v); }); + } + break; + + default: + ML_INVALID_ARGUMENT("Expecting CPU local tensor of type float32."); + break; + } + } + + void DowncastDimensions(gsl::span inputDimensions, std::vector& outputDimensions) + { + outputDimensions.reserve(inputDimensions.size()); + outputDimensions.clear(); + + for (int64_t dim : inputDimensions) + { + outputDimensions.push_back(gsl::narrow_cast(std::clamp(dim, INT32_MIN, INT32_MAX))); + } + } + int64_t IsFloatDataType(MLOperatorTensorDataType tensorDataType) { switch (tensorDataType) @@ -1089,7 +1132,7 @@ namespace OperatorHelper ML_CHECK_VALID_ARGUMENT(inputCount + 1 == m_components.size(), "Mismatch between input tensor count and string equation component count."); ML_CHECK_VALID_ARGUMENT(outputCount == 1, "EinSum expects exactly 1 output tensor."); - std::vector labelSizes(m_labelIndices.size(), static_cast(INT_MIN)); + std::vector labelSizes(m_labelIndices.size(), UINT_MAX); // Read every input tensor, comparing labels to ensure consistent sizes from the equation parsed earlier. for (uint32_t i = 0; i < inputCount; ++i) @@ -1110,7 +1153,7 @@ namespace OperatorHelper uint32_t labelIndex = labelIndices[j]; assert(labelIndex < labelSizes.size()); - if (labelSizes[labelIndex] == INT_MIN) + if (labelSizes[labelIndex] == UINT_MAX) { labelSizes[labelIndex] = dimensionSize; } diff --git a/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.h b/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.h index de927efbfe..dbaa14d2d5 100644 --- a/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.h +++ b/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.h @@ -6,7 +6,7 @@ #include "Common.h" #include "Attributes.h" #include "core/common/common.h" -#include "..\DmlExecutionProvider\src\ErrorHandling.h" +#include "../DmlExecutionProvider/src/ErrorHandling.h" #include "MLOperatorAuthorHelper.h" namespace OperatorHelper { @@ -26,14 +26,14 @@ std::vector BroadcastTensorShape( #endif template void FindValueIndices(gsl::span values, T value, /*out*/ std::vector& indices) { - indices.clear(); - for (size_t i = 0, valuesCount = values.size(); i < valuesCount; ++i) { - // Work around compiler bug on x86 release by using data() rather than operator [] directly. - // cl.exe 19.20.27412.4 for x86 - if (values.data()[i] == value) { - indices.push_back(gsl::narrow_cast(i)); + indices.clear(); + for (size_t i = 0, valuesCount = values.size(); i < valuesCount; ++i) { + // Work around compiler bug on x86 release by using data() rather than operator [] directly. + // cl.exe 19.20.27412.4 for x86 + if (values.data()[i] == value) { + indices.push_back(gsl::narrow_cast(i)); + } } - } } #ifndef __clang__ #pragma optimize("", on) @@ -54,48 +54,49 @@ void HandleNegativeAxes(gsl::span onnxAxes, uint32_t dimCount); // output values = {2,3,5} template void RemoveValuesByIndex(gsl::span indices, bool keepOneValue, /*inout*/ std::vector& values) { - assert(std::is_sorted(indices.begin(), indices.end())); + assert(std::is_sorted(indices.begin(), indices.end())); - // Keep the last value at least, if all values would otherwise be removed. - if (keepOneValue && !indices.empty() && indices.size() == values.size()) { - indices = indices.first(indices.size() - 1); - } - - auto indicesIterator = indices.begin(); - auto indicesEnd = indices.end(); - size_t oldValuesCount = values.size(); - size_t newValuesCount = 0; - size_t nextIndex = (indicesIterator == indicesEnd) ? SIZE_MAX : *(indicesIterator++); - - // For every value, either skip the entry, or copy it to the output. - for (size_t i = 0; i < oldValuesCount; ++i) { - if (i == nextIndex) // Skip and remove entry. - { - nextIndex = (indicesIterator == indicesEnd) ? SIZE_MAX : *(indicesIterator++); - } else // Keep and copy entry. - { - values[newValuesCount++] = values[i]; + // Keep the last value at least, if all values would otherwise be removed. + if (keepOneValue && !indices.empty() && indices.size() == values.size()) { + indices = indices.first(indices.size() - 1); } - } - values.resize(newValuesCount); + + auto indicesIterator = indices.begin(); + auto indicesEnd = indices.end(); + size_t oldValuesCount = values.size(); + size_t newValuesCount = 0; + size_t nextIndex = (indicesIterator == indicesEnd) ? SIZE_MAX : *(indicesIterator++); + + // For every value, either skip the entry, or copy it to the output. + for (size_t i = 0; i < oldValuesCount; ++i) { + if (i == nextIndex) // Skip and remove entry. + { + nextIndex = (indicesIterator == indicesEnd) ? SIZE_MAX : *(indicesIterator++); + } + else // Keep and copy entry. + { + values[newValuesCount++] = values[i]; + } + } + values.resize(newValuesCount); } template void FillWithLeadingValues(/*inout*/ std::vector& values, uint32_t minimumElementCount, T fillValue) { - // e.g. - // input = [6,7] - // elementCount = 4 - // fillValue = 1 - // output = [1,1,6,7] + // e.g. + // input = [6,7] + // elementCount = 4 + // fillValue = 1 + // output = [1,1,6,7] - const size_t oldElementCount = values.size(); - const size_t newElementCount = std::max(size_t(minimumElementCount), oldElementCount); - const size_t fillCount = newElementCount - oldElementCount; + const size_t oldElementCount = values.size(); + const size_t newElementCount = std::max(size_t(minimumElementCount), oldElementCount); + const size_t fillCount = newElementCount - oldElementCount; - values.resize(newElementCount); - std::copy_backward(values.begin(), values.begin() + oldElementCount, values.end()); - std::fill_n(values.data(), fillCount, fillValue); + values.resize(newElementCount); + std::copy_backward(values.begin(), values.begin() + oldElementCount, values.end()); + std::fill_n(values.data(), fillCount, fillValue); } int64_t CastToInt64(MLOperatorTensorDataType tensorDataType, const void* p); @@ -108,76 +109,76 @@ void ReadCpuLocalTensorIntoInt32(const MLOperatorTensor& tensor, std::vector& result); class EdgeShapes { - public: - EdgeShapes() = default; - EdgeShapes(const std::vector& dim) { m_shapes = dim; } - EdgeShapes(const std::initializer_list& dim) { m_shapes.assign(dim.begin(), dim.end()); } - EdgeShapes(const gsl::span dim) { m_shapes.assign(dim.begin(), dim.end()); } +public: + EdgeShapes() = default; + EdgeShapes(const std::vector& dim) { m_shapes = dim; } + EdgeShapes(const std::initializer_list& dim) { m_shapes.assign(dim.begin(), dim.end()); } + EdgeShapes(const gsl::span dim) { m_shapes.assign(dim.begin(), dim.end()); } - bool IsTensor() { return true; } - bool IsUnused() { return m_shapes.empty(); } + bool IsTensor() { return true; } + bool IsUnused() { return m_shapes.empty(); } - std::vector& GetShape() { return m_shapes; } + std::vector& GetShape() { return m_shapes; } - private: - std::vector m_shapes; +private: + std::vector m_shapes; }; struct KernelArgs { - // Initialize arrays up to NcdhwSpatialDimensionCount to avoid vector allocations, - // but it's important to use .spatialDimensionCount when accessing them because - // values beyond that may be bogus. - uint32_t strides[NcdhwSpatialDimensionCount]; - uint32_t dilations[NcdhwSpatialDimensionCount]; - uint32_t windowSize[NcdhwSpatialDimensionCount]; // The filter kernel dimensions. - uint32_t startPadding[NcdhwSpatialDimensionCount]; - uint32_t endPadding[NcdhwSpatialDimensionCount]; - uint32_t outputPadding[NcdhwSpatialDimensionCount]; + // Initialize arrays up to NcdhwSpatialDimensionCount to avoid vector allocations, + // but it's important to use .spatialDimensionCount when accessing them because + // values beyond that may be bogus. + uint32_t strides[NcdhwSpatialDimensionCount]; + uint32_t dilations[NcdhwSpatialDimensionCount]; + uint32_t windowSize[NcdhwSpatialDimensionCount]; // The filter kernel dimensions. + uint32_t startPadding[NcdhwSpatialDimensionCount]; + uint32_t endPadding[NcdhwSpatialDimensionCount]; + uint32_t outputPadding[NcdhwSpatialDimensionCount]; - // This is true if padding must be automatically computed based on input sizes. - // ResolveAutoPadding must happen during Compute rather than initialization. - // This is temporary until kernel initialization routine once Lotus can provide - // sizes at operator initialization. - bool autoPad = false; - bool autoPadSameUpper = false; - bool useCeilingOutputShape = false; - uint32_t spatialDimensionCount = 0; + // This is true if padding must be automatically computed based on input sizes. + // ResolveAutoPadding must happen during Compute rather than initialization. + // This is temporary until kernel initialization routine once Lotus can provide + // sizes at operator initialization. + bool autoPad = false; + bool autoPadSameUpper = false; + bool useCeilingOutputShape = false; + uint32_t spatialDimensionCount = 0; - KernelArgs(uint32_t spatialDimensionCount) : spatialDimensionCount(spatialDimensionCount) - { - ML_CHECK_VALID_ARGUMENT(spatialDimensionCount <= NcdhwSpatialDimensionCount); - } + KernelArgs(uint32_t spatialDimensionCount) : spatialDimensionCount(spatialDimensionCount) + { + ML_CHECK_VALID_ARGUMENT(spatialDimensionCount <= NcdhwSpatialDimensionCount); + } - void FillWithLeadingValues(gsl::span input, gsl::span output, uint32_t fillCount, uint32_t value) { - // e.g. - // input = [5,6,7,8] - // fillcount = 2 - // value = 1 - // output = [1,1,5,6,7,8] + void FillWithLeadingValues(gsl::span input, gsl::span output, uint32_t fillCount, uint32_t value) { + // e.g. + // input = [5,6,7,8] + // fillcount = 2 + // value = 1 + // output = [1,1,5,6,7,8] - const size_t inputCount = input.size(); - const size_t outputCount = output.size(); - const size_t copyCount = std::min(outputCount - fillCount, inputCount); + const size_t inputCount = input.size(); + const size_t outputCount = output.size(); + const size_t copyCount = std::min(outputCount - fillCount, inputCount); - std::fill_n(output.data(), fillCount, value); - std::copy_n(input.data(), copyCount, output.data() + fillCount); - } + std::fill_n(output.data(), fillCount, value); + std::copy_n(input.data(), copyCount, output.data() + fillCount); + } - // Create a copy of an existing kernel args with a minimum dimension count, - // filling the leading attribute values with 1's or 0's respectively. - KernelArgs(KernelArgs const& kernelArgs, uint32_t minimumDimensionCount) : autoPad(kernelArgs.autoPad), - autoPadSameUpper(kernelArgs.autoPadSameUpper), - spatialDimensionCount(std::max(kernelArgs.spatialDimensionCount, minimumDimensionCount)) { - ML_CHECK_VALID_ARGUMENT(spatialDimensionCount <= NcdhwSpatialDimensionCount); + // Create a copy of an existing kernel args with a minimum dimension count, + // filling the leading attribute values with 1's or 0's respectively. + KernelArgs(KernelArgs const& kernelArgs, uint32_t minimumDimensionCount) : autoPad(kernelArgs.autoPad), + autoPadSameUpper(kernelArgs.autoPadSameUpper), + spatialDimensionCount(std::max(kernelArgs.spatialDimensionCount, minimumDimensionCount)) { + ML_CHECK_VALID_ARGUMENT(spatialDimensionCount <= NcdhwSpatialDimensionCount); - uint32_t fillCount = (minimumDimensionCount > kernelArgs.spatialDimensionCount) ? minimumDimensionCount - kernelArgs.spatialDimensionCount : 0; - FillWithLeadingValues(kernelArgs.strides, this->strides, fillCount, 1u); - FillWithLeadingValues(kernelArgs.dilations, this->dilations, fillCount, 1u); - FillWithLeadingValues(kernelArgs.windowSize, this->windowSize, fillCount, 1u); - FillWithLeadingValues(kernelArgs.startPadding, this->startPadding, fillCount, 0u); - FillWithLeadingValues(kernelArgs.endPadding, this->endPadding, fillCount, 0u); - FillWithLeadingValues(kernelArgs.outputPadding, this->outputPadding, fillCount, 0u); - } + uint32_t fillCount = (minimumDimensionCount > kernelArgs.spatialDimensionCount) ? minimumDimensionCount - kernelArgs.spatialDimensionCount : 0; + FillWithLeadingValues(kernelArgs.strides, this->strides, fillCount, 1u); + FillWithLeadingValues(kernelArgs.dilations, this->dilations, fillCount, 1u); + FillWithLeadingValues(kernelArgs.windowSize, this->windowSize, fillCount, 1u); + FillWithLeadingValues(kernelArgs.startPadding, this->startPadding, fillCount, 0u); + FillWithLeadingValues(kernelArgs.endPadding, this->endPadding, fillCount, 0u); + FillWithLeadingValues(kernelArgs.outputPadding, this->outputPadding, fillCount, 0u); + } }; std::vector InitializeKernelOutputDimensions( @@ -200,130 +201,132 @@ void ResolveAutoPadding( gsl::span inputDimensions); void MatMulShapeMapping( - std::vector& inputShape0, - std::vector& inputShape1, - std::vector& outputShape); + std::vector& inputShape0, + std::vector& inputShape1, + std::vector& outputShape); class GetOutputShapeAsInputShapeHelper { - public: - // Info_t is used to obtain attributes which will be used for calculating the output shape later. - // Shape_t is used to obtain input shape which will be used for adjusting attribute value. - // Default to first input tensor. - template - GetOutputShapeAsInputShapeHelper(const Info_t& info, const Shape_t& shape){ - ORT_UNUSED_PARAMETER(info); - ORT_UNUSED_PARAMETER(shape); - }; +public: + // Info_t is used to obtain attributes which will be used for calculating the output shape later. + // Shape_t is used to obtain input shape which will be used for adjusting attribute value. + // Default to first input tensor. + template + GetOutputShapeAsInputShapeHelper(const Info_t& info, const Shape_t& shape) { + ORT_UNUSED_PARAMETER(info); + ORT_UNUSED_PARAMETER(shape); + }; - // Info_t is used to obtain attributes which will be used for calculating the output shape later. - // Shape_t is used to obtain input shape which will be used for adjusting attribute value. - // Pass specific tensor index. - template - GetOutputShapeAsInputShapeHelper(const Info_t& info, const Shape_t& shape, uint32_t inputTensorIndex) - : m_inputTensorIndex(inputTensorIndex) - { - ORT_UNUSED_PARAMETER(info); - ORT_UNUSED_PARAMETER(shape); - }; + // Info_t is used to obtain attributes which will be used for calculating the output shape later. + // Shape_t is used to obtain input shape which will be used for adjusting attribute value. + // Pass specific tensor index. + template + GetOutputShapeAsInputShapeHelper(const Info_t& info, const Shape_t& shape, uint32_t inputTensorIndex) + : m_inputTensorIndex(inputTensorIndex) + { + ORT_UNUSED_PARAMETER(info); + ORT_UNUSED_PARAMETER(shape); + }; - std::vector GetOutputShapes(const MLShapeInferenceContext& shapeInfo) const; + std::vector GetOutputShapes(const MLShapeInferenceContext& shapeInfo) const; - uint32_t m_inputTensorIndex = 0; + uint32_t m_inputTensorIndex = 0; }; template class GetOutputShapeAsSpecificInputShapeHelper : public GetOutputShapeAsInputShapeHelper { public: - // Info_t is used to obtain attributes which will be used for calculating the output shape later. - // Shape_t is used to obtain input shape which will be used for adjusting attribute value. - template - GetOutputShapeAsSpecificInputShapeHelper(const Info_t& info, const Shape_t& shape) - : GetOutputShapeAsInputShapeHelper(info, shape, InputTensorIndex) - {} + // Info_t is used to obtain attributes which will be used for calculating the output shape later. + // Shape_t is used to obtain input shape which will be used for adjusting attribute value. + template + GetOutputShapeAsSpecificInputShapeHelper(const Info_t& info, const Shape_t& shape) + : GetOutputShapeAsInputShapeHelper(info, shape, InputTensorIndex) + {} }; class GetBroadcastedOutputShapeHelper { - public: - // Info_t is used to obtain attributes which will be used for calculating the output shape later. - // Shape_t is used to obtain input shape which will be used for adjusting attribute value. - template - GetBroadcastedOutputShapeHelper(const Info_t& info, const Shape_t& shape){}; +public: + // Info_t is used to obtain attributes which will be used for calculating the output shape later. + // Shape_t is used to obtain input shape which will be used for adjusting attribute value. + template + GetBroadcastedOutputShapeHelper(const Info_t& info, const Shape_t& shape) {}; - std::vector GetOutputShapes(const MLShapeInferenceContext& shapeInfo) const; + std::vector GetOutputShapes(const MLShapeInferenceContext& shapeInfo) const; }; class RandomUniformHelperBase { - public: - // Info_t is used to obtain attributes which will be used for calculating the output shape later. - template - RandomUniformHelperBase(const Info_t& info) { - m_high = info.GetOptionalAttribute(AttrName::High, 1.0f); - m_low = info.GetOptionalAttribute(AttrName::Low, 0.0f); +public: + // Info_t is used to obtain attributes which will be used for calculating the output shape later. + template + RandomUniformHelperBase(const Info_t& info) { + m_high = info.GetOptionalAttribute(AttrName::High, 1.0f); + m_low = info.GetOptionalAttribute(AttrName::Low, 0.0f); - if (info.HasAttribute(AttrName::Seed, MLOperatorAttributeType::Float)) { - m_seed = info.GetAttribute(AttrName::Seed); - } else { - m_seed = static_cast(std::chrono::high_resolution_clock::now().time_since_epoch().count()); + if (info.HasAttribute(AttrName::Seed, MLOperatorAttributeType::Float)) { + m_seed = info.GetAttribute(AttrName::Seed); + } + else { + m_seed = static_cast(std::chrono::high_resolution_clock::now().time_since_epoch().count()); + } } - } - protected: - float m_high; - float m_low; - float m_seed; +protected: + float m_high; + float m_low; + float m_seed; }; class RandomUniformHelper : public RandomUniformHelperBase { - public: - template - RandomUniformHelper(const Info_t& info, const Shape_t& shape) : RandomUniformHelperBase(info) { - auto shapeAttribute = info.GetOptionalAttributeVectorInt32(AttrName::Shape); - ML_CHECK_VALID_ARGUMENT(!shapeAttribute.empty(), "Attribute shape is missing."); - m_tensorShape.assign(shapeAttribute.begin(), shapeAttribute.end()); - } +public: + template + RandomUniformHelper(const Info_t& info, const Shape_t& shape) : RandomUniformHelperBase(info) { + auto shapeAttribute = info.GetOptionalAttributeVectorInt32(AttrName::Shape); + ML_CHECK_VALID_ARGUMENT(!shapeAttribute.empty(), "Attribute shape is missing."); + m_tensorShape.assign(shapeAttribute.begin(), shapeAttribute.end()); + } - std::vector GetOutputShapes(const MLShapeInferenceContext& shapeInfo) const; + std::vector GetOutputShapes(const MLShapeInferenceContext& shapeInfo) const; - private: - // Returns an empty vector if the optional attribute is missing. - std::vector m_tensorShape; +private: + // Returns an empty vector if the optional attribute is missing. + std::vector m_tensorShape; }; class RandomNormalHelperBase { - public: - // Info_t is used to obtain attributes which will be used for calculating the output shape later. - template - RandomNormalHelperBase(const Info_t& info) { - m_mean = info.GetOptionalAttribute(AttrName::Mean, 0.0f); - m_scale = info.GetOptionalAttribute(AttrName::Scale, 1.0f); +public: + // Info_t is used to obtain attributes which will be used for calculating the output shape later. + template + RandomNormalHelperBase(const Info_t& info) { + m_mean = info.GetOptionalAttribute(AttrName::Mean, 0.0f); + m_scale = info.GetOptionalAttribute(AttrName::Scale, 1.0f); - if (info.HasAttribute(AttrName::Seed, MLOperatorAttributeType::Float)) { - m_seed = info.GetAttribute(AttrName::Seed); - } else { - m_seed = static_cast(std::chrono::high_resolution_clock::now().time_since_epoch().count()); + if (info.HasAttribute(AttrName::Seed, MLOperatorAttributeType::Float)) { + m_seed = info.GetAttribute(AttrName::Seed); + } + else { + m_seed = static_cast(std::chrono::high_resolution_clock::now().time_since_epoch().count()); + } } - } - protected: - float m_mean; - float m_scale; - float m_seed; +protected: + float m_mean; + float m_scale; + float m_seed; }; class RandomNormalHelper : public RandomNormalHelperBase { - public: - template - RandomNormalHelper(const Info_t& info, const Shape_t& shape) : RandomNormalHelperBase(info) { - auto shapeAttribute = info.GetOptionalAttributeVectorInt32(AttrName::Shape); - ML_CHECK_VALID_ARGUMENT(!shapeAttribute.empty(), "Attribute shape is missing."); - m_tensorShape.assign(shapeAttribute.begin(), shapeAttribute.end()); - } +public: + template + RandomNormalHelper(const Info_t& info, const Shape_t& shape) : RandomNormalHelperBase(info) { + auto shapeAttribute = info.GetOptionalAttributeVectorInt32(AttrName::Shape); + ML_CHECK_VALID_ARGUMENT(!shapeAttribute.empty(), "Attribute shape is missing."); + m_tensorShape.assign(shapeAttribute.begin(), shapeAttribute.end()); + } - std::vector GetOutputShapes(const MLShapeInferenceContext& shapeInfo) const; + std::vector GetOutputShapes(const MLShapeInferenceContext& shapeInfo) const; - private: - // Returns an empty vector if the optional attribute is missing. - std::vector m_tensorShape; +private: + // Returns an empty vector if the optional attribute is missing. + std::vector m_tensorShape; }; class ConvolutionHelperBase @@ -355,8 +358,8 @@ public: void ResolvingPadding(gsl::span inputDimensions); const std::vector& GetOutputShapes(const MLShapeInferenceContext& shapeInfo) const { - ORT_UNUSED_PARAMETER(shapeInfo); - return m_outputShapes; + ORT_UNUSED_PARAMETER(shapeInfo); + return m_outputShapes; } template @@ -460,14 +463,14 @@ public: m_kernel.endPadding[i] = paddings - m_kernel.startPadding[i]; } } - } + } - protected: - uint32_t m_groupCount; - uint32_t m_inputTensorIndex; - uint32_t m_filterTensorIndex; - KernelArgs m_kernel; - std::vector m_outputShapes; +protected: + uint32_t m_groupCount; + uint32_t m_inputTensorIndex; + uint32_t m_filterTensorIndex; + KernelArgs m_kernel; + std::vector m_outputShapes; }; class ConvHelper : public ConvolutionHelperBase @@ -514,57 +517,59 @@ public: m_beta = info.template GetOptionalAttribute(AttrName::Beta, 0.0f); } - std::vector GetOutputShapes(const MLShapeInferenceContext& shapeInfo) const; + std::vector GetOutputShapes(const MLShapeInferenceContext& shapeInfo) const; - enum InputTensors { IN_A, - IN_B, - IN_C }; + enum InputTensors { + IN_A, + IN_B, + IN_C + }; - protected: - bool m_transA = false; - bool m_transB = false; - bool m_broadcast = false; - float m_alpha = 0.0f; - float m_beta = 0.0f; +protected: + bool m_transA = false; + bool m_transB = false; + bool m_broadcast = false; + float m_alpha = 0.0f; + float m_beta = 0.0f; }; class TransposeHelper { - public: - void Initialize( - const MLOperatorAttributes& operatorAttributes, - gsl::span inputDimensions); +public: + void Initialize( + const MLOperatorAttributes& operatorAttributes, + gsl::span inputDimensions); - // Info_t is used to obtain attributes which will be used for calculating the output shape later. - // Shape_t is used to obtain input shape which will be used for adjusting attribute value. - template - TransposeHelper(const Info_t& info, const Shape_t& shape) { - Initialize(info, shape.GetInputTensorShape(0)); - } + // Info_t is used to obtain attributes which will be used for calculating the output shape later. + // Shape_t is used to obtain input shape which will be used for adjusting attribute value. + template + TransposeHelper(const Info_t& info, const Shape_t& shape) { + Initialize(info, shape.GetInputTensorShape(0)); + } - std::vector GetOutputShapes(const MLShapeInferenceContext& shapeInfo) const; + std::vector GetOutputShapes(const MLShapeInferenceContext& shapeInfo) const; - protected: - std::vector m_permutations; +protected: + std::vector m_permutations; }; class SplitHelper { - public: - void Initialize( - const MLOperatorAttributes& operatorAttributes, - gsl::span inputDimensions); +public: + void Initialize( + const MLOperatorAttributes& operatorAttributes, + gsl::span inputDimensions); - // Info_t is used to obtain attributes which will be used for calculating the output shape later. - // Shape_t is used to obtain input shape which will be used for adjusting attribute value. - template - SplitHelper(const Info_t& info, const Shape_t& shape) { - Initialize(info, shape.GetInputTensorShape(0)); - } + // Info_t is used to obtain attributes which will be used for calculating the output shape later. + // Shape_t is used to obtain input shape which will be used for adjusting attribute value. + template + SplitHelper(const Info_t& info, const Shape_t& shape) { + Initialize(info, shape.GetInputTensorShape(0)); + } - std::vector GetOutputShapes(const MLShapeInferenceContext& shapeInfo) const; + std::vector GetOutputShapes(const MLShapeInferenceContext& shapeInfo) const; - protected: - int m_axis = 0; - std::vector m_split; +protected: + int m_axis = 0; + std::vector m_split; }; class SliceHelper @@ -604,18 +609,18 @@ public: std::vector axes; std::vector steps; - if (opsetVersion == 7) + if (opsetVersion >= 10) + { + // Read starts, ends, and axes from tensors. + ReadIndexTensors(operatorInfo, /*out*/ starts, /*out*/ ends, /*out*/ axes, /*out*/ steps); + } + else if (opsetVersion >= 7) { // Read starts, ends, and axes from attributes. starts = operatorInfo.GetOptionalAttributeVectorInt32(AttrName::Starts); ends = operatorInfo.GetOptionalAttributeVectorInt32(AttrName::Ends); axes = operatorInfo.GetOptionalAttributeVectorInt32(AttrName::Axes); } - else if (opsetVersion == 10 || opsetVersion == 11) - { - // Read starts, ends, and axes from tensors. - ReadIndexTensors(operatorInfo, /*out*/ starts, /*out*/ ends, /*out*/ axes, /*out*/ steps); - } const uint32_t inputDimensionCount = gsl::narrow_cast(inputDimensions.size()); HandleNegativeAxes(/*inout*/ axes, inputDimensionCount); @@ -683,13 +688,13 @@ public: Initialize(info, shape.GetInputTensorShape(0), opsetVersion); } - std::vector GetOutputShapes(const MLShapeInferenceContext& shapeInfo) const; + std::vector GetOutputShapes(const MLShapeInferenceContext& shapeInfo) const; - protected: - std::vector m_outputDimensions; - std::vector m_offsets; - std::vector m_sizes; - std::vector m_strides; +protected: + std::vector m_outputDimensions; + std::vector m_offsets; + std::vector m_sizes; + std::vector m_strides; }; class PaddingHelper @@ -697,29 +702,29 @@ class PaddingHelper public: void Initialize(const MLOperatorAttributes& operatorAttributes, gsl::span padding, uint32_t opsetVersion); - // Info_t is used to obtain attributes which will be used for calculating the output shape later. - // Shape_t is used to obtain input shape which will be used for adjusting attribute value. - template - PaddingHelper(const Info_t& info, const Shape_t& shape, uint32_t opsetVersion) { - std::vector padding; - if (opsetVersion >= 11) - { - MLOperatorTensor padsTensor = info.GetConstantInputTensor(1); - ReadCpuLocalTensorIntoInt32(padsTensor, /*out*/ padding); - } - else - { - padding = info.GetOptionalAttributeVectorInt32(AttrName::Pads); + // Info_t is used to obtain attributes which will be used for calculating the output shape later. + // Shape_t is used to obtain input shape which will be used for adjusting attribute value. + template + PaddingHelper(const Info_t& info, const Shape_t& shape, uint32_t opsetVersion) { + std::vector padding; + if (opsetVersion >= 11) + { + MLOperatorTensor padsTensor = info.GetConstantInputTensor(1); + ReadCpuLocalTensorIntoInt32(padsTensor, /*out*/ padding); + } + else + { + padding = info.GetOptionalAttributeVectorInt32(AttrName::Pads); + } + + Initialize(info, padding, opsetVersion); } - Initialize(info, padding, opsetVersion); - } + std::vector GetOutputShapes(const MLShapeInferenceContext& shapeInfo) const; - std::vector GetOutputShapes(const MLShapeInferenceContext& shapeInfo) const; - - protected: - std::vector m_startPadding; - std::vector m_endPadding; +protected: + std::vector m_startPadding; + std::vector m_endPadding; }; template @@ -731,48 +736,49 @@ public: }; class ReduceHelperBase { - public: - // Info_t is used to obtain attributes which will be used for calculating the output shape later. - // Shape_t is used to obtain input shape which will be used for adjusting attribute value. - template - ReduceHelperBase(const Info_t& info, const Shape_t& shape, bool usingAxes) { - m_keepDims = info.template GetOptionalAttribute(AttrName::KeepDims, 1); - m_selectLastIndex = info.template GetOptionalAttribute(AttrName::SelectLastIndex, 0); - if (usingAxes) { - m_axes = info.GetOptionalAttributeVectorInt32(AttrName::Axes); - } else { - int axis = info.template GetOptionalAttribute(AttrName::Axis, 0); - m_axes.push_back(axis); +public: + // Info_t is used to obtain attributes which will be used for calculating the output shape later. + // Shape_t is used to obtain input shape which will be used for adjusting attribute value. + template + ReduceHelperBase(const Info_t& info, const Shape_t& shape, bool usingAxes) { + m_keepDims = info.template GetOptionalAttribute(AttrName::KeepDims, 1); + m_selectLastIndex = info.template GetOptionalAttribute(AttrName::SelectLastIndex, 0); + if (usingAxes) { + m_axes = info.GetOptionalAttributeVectorInt32(AttrName::Axes); + } + else { + int axis = info.template GetOptionalAttribute(AttrName::Axis, 0); + m_axes.push_back(axis); + } + std::vector inputShape = shape.GetInputTensorShape(0); + HandleNegativeAxes(/*inout*/ m_axes, gsl::narrow_cast(inputShape.size())); + AdjustAxesAndOutputShape(inputShape); } - std::vector inputShape = shape.GetInputTensorShape(0); - HandleNegativeAxes(/*inout*/ m_axes, gsl::narrow_cast(inputShape.size())); - AdjustAxesAndOutputShape(inputShape); - } - std::vector GetOutputShapes(const MLShapeInferenceContext& shapeInfo) const; + std::vector GetOutputShapes(const MLShapeInferenceContext& shapeInfo) const; - private: - void AdjustAxesAndOutputShape(const std::vector& inputShape); +private: + void AdjustAxesAndOutputShape(const std::vector& inputShape); - protected: - std::vector m_axes; - int m_keepDims = 0; - int m_selectLastIndex = 0; +protected: + std::vector m_axes; + int m_keepDims = 0; + int m_selectLastIndex = 0; }; class ArgMinArgMaxHelper : public ReduceHelperBase { - public: - // Info_t is used to obtain attributes which will be used for calculating the output shape later. - // Shape_t is used to obtain input shape which will be used for adjusting attribute value. - template - ArgMinArgMaxHelper(const Info_t& info, const Shape_t& shape) : ReduceHelperBase(info, shape, false) {} +public: + // Info_t is used to obtain attributes which will be used for calculating the output shape later. + // Shape_t is used to obtain input shape which will be used for adjusting attribute value. + template + ArgMinArgMaxHelper(const Info_t& info, const Shape_t& shape) : ReduceHelperBase(info, shape, false) {} }; class ReduceHelper : public ReduceHelperBase { - public: - // Info_t is used to obtain attributes which will be used for calculating the output shape later. - // Shape_t is used to obtain input shape which will be used for adjusting attribute value. - template - ReduceHelper(const Info_t& info, const Shape_t& shape) : ReduceHelperBase(info, shape, true) {} +public: + // Info_t is used to obtain attributes which will be used for calculating the output shape later. + // Shape_t is used to obtain input shape which will be used for adjusting attribute value. + template + ReduceHelper(const Info_t& info, const Shape_t& shape) : ReduceHelperBase(info, shape, true) {} }; class EinSumHelper @@ -842,19 +848,19 @@ protected: }; class MatMulHelperBase { - public: - // Info_t is used to obtain attributes which will be used for calculating the output shape later. - // Shape_t is used to obtain input shape which will be used for adjusting attribute value. - template - MatMulHelperBase(const Info_t& info, const Shape_t& shape, uint32_t aTensorIndex, uint32_t bTensorIndex) : - m_aTensorIndex(aTensorIndex), - m_bTensorIndex(bTensorIndex) - {} +public: + // Info_t is used to obtain attributes which will be used for calculating the output shape later. + // Shape_t is used to obtain input shape which will be used for adjusting attribute value. + template + MatMulHelperBase(const Info_t& info, const Shape_t& shape, uint32_t aTensorIndex, uint32_t bTensorIndex) : + m_aTensorIndex(aTensorIndex), + m_bTensorIndex(bTensorIndex) + {} - std::vector GetOutputShapes(const MLShapeInferenceContext& shapeInfo) const; - protected: - uint32_t m_aTensorIndex = 0; - uint32_t m_bTensorIndex = 1; + std::vector GetOutputShapes(const MLShapeInferenceContext& shapeInfo) const; +protected: + uint32_t m_aTensorIndex = 0; + uint32_t m_bTensorIndex = 1; }; class MatMulHelper : public MatMulHelperBase @@ -873,230 +879,235 @@ public: class TopKHelper { - public: - // Info_t is used to obtain attributes which will be used for calculating the output shape later. - // Shape_t is used to obtain input shape which will be used for adjusting attribute value. - template - TopKHelper(const Info_t& info, const Shape_t& shape, uint32_t opsetVersion) { - int32_t k; - if (opsetVersion >= 10) { - MLOperatorTensor kTensor = info.GetConstantInputTensor(1); - k = gsl::narrow_cast(ReadScalarTensorCastToInt64(kTensor)); - } else { - k = info.template GetOptionalAttribute(AttrName::K, -1); +public: + // Info_t is used to obtain attributes which will be used for calculating the output shape later. + // Shape_t is used to obtain input shape which will be used for adjusting attribute value. + template + TopKHelper(const Info_t& info, const Shape_t& shape, uint32_t opsetVersion) { + int32_t k; + if (opsetVersion >= 10) { + MLOperatorTensor kTensor = info.GetConstantInputTensor(1); + k = gsl::narrow_cast(ReadScalarTensorCastToInt64(kTensor)); + } + else { + k = info.template GetOptionalAttribute(AttrName::K, -1); + } + ML_CHECK_VALID_ARGUMENT(k >= 0, "Attribute k is missing or negative."); + m_k = k; + + auto inputShape = shape.GetInputTensorShape(0); + int32_t axis = info.template GetOptionalAttribute(AttrName::Axis, -1); + m_axis = HandleNegativeAxis(axis, gsl::narrow_cast(inputShape.size())); } - ML_CHECK_VALID_ARGUMENT(k >= 0, "Attribute k is missing or negative."); - m_k = k; - auto inputShape = shape.GetInputTensorShape(0); - int32_t axis = info.template GetOptionalAttribute(AttrName::Axis, -1); - m_axis = HandleNegativeAxis(axis, gsl::narrow_cast(inputShape.size())); - } + std::vector GetOutputShapes(const MLShapeInferenceContext& shapeInfo) const; - std::vector GetOutputShapes(const MLShapeInferenceContext& shapeInfo) const; - - protected: - uint32_t m_k; - uint32_t m_axis; +protected: + uint32_t m_k; + uint32_t m_axis; }; class RecurrentHelper { - public: - // Info_t is used to obtain attributes which will be used for calculating the output shape later. - // Shape_t is used to obtain input shape which will be used for adjusting attribute value. - template - RecurrentHelper(const Info_t& info, const Shape_t& shape) { - m_hiddenSize = info.template GetOptionalAttribute(AttrName::HiddenSize, 1); - } +public: + // Info_t is used to obtain attributes which will be used for calculating the output shape later. + // Shape_t is used to obtain input shape which will be used for adjusting attribute value. + template + RecurrentHelper(const Info_t& info, const Shape_t& shape) { + m_hiddenSize = info.template GetOptionalAttribute(AttrName::HiddenSize, 1); + } - std::vector GetOutputShapes(const MLShapeInferenceContext& shapeInfo) const; + std::vector GetOutputShapes(const MLShapeInferenceContext& shapeInfo) const; - protected: - int m_hiddenSize = 0; +protected: + int m_hiddenSize = 0; }; class ConcatHelper { - public: - void Initialize( - const MLOperatorAttributes& operatorAttributes, - gsl::span inputDimensions); +public: + void Initialize( + const MLOperatorAttributes& operatorAttributes, + gsl::span inputDimensions); - // Info_t is used to obtain attributes which will be used for calculating the output shape later. - // Shape_t is used to obtain input shape which will be used for adjusting attribute value. - template - ConcatHelper(const Info_t& info, const Shape_t& shape) { - Initialize(info, shape.GetInputTensorShape(0)); - } + // Info_t is used to obtain attributes which will be used for calculating the output shape later. + // Shape_t is used to obtain input shape which will be used for adjusting attribute value. + template + ConcatHelper(const Info_t& info, const Shape_t& shape) { + Initialize(info, shape.GetInputTensorShape(0)); + } - std::vector GetOutputShapes(const MLShapeInferenceContext& shapeInfo) const; + std::vector GetOutputShapes(const MLShapeInferenceContext& shapeInfo) const; - protected: - int m_axis; +protected: + int m_axis; }; class CropHelper { - public: - enum BorderDim { Left, - Top, - Right, - Bottom }; - enum ScaleDim { Height, - Width }; +public: + enum BorderDim { + Left, + Top, + Right, + Bottom + }; + enum ScaleDim { + Height, + Width + }; - void Initialize( - const MLOperatorAttributes& operatorAttributes, - gsl::span inputDimensions); + void Initialize( + const MLOperatorAttributes& operatorAttributes, + gsl::span inputDimensions); - // Info_t is used to obtain attributes which will be used for calculating the output shape later. - // Shape_t is used to obtain input shape which will be used for adjusting attribute value. - template - CropHelper(const Info_t& info, const Shape_t& shape) { - Initialize(info, shape.GetInputTensorShape(0)); - } + // Info_t is used to obtain attributes which will be used for calculating the output shape later. + // Shape_t is used to obtain input shape which will be used for adjusting attribute value. + template + CropHelper(const Info_t& info, const Shape_t& shape) { + Initialize(info, shape.GetInputTensorShape(0)); + } - std::vector GetOutputShapes(const MLShapeInferenceContext& shapeInfo) const; + std::vector GetOutputShapes(const MLShapeInferenceContext& shapeInfo) const; - protected: - uint32_t m_offsets[NchwDimensionCount]; - uint32_t m_sizes[NchwSpatialDimensionCount]; +protected: + uint32_t m_offsets[NchwDimensionCount]; + uint32_t m_sizes[NchwSpatialDimensionCount]; }; class DepthToSpaceHelper { - public: - // Info_t is used to obtain attributes which will be used for calculating the output shape later. - // Shape_t is used to obtain input shape which will be used for adjusting attribute value. - template - DepthToSpaceHelper(const Info_t& info, const Shape_t& shape) { - m_blockSize = info.template GetOptionalAttribute(AttrName::BlockSize, -1); - ML_CHECK_VALID_ARGUMENT(m_blockSize > 0, "Attribute blocksize is missing or equal to zero."); - } +public: + // Info_t is used to obtain attributes which will be used for calculating the output shape later. + // Shape_t is used to obtain input shape which will be used for adjusting attribute value. + template + DepthToSpaceHelper(const Info_t& info, const Shape_t& shape) { + m_blockSize = info.template GetOptionalAttribute(AttrName::BlockSize, -1); + ML_CHECK_VALID_ARGUMENT(m_blockSize > 0, "Attribute blocksize is missing or equal to zero."); + } - std::vector GetOutputShapes(const MLShapeInferenceContext& shapeInfo) const; + std::vector GetOutputShapes(const MLShapeInferenceContext& shapeInfo) const; - protected: - int32_t m_blockSize; +protected: + int32_t m_blockSize; }; class SpaceToDepthHelper { - public: - // Info_t is used to obtain attributes which will be used for calculating the output shape later. - // Shape_t is used to obtain input shape which will be used for adjusting attribute value. - template - SpaceToDepthHelper(const Info_t& info, const Shape_t& shape) { - m_blockSize = info.template GetOptionalAttribute(AttrName::BlockSize, -1); - ML_CHECK_VALID_ARGUMENT(m_blockSize > 0, "Attribute blocksize is missing or equal to zero."); - } +public: + // Info_t is used to obtain attributes which will be used for calculating the output shape later. + // Shape_t is used to obtain input shape which will be used for adjusting attribute value. + template + SpaceToDepthHelper(const Info_t& info, const Shape_t& shape) { + m_blockSize = info.template GetOptionalAttribute(AttrName::BlockSize, -1); + ML_CHECK_VALID_ARGUMENT(m_blockSize > 0, "Attribute blocksize is missing or equal to zero."); + } - std::vector GetOutputShapes(const MLShapeInferenceContext& shapeInfo) const; + std::vector GetOutputShapes(const MLShapeInferenceContext& shapeInfo) const; - protected: - int32_t m_blockSize; +protected: + int32_t m_blockSize; }; class FlattenHelper { - public: - void Initialize( - const MLOperatorAttributes& operatorAttributes, - gsl::span inputDimensions); +public: + void Initialize( + const MLOperatorAttributes& operatorAttributes, + gsl::span inputDimensions); - // Info_t is used to obtain attributes which will be used for calculating the output shape later. - // Shape_t is used to obtain input shape which will be used for adjusting attribute value. - template - FlattenHelper(const Info_t& info, const Shape_t& shape) { - Initialize(info, shape.GetInputTensorShape(0)); - } + // Info_t is used to obtain attributes which will be used for calculating the output shape later. + // Shape_t is used to obtain input shape which will be used for adjusting attribute value. + template + FlattenHelper(const Info_t& info, const Shape_t& shape) { + Initialize(info, shape.GetInputTensorShape(0)); + } - std::vector GetOutputShapes(const MLShapeInferenceContext& shapeInfo) const; + std::vector GetOutputShapes(const MLShapeInferenceContext& shapeInfo) const; - protected: - int m_axis = 1; +protected: + int m_axis = 1; }; class MultinomialHelper { - public: - // Info_t is used to obtain attributes which will be used for calculating the output shape later. - // Shape_t is used to obtain input shape which will be used for adjusting attribute value. - template - MultinomialHelper(const Info_t& info, const Shape_t& shape) {} +public: + // Info_t is used to obtain attributes which will be used for calculating the output shape later. + // Shape_t is used to obtain input shape which will be used for adjusting attribute value. + template + MultinomialHelper(const Info_t& info, const Shape_t& shape) {} - std::vector GetOutputShapes(const MLShapeInferenceContext& shapeInfo) const; + std::vector GetOutputShapes(const MLShapeInferenceContext& shapeInfo) const; }; class GatherHelper { - public: - void Initialize( - const MLOperatorAttributes& operatorAttributes, - gsl::span dataDimensions); +public: + void Initialize( + const MLOperatorAttributes& operatorAttributes, + gsl::span dataDimensions); - // Info_t is used to obtain attributes which will be used for calculating the output shape later. - // Shape_t is used to obtain input shape which will be used for adjusting attribute value. - template - GatherHelper(const Info_t& info, const Shape_t& shape) { - Initialize(info, shape.GetInputTensorShape(0)); - } + // Info_t is used to obtain attributes which will be used for calculating the output shape later. + // Shape_t is used to obtain input shape which will be used for adjusting attribute value. + template + GatherHelper(const Info_t& info, const Shape_t& shape) { + Initialize(info, shape.GetInputTensorShape(0)); + } - std::vector GetOutputShapes(const MLShapeInferenceContext& shapeInfo) const; + std::vector GetOutputShapes(const MLShapeInferenceContext& shapeInfo) const; - protected: - int m_axis = 0; +protected: + int m_axis = 0; }; class GatherNdHelper { - public: - // Info_t is used to obtain attributes which will be used for calculating the output shape later. - // Shape_t is used to obtain input shape which will be used for adjusting attribute value. - template - GatherNdHelper(const Info_t& info, const Shape_t& shape) { - m_batchCount = info.template GetOptionalAttribute(AttrName::BatchDimensions, 0); - } +public: + // Info_t is used to obtain attributes which will be used for calculating the output shape later. + // Shape_t is used to obtain input shape which will be used for adjusting attribute value. + template + GatherNdHelper(const Info_t& info, const Shape_t& shape) { + m_batchCount = info.template GetOptionalAttribute(AttrName::BatchDimensions, 0); + } - std::vector GetOutputShapes(const MLShapeInferenceContext& shapeInfo) const; + std::vector GetOutputShapes(const MLShapeInferenceContext& shapeInfo) const; protected: - int32_t m_batchCount; + int32_t m_batchCount; }; class PoolingHelperBase { - public: - // Info_t is used to obtain attributes which will be used for calculating the output shape later. - // Shape_t is used to obtain input shape which will be used for adjusting attribute value. - template - PoolingHelperBase( - const Info_t& info, - const Shape_t& shape, - bool useGlobalPooling) : m_kernel(useGlobalPooling - ? InitializeGlobalKernel(shape.GetInputTensorShape(0)) - : InitializeKernel(info, static_cast(shape.GetInputTensorShape(0).size()), gsl::span())) { - if (!useGlobalPooling) { - ResolveAutoPadding(m_kernel, shape.GetInputTensorShape(0)); +public: + // Info_t is used to obtain attributes which will be used for calculating the output shape later. + // Shape_t is used to obtain input shape which will be used for adjusting attribute value. + template + PoolingHelperBase( + const Info_t& info, + const Shape_t& shape, + bool useGlobalPooling) : m_kernel(useGlobalPooling + ? InitializeGlobalKernel(shape.GetInputTensorShape(0)) + : InitializeKernel(info, static_cast(shape.GetInputTensorShape(0).size()), gsl::span())) { + if (!useGlobalPooling) { + ResolveAutoPadding(m_kernel, shape.GetInputTensorShape(0)); + } } - } - std::vector GetOutputShapes(const MLShapeInferenceContext& shapeInfo) const; + std::vector GetOutputShapes(const MLShapeInferenceContext& shapeInfo) const; - protected: - KernelArgs m_kernel; +protected: + KernelArgs m_kernel; }; class UnpoolingHelper { public: - // Info_t is used to obtain attributes which will be used for calculating the output shape later. - // Shape_t is used to obtain input shape which will be used for adjusting attribute value. - template - UnpoolingHelper( - const Info_t& info, - const Shape_t& shape + // Info_t is used to obtain attributes which will be used for calculating the output shape later. + // Shape_t is used to obtain input shape which will be used for adjusting attribute value. + template + UnpoolingHelper( + const Info_t& info, + const Shape_t& shape ) - : m_inputShape(shape.GetInputTensorShape(0)), - m_kernel(InitializeKernel(info, static_cast(m_inputShape.size()), gsl::span())) - { - Initialize(); - } + : m_inputShape(shape.GetInputTensorShape(0)), + m_kernel(InitializeKernel(info, static_cast(m_inputShape.size()), gsl::span())) + { + Initialize(); + } - void Initialize(); + void Initialize(); - std::vector GetOutputShapes(const MLShapeInferenceContext& shapeInfo) const; + std::vector GetOutputShapes(const MLShapeInferenceContext& shapeInfo) const; protected: std::vector m_inputShape; @@ -1105,15 +1116,15 @@ protected: }; class GlobalPoolingHelper : public PoolingHelperBase { - public: - template - GlobalPoolingHelper(const Info_t& info, const Shape_t& shape) : PoolingHelperBase(info, shape, true) {} +public: + template + GlobalPoolingHelper(const Info_t& info, const Shape_t& shape) : PoolingHelperBase(info, shape, true) {} }; class PoolingHelper : public PoolingHelperBase { - public: - template - PoolingHelper(const Info_t& info, const Shape_t& shape) : PoolingHelperBase(info, shape, false) {} +public: + template + PoolingHelper(const Info_t& info, const Shape_t& shape) : PoolingHelperBase(info, shape, false) {} }; class RoiPoolingHelperBase @@ -1164,139 +1175,139 @@ public: }; class SqueezeHelper { - public: - void Initialize( - gsl::span axes, - gsl::span inputDimensions); +public: + void Initialize( + gsl::span axes, + gsl::span inputDimensions); - // Info_t is used to obtain attributes which will be used for calculating the output shape later. - // Shape_t is used to obtain input shape which will be used for adjusting attribute value. - template - SqueezeHelper(const Info_t& info, const Shape_t& shape) { - Initialize( - info.GetOptionalAttributeVectorInt32(AttrName::Axes), - shape.GetInputTensorShape(0)); - } + // Info_t is used to obtain attributes which will be used for calculating the output shape later. + // Shape_t is used to obtain input shape which will be used for adjusting attribute value. + template + SqueezeHelper(const Info_t& info, const Shape_t& shape) { + Initialize( + info.GetOptionalAttributeVectorInt32(AttrName::Axes), + shape.GetInputTensorShape(0)); + } - std::vector GetOutputShapes(const MLShapeInferenceContext& shapeInfo) const; + std::vector GetOutputShapes(const MLShapeInferenceContext& shapeInfo) const; - protected: - std::vector m_axes; +protected: + std::vector m_axes; }; class UnsqueezeHelper { - public: - void Initialize( - gsl::span axes, - gsl::span inputDimensions); +public: + void Initialize( + gsl::span axes, + gsl::span inputDimensions); - // Info_t is used to obtain attributes which will be used for calculating the output shape later. - // Shape_t is used to obtain input shape which will be used for adjusting attribute value. - template - UnsqueezeHelper(const Info_t& info, const Shape_t& shape) { - Initialize( - info.GetOptionalAttributeVectorInt32(AttrName::Axes), - shape.GetInputTensorShape(0)); - } + // Info_t is used to obtain attributes which will be used for calculating the output shape later. + // Shape_t is used to obtain input shape which will be used for adjusting attribute value. + template + UnsqueezeHelper(const Info_t& info, const Shape_t& shape) { + Initialize( + info.GetOptionalAttributeVectorInt32(AttrName::Axes), + shape.GetInputTensorShape(0)); + } - std::vector GetOutputShapes(const MLShapeInferenceContext& shapeInfo) const; + std::vector GetOutputShapes(const MLShapeInferenceContext& shapeInfo) const; - protected: - std::vector m_axes; +protected: + std::vector m_axes; }; template void CALLBACK ShapeInferenceFunction(IMLOperatorShapeInferenceContext* inference_context) { - MLShapeInferenceContext helperContext(inference_context); - T opHelper(helperContext, helperContext); + MLShapeInferenceContext helperContext(inference_context); + T opHelper(helperContext, helperContext); - // EdgeInfo to contain whether tensor, whether unused, and what shape is - std::vector outputShapes = opHelper.GetOutputShapes(helperContext); + // EdgeInfo to contain whether tensor, whether unused, and what shape is + std::vector outputShapes = opHelper.GetOutputShapes(helperContext); - for (uint32_t i = 0; i < outputShapes.size(); ++i) { - if (outputShapes[i].IsTensor() && !outputShapes[i].IsUnused()) { - helperContext.SetOutputTensorShape(i, outputShapes[i].GetShape()); + for (uint32_t i = 0; i < outputShapes.size(); ++i) { + if (outputShapes[i].IsTensor() && !outputShapes[i].IsUnused()) { + helperContext.SetOutputTensorShape(i, outputShapes[i].GetShape()); + } } - } } class ReshapeHelper { - public: - template - ReshapeHelper(const Info_t& info, const Shape_t& shape) { - ML_CHECK_VALID_ARGUMENT(info.GetInputCount() >= 2); - ML_CHECK_VALID_ARGUMENT(info.GetOutputCount() >= 1); +public: + template + ReshapeHelper(const Info_t& info, const Shape_t& shape) { + ML_CHECK_VALID_ARGUMENT(info.GetInputCount() >= 2); + ML_CHECK_VALID_ARGUMENT(info.GetOutputCount() >= 1); - // The 'shape' tensor is a 1D tensor holding the new shape to reshape to, - // and the first element of its own shape holds how many dimensions there - // will be for the output. - MLOperatorTensor shapeTensor = info.GetConstantInputTensor(1); - ReadCpuLocalTensorIntoInt32(shapeTensor, /*out*/ m_shapeDims); - } + // The 'shape' tensor is a 1D tensor holding the new shape to reshape to, + // and the first element of its own shape holds how many dimensions there + // will be for the output. + MLOperatorTensor shapeTensor = info.GetConstantInputTensor(1); + ReadCpuLocalTensorIntoInt32(shapeTensor, /*out*/ m_shapeDims); + } - std::vector GetOutputShapes(const MLShapeInferenceContext& shapeInfo) const; + std::vector GetOutputShapes(const MLShapeInferenceContext& shapeInfo) const; - protected: - std::vector m_shapeDims; +protected: + std::vector m_shapeDims; }; class ExpandHelper { - public: - template - ExpandHelper(const Info_t& info, const Shape_t& shape) { - } +public: + template + ExpandHelper(const Info_t& info, const Shape_t& shape) { + } - std::vector GetOutputShapes(const MLShapeInferenceContext& shapeInfo) const; + std::vector GetOutputShapes(const MLShapeInferenceContext& shapeInfo) const; - protected: +protected: }; class ConstantOfShapeHelper { - public: - template - ConstantOfShapeHelper(const Info_t& info, const Shape_t& shape) { - } +public: + template + ConstantOfShapeHelper(const Info_t& info, const Shape_t& shape) { + } - std::vector GetOutputShapes(const MLShapeInferenceContext& shapeInfo) const; + std::vector GetOutputShapes(const MLShapeInferenceContext& shapeInfo) const; - protected: +protected: }; class TileHelper { - public: - template - TileHelper(const Info_t& info, const Shape_t& shapeInfo) { - m_inputDimensions = shapeInfo.GetInputTensorShape(0); +public: + template + TileHelper(const Info_t& info, const Shape_t& shapeInfo) { + m_inputDimensions = shapeInfo.GetInputTensorShape(0); - // Read the repeats tensor. - const std::vector repeatsTensorDimensions = shapeInfo.GetInputTensorShape(1); - ML_CHECK_VALID_ARGUMENT(repeatsTensorDimensions.size() == 1, "Tile's repeats tensor must be 1D."); - const size_t dimCount = repeatsTensorDimensions[0]; + // Read the repeats tensor. + const std::vector repeatsTensorDimensions = shapeInfo.GetInputTensorShape(1); + ML_CHECK_VALID_ARGUMENT(repeatsTensorDimensions.size() == 1, "Tile's repeats tensor must be 1D."); + const size_t dimCount = repeatsTensorDimensions[0]; - MLOperatorTensor repeatsTensor = info.GetConstantInputTensor(1); - const int64_t* repeatsData = repeatsTensor.GetData(); - ML_CHECK_VALID_ARGUMENT(m_inputDimensions.size() == dimCount, "Tile's repeats tensor must be the same dimension count as the input tensor."); - ML_CHECK_VALID_ARGUMENT(repeatsTensor.IsCpuData(), "Tile's repeats tensor must be CPU Tensor."); + MLOperatorTensor repeatsTensor = info.GetConstantInputTensor(1); + const int64_t* repeatsData = repeatsTensor.GetData(); + ML_CHECK_VALID_ARGUMENT(m_inputDimensions.size() == dimCount, "Tile's repeats tensor must be the same dimension count as the input tensor."); + ML_CHECK_VALID_ARGUMENT(repeatsTensor.IsCpuData(), "Tile's repeats tensor must be CPU Tensor."); - for (size_t i = 0; i < dimCount; ++i) { - ML_CHECK_VALID_ARGUMENT(repeatsData[i] >= 0, "Repeat values should be >= 0."); - m_repeatsData.push_back(gsl::narrow_cast(repeatsData[i])); + for (size_t i = 0; i < dimCount; ++i) { + ML_CHECK_VALID_ARGUMENT(repeatsData[i] >= 0, "Repeat values should be >= 0."); + m_repeatsData.push_back(gsl::narrow_cast(repeatsData[i])); + } + + // Update the computed output shape accordingly, repeat every axis's length by the repeat count. + m_outputDimensions.assign(m_inputDimensions.begin(), m_inputDimensions.end()); + + for (size_t dimIndex = 0; dimIndex < dimCount; ++dimIndex) { + m_outputDimensions[dimIndex] *= m_repeatsData[dimIndex]; + } } - // Update the computed output shape accordingly, repeat every axis's length by the repeat count. - m_outputDimensions.assign(m_inputDimensions.begin(), m_inputDimensions.end()); + std::vector GetOutputShapes(const MLShapeInferenceContext& shapeInfo) const; - for (size_t dimIndex = 0; dimIndex < dimCount; ++dimIndex) { - m_outputDimensions[dimIndex] *= m_repeatsData[dimIndex]; - } - } - - std::vector GetOutputShapes(const MLShapeInferenceContext& shapeInfo) const; - - protected: - std::vector m_repeatsData; - std::vector m_inputDimensions; - std::vector m_outputDimensions; +protected: + std::vector m_repeatsData; + std::vector m_inputDimensions; + std::vector m_outputDimensions; }; class ResizeHelper { @@ -1306,38 +1317,39 @@ class ResizeHelper { template ResizeHelper(const Info_t& info, const Shape_t& shape, uint32_t opsetVersion) { - m_inputDimensions = shape.GetInputTensorShape(0); - std::vector outputSizes; + m_inputDimensions = shape.GetInputTensorShape(0); + std::vector outputSizes; - if (opsetVersion >= 11) { - if (info.IsInputValid(1)) - { - MLOperatorTensor regionOfInterestTensor = info.GetConstantInputTensor(1); - ReadCpuLocalTensorIntoFloat32(regionOfInterestTensor, /*out*/ m_regionOfInterest); + if (opsetVersion >= 11) { + if (info.IsInputValid(1)) + { + MLOperatorTensor regionOfInterestTensor = info.GetConstantInputTensor(1); + ReadCpuLocalTensorIntoFloat32(regionOfInterestTensor, /*out*/ m_regionOfInterest); + } + if (info.IsInputValid(2)) + { + MLOperatorTensor scalesTensor = info.GetConstantInputTensor(2); + ReadCpuLocalTensorIntoFloat32(scalesTensor, /*out*/ m_scales); + } + if (info.IsInputValid(3)) + { + MLOperatorTensor outputSizesTensor = info.GetConstantInputTensor(3); + ReadCpuLocalTensorIntoInt32(outputSizesTensor, /*out*/ outputSizes); + } } - if (info.IsInputValid(2)) - { - MLOperatorTensor scalesTensor = info.GetConstantInputTensor(2); + else if (opsetVersion >= 9) { + // Read the scales from the 2nd tensor. + // Compatible with Upsample-9/Upsample-10 and Resize-10. + MLOperatorTensor scalesTensor = info.GetConstantInputTensor(1); ReadCpuLocalTensorIntoFloat32(scalesTensor, /*out*/ m_scales); } - if (info.IsInputValid(3)) + else { - MLOperatorTensor outputSizesTensor = info.GetConstantInputTensor(3); - ReadCpuLocalTensorIntoInt32(outputSizesTensor, /*out*/ outputSizes); + // From attribute, compatible with Upsample-7. + m_scales = info.template GetOptionalAttribute>(AttrName::Scales, std::vector()); } - } - else if (opsetVersion >= 9) { - // Read the scales from the 2nd tensor. - // Compatible with Upsample-9/Upsample-10 and Resize-10. - MLOperatorTensor scalesTensor = info.GetConstantInputTensor(1); - ReadCpuLocalTensorIntoFloat32(scalesTensor, /*out*/ m_scales); - } else - { - // From attribute, compatible with Upsample-7. - m_scales = info.template GetOptionalAttribute>(AttrName::Scales, std::vector()); - } - Initialize(outputSizes); + Initialize(outputSizes); } void Initialize(gsl::span outputSizes); @@ -1383,39 +1395,39 @@ protected: }; class OneHotHelper { - public: - template - OneHotHelper(const Info_t& info, const Shape_t& shapeInfo) { - ML_CHECK_VALID_ARGUMENT(info.GetInputCount() == 3); - ML_CHECK_VALID_ARGUMENT(info.GetOutputCount() == 1); +public: + template + OneHotHelper(const Info_t& info, const Shape_t& shapeInfo) { + ML_CHECK_VALID_ARGUMENT(info.GetInputCount() == 3); + ML_CHECK_VALID_ARGUMENT(info.GetOutputCount() == 1); - const std::vector inputDimensions = shapeInfo.GetInputTensorShape(0); - std::vector outputDimensions; + const std::vector inputDimensions = shapeInfo.GetInputTensorShape(0); + std::vector outputDimensions; - m_onnxAxis = info.template GetOptionalAttribute(AttrName::Axis, -1); + m_onnxAxis = info.template GetOptionalAttribute(AttrName::Axis, -1); - // Get 'depth' tensor, which is really a scalar for the output size along the given axis. - MLOperatorTensor shapeTensor = info.GetConstantInputTensor(1); + // Get 'depth' tensor, which is really a scalar for the output size along the given axis. + MLOperatorTensor shapeTensor = info.GetConstantInputTensor(1); - auto indicesShape = shapeInfo.GetInputTensorShape(0); - m_absoluteAxis = HandleNegativeAxis(m_onnxAxis, gsl::narrow_cast(indicesShape.size() + 1)); + auto indicesShape = shapeInfo.GetInputTensorShape(0); + m_absoluteAxis = HandleNegativeAxis(m_onnxAxis, gsl::narrow_cast(indicesShape.size() + 1)); - // The shape tensor ('depth') is a 0D tensor holding the size for the output tensor along the specified axis. - // It must be registered as OrtMemType::OrtMemTypeCPUInput for CPU read access. - const int64_t depth64 = ReadScalarTensorCastToInt64(shapeTensor); - ML_CHECK_VALID_ARGUMENT(depth64 > 0, "Negative or zero 'depth' values for OneHot are illegal."); - const uint32_t depth = gsl::narrow_cast(depth64); - m_outputDimensions.assign(indicesShape.begin(), indicesShape.end()); - m_outputDimensions.insert(m_outputDimensions.begin() + m_absoluteAxis, depth); - } + // The shape tensor ('depth') is a 0D tensor holding the size for the output tensor along the specified axis. + // It must be registered as OrtMemType::OrtMemTypeCPUInput for CPU read access. + const int64_t depth64 = ReadScalarTensorCastToInt64(shapeTensor); + ML_CHECK_VALID_ARGUMENT(depth64 > 0, "Negative or zero 'depth' values for OneHot are illegal."); + const uint32_t depth = gsl::narrow_cast(depth64); + m_outputDimensions.assign(indicesShape.begin(), indicesShape.end()); + m_outputDimensions.insert(m_outputDimensions.begin() + m_absoluteAxis, depth); + } - std::vector GetOutputShapes(const MLShapeInferenceContext& shapeInfo) const; + std::vector GetOutputShapes(const MLShapeInferenceContext& shapeInfo) const; - protected: - int32_t m_onnxAxis = 0; // Original ONNX attribute value, including negative value. - uint32_t m_absoluteAxis = 0; // Absolute index value. - std::vector m_indicesDimensions; - std::vector m_outputDimensions; +protected: + int32_t m_onnxAxis = 0; // Original ONNX attribute value, including negative value. + uint32_t m_absoluteAxis = 0; // Absolute index value. + std::vector m_indicesDimensions; + std::vector m_outputDimensions; }; using ShapeInferenceHelper_Conv = ConvHelper; @@ -1447,6 +1459,7 @@ using ShapeInferenceHelper_GatherElements = GetOutputShapeAsSpecificInputShapeHe using ShapeInferenceHelper_ScatterElements = GetOutputShapeAsInputShapeHelper; using ShapeInferenceHelper_Scatter9 = ShapeInferenceHelper_ScatterElements; // Old deprecated alias for ScatterElements. using ShapeInferenceHelper_Scatter11 = ShapeInferenceHelper_ScatterElements; // Old deprecated alias for ScatterElements. +using ShapeInferenceHelper_Scatter13 = ShapeInferenceHelper_ScatterElements; // Old deprecated alias for ScatterElements. using ShapeInferenceHelper_GatherND = GatherNdHelper; using ShapeInferenceHelper_ScatterND = GetOutputShapeAsInputShapeHelper; @@ -1457,8 +1470,10 @@ using ShapeInferenceHelper_Concat = ConcatHelper; using ShapeInferenceHelper_Slice7 = VersionedOpsetHelper; using ShapeInferenceHelper_Slice10 = VersionedOpsetHelper; using ShapeInferenceHelper_Slice11 = VersionedOpsetHelper; // Note 11 and 10 are identical - no functional change. +using ShapeInferenceHelper_Slice13 = VersionedOpsetHelper; // Note 13 and 10 are identical - no functional change, just new types. using ShapeInferenceHelper_Pad7 = VersionedOpsetHelper; using ShapeInferenceHelper_Pad11 = VersionedOpsetHelper; +using ShapeInferenceHelper_Pad13 = VersionedOpsetHelper; using ShapeInferenceHelper_SpaceToDepth = SpaceToDepthHelper; using ShapeInferenceHelper_DepthToSpace = DepthToSpaceHelper; @@ -1485,6 +1500,7 @@ using ShapeInferenceHelper_Floor = GetOutputShapeAsInputShapeHelper; using ShapeInferenceHelper_Clip7 = GetOutputShapeAsInputShapeHelper; using ShapeInferenceHelper_Clip11 = GetOutputShapeAsInputShapeHelper; using ShapeInferenceHelper_Clip12 = GetOutputShapeAsInputShapeHelper; +using ShapeInferenceHelper_Clip13 = GetOutputShapeAsInputShapeHelper; using ShapeInferenceHelper_Greater = GetBroadcastedOutputShapeHelper; using ShapeInferenceHelper_Less = GetBroadcastedOutputShapeHelper; using ShapeInferenceHelper_GreaterOrEqual = GetBroadcastedOutputShapeHelper; @@ -1512,7 +1528,7 @@ using ShapeInferenceHelper_Affine = GetOutputShapeAsInputShapeHelper; using ShapeInferenceHelper_QuantizeLinear = GetOutputShapeAsInputShapeHelper; using ShapeInferenceHelper_DequantizeLinear = GetOutputShapeAsInputShapeHelper; using ShapeInferenceHelper_Sign = GetBroadcastedOutputShapeHelper; -using ShapeInferenceHelper_IsNan = GetBroadcastedOutputShapeHelper; +using ShapeInferenceHelper_IsNaN = GetBroadcastedOutputShapeHelper; using ShapeInferenceHelper_Erf = GetBroadcastedOutputShapeHelper; using ShapeInferenceHelper_Sinh = GetBroadcastedOutputShapeHelper; using ShapeInferenceHelper_Cosh = GetBroadcastedOutputShapeHelper; @@ -1546,6 +1562,7 @@ using ShapeInferenceHelper_ImageScaler = GetOutputShapeAsInputShapeHelper; using ShapeInferenceHelper_Upsample7 = VersionedOpsetHelper; using ShapeInferenceHelper_Upsample9 = VersionedOpsetHelper; using ShapeInferenceHelper_Upsample10 = VersionedOpsetHelper; +using ShapeInferenceHelper_Upsample13 = VersionedOpsetHelper; using ShapeInferenceHelper_Sigmoid = GetOutputShapeAsInputShapeHelper; using ShapeInferenceHelper_HardSigmoid = GetOutputShapeAsInputShapeHelper; @@ -1571,6 +1588,8 @@ using ShapeInferenceHelper_Identity = GetOutputShapeAsInputShapeHelper; using ShapeInferenceHelper_MatMul = MatMulHelper; using ShapeInferenceHelper_MatMulInteger = MatMulHelper; using ShapeInferenceHelper_QLinearMatMul = QLinearMatMulHelper; +using ShapeInferenceHelper_QLinearAdd = GetBroadcastedOutputShapeHelper; +using ShapeInferenceHelper_DynamicQuantizeLinear = GetOutputShapeAsInputShapeHelper; using ShapeInferenceHelper_Cast = GetOutputShapeAsInputShapeHelper; using ShapeInferenceHelper_MemcpyFromHost = GetOutputShapeAsInputShapeHelper; diff --git a/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorRegistration.h b/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorRegistration.h index 9d5c248559..0185d466f9 100644 --- a/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorRegistration.h +++ b/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorRegistration.h @@ -142,7 +142,7 @@ namespace OperatorHelper namespace OnnxOperatorSet9 { static const int sc_sinceVer_Sign = 9; - static const int sc_sinceVer_IsNan = 9; + static const int sc_sinceVer_IsNaN = 9; static const int sc_sinceVer_Sinh = 9; static const int sc_sinceVer_Cosh = 9; static const int sc_sinceVer_Asinh = 9; @@ -233,6 +233,7 @@ namespace OperatorHelper static const int sc_sinceVer_ReduceSumSquare = 11; static const int sc_sinceVer_Resize = 11; static const int sc_sinceVer_Round = 11; + static const int sc_sinceVer_DynamicQuantizeLinear = 11; static const int sc_sinceVer_Scan = 11; static const int sc_sinceVer_Scatter = 11; // Deprecated alias static const int sc_sinceVer_ScatterElements = 11; @@ -263,6 +264,73 @@ namespace OperatorHelper static const int sc_sinceVer_ReduceMin = 12; } // namespace OnnxOperatorSet12 + namespace OnnxOperatorSet13 + { + static const int sc_sinceVer_Abs = 13; + static const int sc_sinceVer_Add = 13; + static const int sc_sinceVer_ArgMax = 13; + static const int sc_sinceVer_ArgMin = 13; + static const int sc_sinceVer_Cast = 13; + static const int sc_sinceVer_Ceil = 13; + static const int sc_sinceVer_Clip = 13; + static const int sc_sinceVer_Concat = 13; + static const int sc_sinceVer_Constant = 13; + static const int sc_sinceVer_DepthToSpace = 13; + static const int sc_sinceVer_Div = 13; + static const int sc_sinceVer_Equal = 13; + static const int sc_sinceVer_Erf = 13; + static const int sc_sinceVer_Exp = 13; + static const int sc_sinceVer_Expand = 13; + static const int sc_sinceVer_Flatten = 13; + static const int sc_sinceVer_Floor = 13; + static const int sc_sinceVer_Gather = 13; + static const int sc_sinceVer_GatherElements = 13; + static const int sc_sinceVer_GatherND = 13; + static const int sc_sinceVer_Gemm = 13; + static const int sc_sinceVer_Greater = 13; + static const int sc_sinceVer_Identity = 13; + static const int sc_sinceVer_IsNaN = 13; + static const int sc_sinceVer_LRN = 13; + static const int sc_sinceVer_Less = 13; + static const int sc_sinceVer_Log = 13; + static const int sc_sinceVer_MatMul = 13; + static const int sc_sinceVer_Max = 13; + static const int sc_sinceVer_Mean = 13; + static const int sc_sinceVer_MeanVarianceNormalization = 13; + static const int sc_sinceVer_Min = 13; + static const int sc_sinceVer_Mod = 13; + static const int sc_sinceVer_Mul = 13; + static const int sc_sinceVer_Neg = 13; + static const int sc_sinceVer_Pad = 13; + static const int sc_sinceVer_Pow = 13; + static const int sc_sinceVer_Reciprocal = 13; + static const int sc_sinceVer_ReduceL1 = 13; + static const int sc_sinceVer_ReduceL2 = 13; + static const int sc_sinceVer_ReduceLogSum = 13; + static const int sc_sinceVer_ReduceLogSumExp = 13; + static const int sc_sinceVer_ReduceMax = 13; + static const int sc_sinceVer_ReduceMean = 13; + static const int sc_sinceVer_ReduceMin = 13; + static const int sc_sinceVer_ReduceProd = 13; + static const int sc_sinceVer_ReduceSumSquare = 13; + static const int sc_sinceVer_Relu = 13; + static const int sc_sinceVer_Reshape = 13; + static const int sc_sinceVer_Scatter = 13; + static const int sc_sinceVer_ScatterElements = 13; + static const int sc_sinceVer_ScatterND = 13; + static const int sc_sinceVer_Sigmoid = 13; + static const int sc_sinceVer_Sign = 13; + static const int sc_sinceVer_Slice = 13; + static const int sc_sinceVer_SpaceToDepth = 13; + static const int sc_sinceVer_Sqrt = 13; + static const int sc_sinceVer_Sub = 13; + static const int sc_sinceVer_Sum = 13; + static const int sc_sinceVer_Tanh = 13; + static const int sc_sinceVer_Tile = 13; + static const int sc_sinceVer_Transpose = 13; + static const int sc_sinceVer_Upsample = 13; + } // namespace OnnxOperatorSet13 + namespace MsftOperatorSet1 { static const int sc_sinceVer_FusedConv = 1; @@ -277,6 +345,7 @@ namespace OperatorHelper static const int sc_sinceVer_QuantizeLinear = 1; static const int sc_sinceVer_DequantizeLinear = 1; static const int sc_sinceVer_ConvTransposeWithDynamicPads = 1; + static const int sc_sinceVer_QLinearAdd = 1; } // namespace MsftOperatorSet1 } // namespace OperatorHelper diff --git a/winml/adapter/abi_custom_registry_impl.cpp b/winml/adapter/abi_custom_registry_impl.cpp index 3619053440..d87d50a9ee 100644 --- a/winml/adapter/abi_custom_registry_impl.cpp +++ b/winml/adapter/abi_custom_registry_impl.cpp @@ -58,6 +58,7 @@ HRESULT STDMETHODCALLTYPE AbiCustomRegistryImpl::RegisterOperatorKernel( bool supportedWith64BitTensorsVia32BitStrides, bool supportedWith64BitTensorsVia32BitStridesFromAnyEp, bool prefer64BitTensorsDirectly, + bool support64BitTensorsViaEmulation, _In_reads_(constantCpuInputCount) const uint32_t* requiredConstantCpuInputs, uint32_t constantCpuInputCount) const noexcept try { #ifdef LAYERING_DONE @@ -83,6 +84,7 @@ HRESULT STDMETHODCALLTYPE AbiCustomRegistryImpl::RegisterOperatorKernel( supportedWith64BitTensorsVia32BitStrides, supportedWith64BitTensorsVia32BitStridesFromAnyEp, prefer64BitTensorsDirectly, + support64BitTensorsViaEmulation, requiredConstantCpuInputs, constantCpuInputCount); } diff --git a/winml/adapter/abi_custom_registry_impl.h b/winml/adapter/abi_custom_registry_impl.h index c955c7e384..f24ddd4b02 100644 --- a/winml/adapter/abi_custom_registry_impl.h +++ b/winml/adapter/abi_custom_registry_impl.h @@ -31,6 +31,7 @@ class AbiCustomRegistryImpl : public AbiCustomRegistry { bool supports_64bit_directly = false, bool allows_64bit_via_strides = false, bool allows_64bit_via_strides_from_any_ep = false, + bool supports_64bit_tensors_via_emulation = false, _In_reads_(constant_cpu_input_count) const uint32_t* required_constant_cpu_inputs = nullptr, uint32_t constant_cpu_input_count = 0) const noexcept override;