Merge remote-tracking branch 'ado_wai_ort/DmlDev' into user/dwayner/DML1.8forORT1.10

This commit is contained in:
Dwayne Robinson 2021-11-19 05:35:00 -08:00
commit 289b1bdc86
28 changed files with 2286 additions and 1071 deletions

View file

@ -122,6 +122,10 @@ namespace Windows::AI::MachineLearning::Adapter
// Operator supports true 64-bit tensors directly, no strides needed.
// So fallback to strided 32-bit only occurs when the device lacks 64-bit support.
bool prefer64BitTensorsDirectly = false;
// The operator supports emulation for uint64/int64 even if the hardware doesn't
// support native uint64/int64 data types.
bool support64BitTensorsViaEmulation = false;
};
using InternalRegistrationInfoMap = std::unordered_map<onnxruntime::KernelDef*, std::shared_ptr<InternalRegistrationInfo>>;

View file

@ -345,6 +345,7 @@ HRESULT STDMETHODCALLTYPE AbiCustomRegistry::RegisterOperatorKernel(
bool supportedWith64BitTensorsVia32BitStrides,
bool supportedWith64BitTensorsVia32BitStridesFromAnyEp,
bool prefer64BitTensorsDirectly,
bool support64BitTensorsViaEmulation,
_In_reads_(constantCpuInputCount) const uint32_t* requiredConstantCpuInputs,
uint32_t constantCpuInputCount) const noexcept
{
@ -472,6 +473,7 @@ HRESULT STDMETHODCALLTYPE AbiCustomRegistry::RegisterOperatorKernel(
regInfo->supportedWith64BitTensorsVia32BitStrides = supportedWith64BitTensorsVia32BitStrides;
regInfo->supportedWith64BitTensorsVia32BitStridesFromAnyEp = supportedWith64BitTensorsVia32BitStridesFromAnyEp;
regInfo->prefer64BitTensorsDirectly = prefer64BitTensorsDirectly;
regInfo->support64BitTensorsViaEmulation = support64BitTensorsViaEmulation;
// Only internal operators support usage in DML graphs
if (supportsGraph)
@ -546,7 +548,8 @@ HRESULT STDMETHODCALLTYPE AbiCustomRegistry::RegisterOperatorKernel(
requiredConstantCpuInputs ||
supportedWith64BitTensorsVia32BitStrides ||
supportedWith64BitTensorsVia32BitStridesFromAnyEp ||
prefer64BitTensorsDirectly)
prefer64BitTensorsDirectly ||
support64BitTensorsViaEmulation)
{
ORT_THROW_HR(E_INVALIDARG);
}

View file

@ -44,6 +44,7 @@ class AbiCustomRegistry : public WRL::Base<IMLOperatorRegistry, IMLOperatorRegis
bool supportedWith64BitTensorsVia32BitStrides = false,
bool supportedWith64BitTensorsVia32BitStridesFromAnyEp = false,
bool prefer64BitTensorsDirectly = false,
bool support64BitTensorsViaEmulation = false,
_In_reads_(constantCpuInputCount) const uint32_t* requiredConstantCpuInputs = nullptr,
uint32_t constantCpuInputCount = 0) const noexcept override;

View file

@ -20,7 +20,7 @@ DML_TENSOR_DATA_TYPE GetDmlDataTypeFromMlDataTypeNoThrow(MLOperatorTensorDataTyp
case MLOperatorTensorDataType::String: return DML_TENSOR_DATA_TYPE_UNKNOWN;
case MLOperatorTensorDataType::Bool: return DML_TENSOR_DATA_TYPE_UINT8;
case MLOperatorTensorDataType::Float16: return DML_TENSOR_DATA_TYPE_FLOAT16;
case MLOperatorTensorDataType::Double: return DML_TENSOR_DATA_TYPE_UNKNOWN;
case MLOperatorTensorDataType::Double: return DML_TENSOR_DATA_TYPE_FLOAT64;
case MLOperatorTensorDataType::UInt32: return DML_TENSOR_DATA_TYPE_UINT32;
case MLOperatorTensorDataType::UInt64: return DML_TENSOR_DATA_TYPE_UINT64;
case MLOperatorTensorDataType::Complex64: return DML_TENSOR_DATA_TYPE_UNKNOWN;
@ -119,7 +119,7 @@ uint32_t GetSupportedDeviceDataTypeMask(IDMLDevice* dmlDevice)
uint32_t deviceTypeMask = 0u;
// Form the bitmask of all supported data types.
for (uint32_t i = 0; i <= DML_TENSOR_DATA_TYPE_INT8; ++i)
for (uint32_t i = 0; i <= DML_TENSOR_DATA_TYPE_INT64; ++i)
{
DML_FEATURE_QUERY_TENSOR_DATA_TYPE_SUPPORT dataTypeQuery = { static_cast<DML_TENSOR_DATA_TYPE>(i) };
DML_FEATURE_DATA_TENSOR_DATA_TYPE_SUPPORT dataTypeSupport = {};

View file

@ -96,4 +96,24 @@ namespace Dml
return minimumImpliedSizeInBytes;
}
template <typename T>
void CastToClampedScalarUnion(DML_TENSOR_DATA_TYPE dataType, T value, DML_SCALAR_UNION* outputValue)
{
switch (dataType)
{
case DML_TENSOR_DATA_TYPE_UINT8: outputValue->UInt8 = clamp_cast<uint8_t, T>(value); break;
case DML_TENSOR_DATA_TYPE_UINT16: outputValue->UInt16 = clamp_cast<uint16_t, T>(value); break;
case DML_TENSOR_DATA_TYPE_UINT32: outputValue->UInt32 = clamp_cast<uint32_t, T>(value); break;
case DML_TENSOR_DATA_TYPE_UINT64: outputValue->UInt64 = clamp_cast<uint64_t, T>(value); break;
case DML_TENSOR_DATA_TYPE_INT8: outputValue->Int8 = clamp_cast<int8_t, T>(value); break;
case DML_TENSOR_DATA_TYPE_INT16: outputValue->Int16 = clamp_cast<int16_t, T>(value); break;
case DML_TENSOR_DATA_TYPE_INT32: outputValue->Int32 = clamp_cast<int32_t, T>(value); break;
case DML_TENSOR_DATA_TYPE_INT64: outputValue->Int64 = clamp_cast<int64_t, T>(value); break;
case DML_TENSOR_DATA_TYPE_FLOAT16: outputValue->Float32 = clamp_cast<float, T>(value); break;
case DML_TENSOR_DATA_TYPE_FLOAT32: outputValue->Float32 = clamp_cast<float, T>(value); break;
case DML_TENSOR_DATA_TYPE_FLOAT64: outputValue->Float64 = clamp_cast<double, T>(value); break;
default: assert(false);
}
}
} // namespace Dml

View file

@ -24,7 +24,7 @@ struct EnumTraits<DML_TENSOR_TYPE>
template <>
struct EnumTraits<DML_OPERATOR_TYPE>
{
static constexpr auto ValueCount = 141;
static constexpr auto ValueCount = 153;
static constexpr size_t ActivationFunctionCount = 20;
};
@ -86,7 +86,7 @@ struct EnumTraits<DML_FEATURE>
template <>
struct EnumTraits<DML_FEATURE_LEVEL>
{
static constexpr auto ValueCount = 4;
static constexpr auto ValueCount = 8;
};
template <>
@ -225,6 +225,24 @@ struct OperatorDescTraits<DML_ELEMENT_WISE_CLIP_OPERATOR_DESC>
static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_ELEMENT_WISE_CLIP;
};
template <>
struct OperatorDescTraits<DML_ELEMENT_WISE_CLIP1_OPERATOR_DESC>
{
static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_ELEMENT_WISE_CLIP1;
};
template <>
struct OperatorDescTraits<DML_ELEMENT_WISE_CLIP_GRAD_OPERATOR_DESC>
{
static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_ELEMENT_WISE_CLIP_GRAD;
};
template <>
struct OperatorDescTraits<DML_ELEMENT_WISE_CLIP_GRAD1_OPERATOR_DESC>
{
static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_ELEMENT_WISE_CLIP_GRAD1;
};
template <>
struct OperatorDescTraits<DML_ELEMENT_WISE_COS_OPERATOR_DESC>
{
@ -363,6 +381,18 @@ struct OperatorDescTraits<DML_ELEMENT_WISE_SQRT_OPERATOR_DESC>
static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_ELEMENT_WISE_SQRT;
};
template <>
struct OperatorDescTraits<DML_ELEMENT_WISE_DIFFERENCE_SQUARE_OPERATOR_DESC>
{
static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_ELEMENT_WISE_DIFFERENCE_SQUARE;
};
template <>
struct OperatorDescTraits<DML_ELEMENT_WISE_ATAN_YX_OPERATOR_DESC>
{
static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_ELEMENT_WISE_ATAN_YX;
};
template <>
struct OperatorDescTraits<DML_ELEMENT_WISE_SUBTRACT_OPERATOR_DESC>
{
@ -483,6 +513,12 @@ struct OperatorDescTraits<DML_PADDING_OPERATOR_DESC>
static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_PADDING;
};
template <>
struct OperatorDescTraits<DML_PADDING1_OPERATOR_DESC>
{
static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_PADDING1;
};
template <>
struct OperatorDescTraits<DML_VALUE_SCALE_2D_OPERATOR_DESC>
{
@ -531,6 +567,18 @@ struct OperatorDescTraits<DML_BATCH_NORMALIZATION_OPERATOR_DESC>
static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_BATCH_NORMALIZATION;
};
template <>
struct OperatorDescTraits<DML_BATCH_NORMALIZATION_GRAD_OPERATOR_DESC>
{
static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_BATCH_NORMALIZATION_GRAD;
};
template <>
struct OperatorDescTraits<DML_BATCH_NORMALIZATION_TRAINING_GRAD_OPERATOR_DESC>
{
static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_BATCH_NORMALIZATION_TRAINING_GRAD;
};
template <>
struct OperatorDescTraits<DML_MEAN_VARIANCE_NORMALIZATION_OPERATOR_DESC>
{
@ -543,6 +591,12 @@ struct OperatorDescTraits<DML_LOCAL_RESPONSE_NORMALIZATION_OPERATOR_DESC>
static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_LOCAL_RESPONSE_NORMALIZATION;
};
template <>
struct OperatorDescTraits<DML_LOCAL_RESPONSE_NORMALIZATION_GRAD_OPERATOR_DESC>
{
static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_LOCAL_RESPONSE_NORMALIZATION_GRAD;
};
template <>
struct OperatorDescTraits<DML_LP_NORMALIZATION_OPERATOR_DESC>
{
@ -579,6 +633,12 @@ struct OperatorDescTraits<DML_ELEMENT_WISE_IS_NAN_OPERATOR_DESC>
static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_ELEMENT_WISE_IS_NAN;
};
template <>
struct OperatorDescTraits<DML_ELEMENT_WISE_NEGATE_OPERATOR_DESC>
{
static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_ELEMENT_WISE_NEGATE;
};
template <>
struct OperatorDescTraits<DML_ELEMENT_WISE_ERF_OPERATOR_DESC>
{
@ -717,6 +777,12 @@ struct OperatorDescTraits<DML_CUMULATIVE_SUMMATION_OPERATOR_DESC>
static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_CUMULATIVE_SUMMATION;
};
template <>
struct OperatorDescTraits<DML_CUMULATIVE_PRODUCT_OPERATOR_DESC>
{
static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_CUMULATIVE_PRODUCT;
};
template <>
struct OperatorDescTraits<DML_REVERSE_SUBSEQUENCES_OPERATOR_DESC>
{
@ -891,12 +957,42 @@ struct OperatorDescTraits<DML_ROI_ALIGN_OPERATOR_DESC>
static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_ROI_ALIGN;
};
template <>
struct OperatorDescTraits<DML_ROI_ALIGN1_OPERATOR_DESC>
{
static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_ROI_ALIGN1;
};
template <>
struct OperatorDescTraits<DML_GATHER_ND1_OPERATOR_DESC>
{
static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_GATHER_ND1;
};
template <>
struct OperatorDescTraits<DML_DYNAMIC_QUANTIZE_LINEAR_OPERATOR_DESC>
{
static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_DYNAMIC_QUANTIZE_LINEAR;
};
template <>
struct OperatorDescTraits<DML_ELEMENT_WISE_QUANTIZED_LINEAR_ADD_OPERATOR_DESC>
{
static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_ELEMENT_WISE_QUANTIZED_LINEAR_ADD;
};
template <>
struct OperatorDescTraits<DML_ROI_ALIGN_GRAD_OPERATOR_DESC>
{
static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_ROI_ALIGN_GRAD;
};
template <>
struct OperatorDescTraits<DML_BATCH_NORMALIZATION_TRAINING_OPERATOR_DESC>
{
static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_BATCH_NORMALIZATION_TRAINING;
};
template <>
struct OperatorDescTraits<DML_ACTIVATION_ELU_OPERATOR_DESC>
{
@ -1071,6 +1167,24 @@ struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_ELEMENT_WISE_CLIP>
using DescType = DML_ELEMENT_WISE_CLIP_OPERATOR_DESC;
};
template <>
struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_ELEMENT_WISE_CLIP1>
{
using DescType = DML_ELEMENT_WISE_CLIP1_OPERATOR_DESC;
};
template <>
struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_ELEMENT_WISE_CLIP_GRAD>
{
using DescType = DML_ELEMENT_WISE_CLIP_GRAD_OPERATOR_DESC;
};
template <>
struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_ELEMENT_WISE_CLIP_GRAD1>
{
using DescType = DML_ELEMENT_WISE_CLIP_GRAD1_OPERATOR_DESC;
};
template <>
struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_ELEMENT_WISE_COS>
{
@ -1209,6 +1323,18 @@ struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_ELEMENT_WISE_SQRT>
using DescType = DML_ELEMENT_WISE_SQRT_OPERATOR_DESC;
};
template <>
struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_ELEMENT_WISE_DIFFERENCE_SQUARE>
{
using DescType = DML_ELEMENT_WISE_DIFFERENCE_SQUARE_OPERATOR_DESC;
};
template <>
struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_ELEMENT_WISE_ATAN_YX>
{
using DescType = DML_ELEMENT_WISE_ATAN_YX_OPERATOR_DESC;
};
template <>
struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_ELEMENT_WISE_SUBTRACT>
{
@ -1329,6 +1455,12 @@ struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_PADDING>
using DescType = DML_PADDING_OPERATOR_DESC;
};
template <>
struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_PADDING1>
{
using DescType = DML_PADDING1_OPERATOR_DESC;
};
template <>
struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_VALUE_SCALE_2D>
{
@ -1377,6 +1509,18 @@ struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_BATCH_NORMALIZATION>
using DescType = DML_BATCH_NORMALIZATION_OPERATOR_DESC;
};
template <>
struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_BATCH_NORMALIZATION_GRAD>
{
using DescType = DML_BATCH_NORMALIZATION_GRAD_OPERATOR_DESC;
};
template <>
struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_BATCH_NORMALIZATION_TRAINING_GRAD>
{
using DescType = DML_BATCH_NORMALIZATION_TRAINING_GRAD_OPERATOR_DESC;
};
template <>
struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_MEAN_VARIANCE_NORMALIZATION>
{
@ -1389,6 +1533,12 @@ struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_LOCAL_RESPONSE_NORMALI
using DescType = DML_LOCAL_RESPONSE_NORMALIZATION_OPERATOR_DESC;
};
template <>
struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_LOCAL_RESPONSE_NORMALIZATION_GRAD>
{
using DescType = DML_LOCAL_RESPONSE_NORMALIZATION_GRAD_OPERATOR_DESC;
};
template <>
struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_LP_NORMALIZATION>
{
@ -1425,6 +1575,12 @@ struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_ELEMENT_WISE_IS_NAN>
using DescType = DML_ELEMENT_WISE_IS_NAN_OPERATOR_DESC;
};
template <>
struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_ELEMENT_WISE_NEGATE>
{
using DescType = DML_ELEMENT_WISE_NEGATE_OPERATOR_DESC;
};
template <>
struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_ELEMENT_WISE_ERF>
{
@ -1563,6 +1719,12 @@ struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_CUMULATIVE_SUMMATION>
using DescType = DML_CUMULATIVE_SUMMATION_OPERATOR_DESC;
};
template <>
struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_CUMULATIVE_PRODUCT>
{
using DescType = DML_CUMULATIVE_PRODUCT_OPERATOR_DESC;
};
template <>
struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_REVERSE_SUBSEQUENCES>
{
@ -1737,12 +1899,42 @@ struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_ROI_ALIGN>
using DescType = DML_ROI_ALIGN_OPERATOR_DESC;
};
template <>
struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_ROI_ALIGN1>
{
using DescType = DML_ROI_ALIGN1_OPERATOR_DESC;
};
template <>
struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_GATHER_ND1>
{
using DescType = DML_GATHER_ND1_OPERATOR_DESC;
};
template <>
struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_DYNAMIC_QUANTIZE_LINEAR>
{
using DescType = DML_DYNAMIC_QUANTIZE_LINEAR_OPERATOR_DESC;
};
template <>
struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_ELEMENT_WISE_QUANTIZED_LINEAR_ADD>
{
using DescType = DML_ELEMENT_WISE_QUANTIZED_LINEAR_ADD_OPERATOR_DESC;
};
template <>
struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_ROI_ALIGN_GRAD>
{
using DescType = DML_ROI_ALIGN_GRAD_OPERATOR_DESC;
};
template <>
struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_BATCH_NORMALIZATION_TRAINING>
{
using DescType = DML_BATCH_NORMALIZATION_TRAINING_OPERATOR_DESC;
};
template <>
struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_ACTIVATION_ELU>
{
@ -1894,6 +2086,12 @@ auto OperatorTypeVisitor(DML_OPERATOR_TYPE type, Visitor&& visitor, Ts&&... args
return std::invoke(std::forward<Visitor>(visitor), DML_ELEMENT_WISE_CEIL_OPERATOR_DESC{}, std::forward<Ts>(args)...);
case DML_OPERATOR_ELEMENT_WISE_CLIP:
return std::invoke(std::forward<Visitor>(visitor), DML_ELEMENT_WISE_CLIP_OPERATOR_DESC{}, std::forward<Ts>(args)...);
case DML_OPERATOR_ELEMENT_WISE_CLIP1:
return std::invoke(std::forward<Visitor>(visitor), DML_ELEMENT_WISE_CLIP1_OPERATOR_DESC{}, std::forward<Ts>(args)...);
case DML_OPERATOR_ELEMENT_WISE_CLIP_GRAD:
return std::invoke(std::forward<Visitor>(visitor), DML_ELEMENT_WISE_CLIP_GRAD_OPERATOR_DESC{}, std::forward<Ts>(args)...);
case DML_OPERATOR_ELEMENT_WISE_CLIP_GRAD1:
return std::invoke(std::forward<Visitor>(visitor), DML_ELEMENT_WISE_CLIP_GRAD1_OPERATOR_DESC{}, std::forward<Ts>(args)...);
case DML_OPERATOR_ELEMENT_WISE_COS:
return std::invoke(std::forward<Visitor>(visitor), DML_ELEMENT_WISE_COS_OPERATOR_DESC{}, std::forward<Ts>(args)...);
case DML_OPERATOR_ELEMENT_WISE_DIVIDE:
@ -1940,6 +2138,10 @@ auto OperatorTypeVisitor(DML_OPERATOR_TYPE type, Visitor&& visitor, Ts&&... args
return std::invoke(std::forward<Visitor>(visitor), DML_ELEMENT_WISE_SIN_OPERATOR_DESC{}, std::forward<Ts>(args)...);
case DML_OPERATOR_ELEMENT_WISE_SQRT:
return std::invoke(std::forward<Visitor>(visitor), DML_ELEMENT_WISE_SQRT_OPERATOR_DESC{}, std::forward<Ts>(args)...);
case DML_OPERATOR_ELEMENT_WISE_DIFFERENCE_SQUARE:
return std::invoke(std::forward<Visitor>(visitor), DML_ELEMENT_WISE_DIFFERENCE_SQUARE_OPERATOR_DESC{}, std::forward<Ts>(args)...);
case DML_OPERATOR_ELEMENT_WISE_ATAN_YX:
return std::invoke(std::forward<Visitor>(visitor), DML_ELEMENT_WISE_ATAN_YX_OPERATOR_DESC{}, std::forward<Ts>(args)...);
case DML_OPERATOR_ELEMENT_WISE_SUBTRACT:
return std::invoke(std::forward<Visitor>(visitor), DML_ELEMENT_WISE_SUBTRACT_OPERATOR_DESC{}, std::forward<Ts>(args)...);
case DML_OPERATOR_ELEMENT_WISE_TAN:
@ -1980,6 +2182,8 @@ auto OperatorTypeVisitor(DML_OPERATOR_TYPE type, Visitor&& visitor, Ts&&... args
return std::invoke(std::forward<Visitor>(visitor), DML_JOIN_OPERATOR_DESC{}, std::forward<Ts>(args)...);
case DML_OPERATOR_PADDING:
return std::invoke(std::forward<Visitor>(visitor), DML_PADDING_OPERATOR_DESC{}, std::forward<Ts>(args)...);
case DML_OPERATOR_PADDING1:
return std::invoke(std::forward<Visitor>(visitor), DML_PADDING1_OPERATOR_DESC{}, std::forward<Ts>(args)...);
case DML_OPERATOR_VALUE_SCALE_2D:
return std::invoke(std::forward<Visitor>(visitor), DML_VALUE_SCALE_2D_OPERATOR_DESC{}, std::forward<Ts>(args)...);
case DML_OPERATOR_UPSAMPLE_2D:
@ -1996,10 +2200,16 @@ auto OperatorTypeVisitor(DML_OPERATOR_TYPE type, Visitor&& visitor, Ts&&... args
return std::invoke(std::forward<Visitor>(visitor), DML_TOP_K_OPERATOR_DESC{}, std::forward<Ts>(args)...);
case DML_OPERATOR_BATCH_NORMALIZATION:
return std::invoke(std::forward<Visitor>(visitor), DML_BATCH_NORMALIZATION_OPERATOR_DESC{}, std::forward<Ts>(args)...);
case DML_OPERATOR_BATCH_NORMALIZATION_GRAD:
return std::invoke(std::forward<Visitor>(visitor), DML_BATCH_NORMALIZATION_GRAD_OPERATOR_DESC{}, std::forward<Ts>(args)...);
case DML_OPERATOR_BATCH_NORMALIZATION_TRAINING_GRAD:
return std::invoke(std::forward<Visitor>(visitor), DML_BATCH_NORMALIZATION_TRAINING_GRAD_OPERATOR_DESC{}, std::forward<Ts>(args)...);
case DML_OPERATOR_MEAN_VARIANCE_NORMALIZATION:
return std::invoke(std::forward<Visitor>(visitor), DML_MEAN_VARIANCE_NORMALIZATION_OPERATOR_DESC{}, std::forward<Ts>(args)...);
case DML_OPERATOR_LOCAL_RESPONSE_NORMALIZATION:
return std::invoke(std::forward<Visitor>(visitor), DML_LOCAL_RESPONSE_NORMALIZATION_OPERATOR_DESC{}, std::forward<Ts>(args)...);
case DML_OPERATOR_LOCAL_RESPONSE_NORMALIZATION_GRAD:
return std::invoke(std::forward<Visitor>(visitor), DML_LOCAL_RESPONSE_NORMALIZATION_GRAD_OPERATOR_DESC{}, std::forward<Ts>(args)...);
case DML_OPERATOR_LP_NORMALIZATION:
return std::invoke(std::forward<Visitor>(visitor), DML_LP_NORMALIZATION_OPERATOR_DESC{}, std::forward<Ts>(args)...);
case DML_OPERATOR_RNN:
@ -2012,6 +2222,8 @@ auto OperatorTypeVisitor(DML_OPERATOR_TYPE type, Visitor&& visitor, Ts&&... args
return std::invoke(std::forward<Visitor>(visitor), DML_ELEMENT_WISE_SIGN_OPERATOR_DESC{}, std::forward<Ts>(args)...);
case DML_OPERATOR_ELEMENT_WISE_IS_NAN:
return std::invoke(std::forward<Visitor>(visitor), DML_ELEMENT_WISE_IS_NAN_OPERATOR_DESC{}, std::forward<Ts>(args)...);
case DML_OPERATOR_ELEMENT_WISE_NEGATE:
return std::invoke(std::forward<Visitor>(visitor), DML_ELEMENT_WISE_NEGATE_OPERATOR_DESC{}, std::forward<Ts>(args)...);
case DML_OPERATOR_ELEMENT_WISE_ERF:
return std::invoke(std::forward<Visitor>(visitor), DML_ELEMENT_WISE_ERF_OPERATOR_DESC{}, std::forward<Ts>(args)...);
case DML_OPERATOR_ELEMENT_WISE_SINH:
@ -2058,6 +2270,8 @@ auto OperatorTypeVisitor(DML_OPERATOR_TYPE type, Visitor&& visitor, Ts&&... args
return std::invoke(std::forward<Visitor>(visitor), DML_FILL_VALUE_SEQUENCE_OPERATOR_DESC{}, std::forward<Ts>(args)...);
case DML_OPERATOR_CUMULATIVE_SUMMATION:
return std::invoke(std::forward<Visitor>(visitor), DML_CUMULATIVE_SUMMATION_OPERATOR_DESC{}, std::forward<Ts>(args)...);
case DML_OPERATOR_CUMULATIVE_PRODUCT:
return std::invoke(std::forward<Visitor>(visitor), DML_CUMULATIVE_PRODUCT_OPERATOR_DESC{}, std::forward<Ts>(args)...);
case DML_OPERATOR_REVERSE_SUBSEQUENCES:
return std::invoke(std::forward<Visitor>(visitor), DML_REVERSE_SUBSEQUENCES_OPERATOR_DESC{}, std::forward<Ts>(args)...);
case DML_OPERATOR_GATHER_ELEMENTS:
@ -2116,8 +2330,18 @@ auto OperatorTypeVisitor(DML_OPERATOR_TYPE type, Visitor&& visitor, Ts&&... args
return std::invoke(std::forward<Visitor>(visitor), DML_ADAM_OPTIMIZER_OPERATOR_DESC{}, std::forward<Ts>(args)...);
case DML_OPERATOR_ROI_ALIGN:
return std::invoke(std::forward<Visitor>(visitor), DML_ROI_ALIGN_OPERATOR_DESC{}, std::forward<Ts>(args)...);
case DML_OPERATOR_ROI_ALIGN1:
return std::invoke(std::forward<Visitor>(visitor), DML_ROI_ALIGN1_OPERATOR_DESC{}, std::forward<Ts>(args)...);
case DML_OPERATOR_GATHER_ND1:
return std::invoke(std::forward<Visitor>(visitor), DML_GATHER_ND1_OPERATOR_DESC{}, std::forward<Ts>(args)...);
case DML_OPERATOR_DYNAMIC_QUANTIZE_LINEAR:
return std::invoke(std::forward<Visitor>(visitor), DML_DYNAMIC_QUANTIZE_LINEAR_OPERATOR_DESC{}, std::forward<Ts>(args)...);
case DML_OPERATOR_ELEMENT_WISE_QUANTIZED_LINEAR_ADD:
return std::invoke(std::forward<Visitor>(visitor), DML_ELEMENT_WISE_QUANTIZED_LINEAR_ADD_OPERATOR_DESC{}, std::forward<Ts>(args)...);
case DML_OPERATOR_ROI_ALIGN_GRAD:
return std::invoke(std::forward<Visitor>(visitor), DML_ROI_ALIGN_GRAD_OPERATOR_DESC{}, std::forward<Ts>(args)...);
case DML_OPERATOR_BATCH_NORMALIZATION_TRAINING:
return std::invoke(std::forward<Visitor>(visitor), DML_BATCH_NORMALIZATION_TRAINING_OPERATOR_DESC{}, std::forward<Ts>(args)...);
case DML_OPERATOR_ACTIVATION_ELU:
return std::invoke(std::forward<Visitor>(visitor), DML_ACTIVATION_ELU_OPERATOR_DESC{}, std::forward<Ts>(args)...);
case DML_OPERATOR_ACTIVATION_CELU:
@ -2180,6 +2404,9 @@ inline gsl::czstring ToString(DML_OPERATOR_TYPE value)
case DML_OPERATOR_ELEMENT_WISE_ATAN: return "DML_OPERATOR_ELEMENT_WISE_ATAN";
case DML_OPERATOR_ELEMENT_WISE_CEIL: return "DML_OPERATOR_ELEMENT_WISE_CEIL";
case DML_OPERATOR_ELEMENT_WISE_CLIP: return "DML_OPERATOR_ELEMENT_WISE_CLIP";
case DML_OPERATOR_ELEMENT_WISE_CLIP1: return "DML_OPERATOR_ELEMENT_WISE_CLIP1";
case DML_OPERATOR_ELEMENT_WISE_CLIP_GRAD: return "DML_OPERATOR_ELEMENT_WISE_CLIP_GRAD";
case DML_OPERATOR_ELEMENT_WISE_CLIP_GRAD1: return "DML_OPERATOR_ELEMENT_WISE_CLIP_GRAD1";
case DML_OPERATOR_ELEMENT_WISE_COS: return "DML_OPERATOR_ELEMENT_WISE_COS";
case DML_OPERATOR_ELEMENT_WISE_DIVIDE: return "DML_OPERATOR_ELEMENT_WISE_DIVIDE";
case DML_OPERATOR_ELEMENT_WISE_EXP: return "DML_OPERATOR_ELEMENT_WISE_EXP";
@ -2203,6 +2430,8 @@ inline gsl::czstring ToString(DML_OPERATOR_TYPE value)
case DML_OPERATOR_ELEMENT_WISE_RECIP: return "DML_OPERATOR_ELEMENT_WISE_RECIP";
case DML_OPERATOR_ELEMENT_WISE_SIN: return "DML_OPERATOR_ELEMENT_WISE_SIN";
case DML_OPERATOR_ELEMENT_WISE_SQRT: return "DML_OPERATOR_ELEMENT_WISE_SQRT";
case DML_OPERATOR_ELEMENT_WISE_DIFFERENCE_SQUARE: return "DML_OPERATOR_ELEMENT_WISE_DIFFERENCE_SQUARE";
case DML_OPERATOR_ELEMENT_WISE_ATAN_YX: return "DML_OPERATOR_ELEMENT_WISE_ATAN_YX";
case DML_OPERATOR_ELEMENT_WISE_SUBTRACT: return "DML_OPERATOR_ELEMENT_WISE_SUBTRACT";
case DML_OPERATOR_ELEMENT_WISE_TAN: return "DML_OPERATOR_ELEMENT_WISE_TAN";
case DML_OPERATOR_ELEMENT_WISE_THRESHOLD: return "DML_OPERATOR_ELEMENT_WISE_THRESHOLD";
@ -2223,6 +2452,7 @@ inline gsl::czstring ToString(DML_OPERATOR_TYPE value)
case DML_OPERATOR_SPLIT: return "DML_OPERATOR_SPLIT";
case DML_OPERATOR_JOIN: return "DML_OPERATOR_JOIN";
case DML_OPERATOR_PADDING: return "DML_OPERATOR_PADDING";
case DML_OPERATOR_PADDING1: return "DML_OPERATOR_PADDING1";
case DML_OPERATOR_VALUE_SCALE_2D: return "DML_OPERATOR_VALUE_SCALE_2D";
case DML_OPERATOR_UPSAMPLE_2D: return "DML_OPERATOR_UPSAMPLE_2D";
case DML_OPERATOR_GATHER: return "DML_OPERATOR_GATHER";
@ -2231,14 +2461,18 @@ inline gsl::czstring ToString(DML_OPERATOR_TYPE value)
case DML_OPERATOR_TILE: return "DML_OPERATOR_TILE";
case DML_OPERATOR_TOP_K: return "DML_OPERATOR_TOP_K";
case DML_OPERATOR_BATCH_NORMALIZATION: return "DML_OPERATOR_BATCH_NORMALIZATION";
case DML_OPERATOR_BATCH_NORMALIZATION_GRAD: return "DML_OPERATOR_BATCH_NORMALIZATION_GRAD";
case DML_OPERATOR_BATCH_NORMALIZATION_TRAINING_GRAD: return "DML_OPERATOR_BATCH_NORMALIZATION_TRAINING_GRAD";
case DML_OPERATOR_MEAN_VARIANCE_NORMALIZATION: return "DML_OPERATOR_MEAN_VARIANCE_NORMALIZATION";
case DML_OPERATOR_LOCAL_RESPONSE_NORMALIZATION: return "DML_OPERATOR_LOCAL_RESPONSE_NORMALIZATION";
case DML_OPERATOR_LOCAL_RESPONSE_NORMALIZATION_GRAD: return "DML_OPERATOR_LOCAL_RESPONSE_NORMALIZATION_GRAD";
case DML_OPERATOR_LP_NORMALIZATION: return "DML_OPERATOR_LP_NORMALIZATION";
case DML_OPERATOR_RNN: return "DML_OPERATOR_RNN";
case DML_OPERATOR_LSTM: return "DML_OPERATOR_LSTM";
case DML_OPERATOR_GRU: return "DML_OPERATOR_GRU";
case DML_OPERATOR_ELEMENT_WISE_SIGN: return "DML_OPERATOR_ELEMENT_WISE_SIGN";
case DML_OPERATOR_ELEMENT_WISE_IS_NAN: return "DML_OPERATOR_ELEMENT_WISE_IS_NAN";
case DML_OPERATOR_ELEMENT_WISE_NEGATE: return "DML_OPERATOR_ELEMENT_WISE_NEGATE";
case DML_OPERATOR_ELEMENT_WISE_ERF: return "DML_OPERATOR_ELEMENT_WISE_ERF";
case DML_OPERATOR_ELEMENT_WISE_SINH: return "DML_OPERATOR_ELEMENT_WISE_SINH";
case DML_OPERATOR_ELEMENT_WISE_COSH: return "DML_OPERATOR_ELEMENT_WISE_COSH";
@ -2262,6 +2496,7 @@ inline gsl::czstring ToString(DML_OPERATOR_TYPE value)
case DML_OPERATOR_FILL_VALUE_CONSTANT: return "DML_OPERATOR_FILL_VALUE_CONSTANT";
case DML_OPERATOR_FILL_VALUE_SEQUENCE: return "DML_OPERATOR_FILL_VALUE_SEQUENCE";
case DML_OPERATOR_CUMULATIVE_SUMMATION: return "DML_OPERATOR_CUMULATIVE_SUMMATION";
case DML_OPERATOR_CUMULATIVE_PRODUCT: return "DML_OPERATOR_CUMULATIVE_PRODUCT";
case DML_OPERATOR_REVERSE_SUBSEQUENCES: return "DML_OPERATOR_REVERSE_SUBSEQUENCES";
case DML_OPERATOR_GATHER_ELEMENTS: return "DML_OPERATOR_GATHER_ELEMENTS";
case DML_OPERATOR_GATHER_ND: return "DML_OPERATOR_GATHER_ND";
@ -2291,7 +2526,12 @@ inline gsl::czstring ToString(DML_OPERATOR_TYPE value)
case DML_OPERATOR_SLICE_GRAD: return "DML_OPERATOR_SLICE_GRAD";
case DML_OPERATOR_ADAM_OPTIMIZER: return "DML_OPERATOR_ADAM_OPTIMIZER";
case DML_OPERATOR_ROI_ALIGN: return "DML_OPERATOR_ROI_ALIGN";
case DML_OPERATOR_ROI_ALIGN1: return "DML_OPERATOR_ROI_ALIGN1";
case DML_OPERATOR_GATHER_ND1: return "DML_OPERATOR_GATHER_ND1";
case DML_OPERATOR_DYNAMIC_QUANTIZE_LINEAR: return "DML_OPERATOR_DYNAMIC_QUANTIZE_LINEAR";
case DML_OPERATOR_ELEMENT_WISE_QUANTIZED_LINEAR_ADD: return "DML_OPERATOR_ELEMENT_WISE_QUANTIZED_LINEAR_ADD";
case DML_OPERATOR_ROI_ALIGN_GRAD: return "DML_OPERATOR_ROI_ALIGN_GRAD";
case DML_OPERATOR_BATCH_NORMALIZATION_TRAINING: return "DML_OPERATOR_BATCH_NORMALIZATION_TRAINING";
default:
assert(false);
return "<unknown>";

View file

@ -28,6 +28,7 @@ enum DML_SCHEMA_FIELD_TYPE
DML_SCHEMA_FIELD_TYPE_SCALE_BIAS,
DML_SCHEMA_FIELD_TYPE_SIZE_2D,
DML_SCHEMA_FIELD_TYPE_SCALAR_UNION,
DML_SCHEMA_FIELD_TYPE_BOOL,
};
enum DML_SCHEMA_OPERATOR_SUPPORT_FLAGS
@ -170,6 +171,56 @@ constexpr DML_OPERATOR_SCHEMA DML_ELEMENT_WISE_CLIP_OPERATOR_SCHEMA {
DML_ELEMENT_WISE_CLIP_OPERATOR_SCHEMA_FIELDS,
};
constexpr DML_SCHEMA_FIELD DML_ELEMENT_WISE_CLIP1_OPERATOR_SCHEMA_FIELDS[6] {
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_SCALE_BIAS, "ScaleBias", true },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT, "MinMaxDataType", false },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_SCALAR_UNION, "Min", false },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_SCALAR_UNION, "Max", false },
};
constexpr DML_OPERATOR_SCHEMA DML_ELEMENT_WISE_CLIP1_OPERATOR_SCHEMA {
"DML_OPERATOR_ELEMENT_WISE_CLIP1",
DML_OPERATOR_ELEMENT_WISE_CLIP1,
DML_SCHEMA_OPERATOR_SUPPORT_FLAG_IN_PLACE_EXECUTION,
6,
DML_ELEMENT_WISE_CLIP1_OPERATOR_SCHEMA_FIELDS,
};
constexpr DML_SCHEMA_FIELD DML_ELEMENT_WISE_CLIP_GRAD_OPERATOR_SCHEMA_FIELDS[5] {
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputGradientTensor", false },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputGradientTensor", false },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_FLOAT, "Min", false },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_FLOAT, "Max", false },
};
constexpr DML_OPERATOR_SCHEMA DML_ELEMENT_WISE_CLIP_GRAD_OPERATOR_SCHEMA {
"DML_OPERATOR_ELEMENT_WISE_CLIP_GRAD",
DML_OPERATOR_ELEMENT_WISE_CLIP_GRAD,
DML_SCHEMA_OPERATOR_SUPPORT_FLAG_IN_PLACE_EXECUTION,
5,
DML_ELEMENT_WISE_CLIP_GRAD_OPERATOR_SCHEMA_FIELDS,
};
constexpr DML_SCHEMA_FIELD DML_ELEMENT_WISE_CLIP_GRAD1_OPERATOR_SCHEMA_FIELDS[6] {
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputGradientTensor", false },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputGradientTensor", false },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT, "MinMaxDataType", false },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_SCALAR_UNION, "Min", false },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_SCALAR_UNION, "Max", false },
};
constexpr DML_OPERATOR_SCHEMA DML_ELEMENT_WISE_CLIP_GRAD1_OPERATOR_SCHEMA {
"DML_OPERATOR_ELEMENT_WISE_CLIP_GRAD1",
DML_OPERATOR_ELEMENT_WISE_CLIP_GRAD1,
DML_SCHEMA_OPERATOR_SUPPORT_FLAG_IN_PLACE_EXECUTION,
6,
DML_ELEMENT_WISE_CLIP_GRAD1_OPERATOR_SCHEMA_FIELDS,
};
constexpr DML_SCHEMA_FIELD DML_ELEMENT_WISE_COS_OPERATOR_SCHEMA_FIELDS[3] {
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false },
@ -493,6 +544,34 @@ constexpr DML_OPERATOR_SCHEMA DML_ELEMENT_WISE_SQRT_OPERATOR_SCHEMA {
DML_ELEMENT_WISE_SQRT_OPERATOR_SCHEMA_FIELDS,
};
constexpr DML_SCHEMA_FIELD DML_ELEMENT_WISE_DIFFERENCE_SQUARE_OPERATOR_SCHEMA_FIELDS[3] {
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "ATensor", false },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "BTensor", false },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false },
};
constexpr DML_OPERATOR_SCHEMA DML_ELEMENT_WISE_DIFFERENCE_SQUARE_OPERATOR_SCHEMA {
"DML_OPERATOR_ELEMENT_WISE_DIFFERENCE_SQUARE",
DML_OPERATOR_ELEMENT_WISE_DIFFERENCE_SQUARE,
DML_SCHEMA_OPERATOR_SUPPORT_FLAG_IN_PLACE_EXECUTION,
3,
DML_ELEMENT_WISE_DIFFERENCE_SQUARE_OPERATOR_SCHEMA_FIELDS,
};
constexpr DML_SCHEMA_FIELD DML_ELEMENT_WISE_ATAN_YX_OPERATOR_SCHEMA_FIELDS[3] {
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "ATensor", false },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "BTensor", false },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false },
};
constexpr DML_OPERATOR_SCHEMA DML_ELEMENT_WISE_ATAN_YX_OPERATOR_SCHEMA {
"DML_OPERATOR_ELEMENT_WISE_ATAN_YX",
DML_OPERATOR_ELEMENT_WISE_ATAN_YX,
DML_SCHEMA_OPERATOR_SUPPORT_FLAG_IN_PLACE_EXECUTION,
3,
DML_ELEMENT_WISE_ATAN_YX_OPERATOR_SCHEMA_FIELDS,
};
constexpr DML_SCHEMA_FIELD DML_ELEMENT_WISE_SUBTRACT_OPERATOR_SCHEMA_FIELDS[3] {
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "ATensor", false },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "BTensor", false },
@ -828,6 +907,25 @@ constexpr DML_OPERATOR_SCHEMA DML_PADDING_OPERATOR_SCHEMA {
DML_PADDING_OPERATOR_SCHEMA_FIELDS,
};
constexpr DML_SCHEMA_FIELD DML_PADDING1_OPERATOR_SCHEMA_FIELDS[8] {
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT, "PaddingMode", false },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT, "PaddingValueDataType", false },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_SCALAR_UNION, "PaddingValue", false },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT, "DimensionCount", false },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT_ARRAY, "StartPadding", false },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT_ARRAY, "EndPadding", false },
};
constexpr DML_OPERATOR_SCHEMA DML_PADDING1_OPERATOR_SCHEMA {
"DML_OPERATOR_PADDING1",
DML_OPERATOR_PADDING1,
DML_SCHEMA_OPERATOR_SUPPORT_FLAG_NONE,
8,
DML_PADDING1_OPERATOR_SCHEMA_FIELDS,
};
constexpr DML_SCHEMA_FIELD DML_VALUE_SCALE_2D_OPERATOR_SCHEMA_FIELDS[5] {
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false },
@ -954,6 +1052,46 @@ constexpr DML_OPERATOR_SCHEMA DML_BATCH_NORMALIZATION_OPERATOR_SCHEMA {
DML_BATCH_NORMALIZATION_OPERATOR_SCHEMA_FIELDS,
};
constexpr DML_SCHEMA_FIELD DML_BATCH_NORMALIZATION_GRAD_OPERATOR_SCHEMA_FIELDS[9] {
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputGradientTensor", false },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "MeanTensor", false },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "VarianceTensor", false },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "ScaleTensor", false },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputGradientTensor", false },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputScaleGradientTensor", false },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputBiasGradientTensor", false },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_FLOAT, "Epsilon", true },
};
constexpr DML_OPERATOR_SCHEMA DML_BATCH_NORMALIZATION_GRAD_OPERATOR_SCHEMA {
"DML_OPERATOR_BATCH_NORMALIZATION_GRAD",
DML_OPERATOR_BATCH_NORMALIZATION_GRAD,
DML_SCHEMA_OPERATOR_SUPPORT_FLAG_NONE,
9,
DML_BATCH_NORMALIZATION_GRAD_OPERATOR_SCHEMA_FIELDS,
};
constexpr DML_SCHEMA_FIELD DML_BATCH_NORMALIZATION_TRAINING_GRAD_OPERATOR_SCHEMA_FIELDS[9] {
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputGradientTensor", false },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "MeanTensor", false },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "VarianceTensor", false },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "ScaleTensor", false },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputGradientTensor", false },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputScaleGradientTensor", false },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputBiasGradientTensor", false },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_FLOAT, "Epsilon", true },
};
constexpr DML_OPERATOR_SCHEMA DML_BATCH_NORMALIZATION_TRAINING_GRAD_OPERATOR_SCHEMA {
"DML_OPERATOR_BATCH_NORMALIZATION_TRAINING_GRAD",
DML_OPERATOR_BATCH_NORMALIZATION_TRAINING_GRAD,
DML_SCHEMA_OPERATOR_SUPPORT_FLAG_NONE,
9,
DML_BATCH_NORMALIZATION_TRAINING_GRAD_OPERATOR_SCHEMA_FIELDS,
};
constexpr DML_SCHEMA_FIELD DML_MEAN_VARIANCE_NORMALIZATION_OPERATOR_SCHEMA_FIELDS[8] {
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "ScaleTensor", true },
@ -991,6 +1129,25 @@ constexpr DML_OPERATOR_SCHEMA DML_LOCAL_RESPONSE_NORMALIZATION_OPERATOR_SCHEMA {
DML_LOCAL_RESPONSE_NORMALIZATION_OPERATOR_SCHEMA_FIELDS,
};
constexpr DML_SCHEMA_FIELD DML_LOCAL_RESPONSE_NORMALIZATION_GRAD_OPERATOR_SCHEMA_FIELDS[8] {
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputGradientTensor", false },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputGradientTensor", false },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_BOOL, "CrossChannel", false },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT, "LocalSize", false },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_FLOAT, "Alpha", false },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_FLOAT, "Beta", false },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_FLOAT, "Bias", false },
};
constexpr DML_OPERATOR_SCHEMA DML_LOCAL_RESPONSE_NORMALIZATION_GRAD_OPERATOR_SCHEMA {
"DML_OPERATOR_LOCAL_RESPONSE_NORMALIZATION_GRAD",
DML_OPERATOR_LOCAL_RESPONSE_NORMALIZATION_GRAD,
DML_SCHEMA_OPERATOR_SUPPORT_FLAG_NONE,
8,
DML_LOCAL_RESPONSE_NORMALIZATION_GRAD_OPERATOR_SCHEMA_FIELDS,
};
constexpr DML_SCHEMA_FIELD DML_LP_NORMALIZATION_OPERATOR_SCHEMA_FIELDS[5] {
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false },
@ -1106,6 +1263,19 @@ constexpr DML_OPERATOR_SCHEMA DML_ELEMENT_WISE_IS_NAN_OPERATOR_SCHEMA {
DML_ELEMENT_WISE_IS_NAN_OPERATOR_SCHEMA_FIELDS,
};
constexpr DML_SCHEMA_FIELD DML_ELEMENT_WISE_NEGATE_OPERATOR_SCHEMA_FIELDS[2] {
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false },
};
constexpr DML_OPERATOR_SCHEMA DML_ELEMENT_WISE_NEGATE_OPERATOR_SCHEMA {
"DML_OPERATOR_ELEMENT_WISE_NEGATE",
DML_OPERATOR_ELEMENT_WISE_NEGATE,
DML_SCHEMA_OPERATOR_SUPPORT_FLAG_NONE,
2,
DML_ELEMENT_WISE_NEGATE_OPERATOR_SCHEMA_FIELDS,
};
constexpr DML_SCHEMA_FIELD DML_ELEMENT_WISE_ERF_OPERATOR_SCHEMA_FIELDS[3] {
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false },
@ -1426,8 +1596,8 @@ constexpr DML_SCHEMA_FIELD DML_CUMULATIVE_SUMMATION_OPERATOR_SCHEMA_FIELDS[5] {
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT, "Axis", false },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT, "HasExclusiveSum", false },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT, "AxisDirection", false },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT, "HasExclusiveSum", false },
};
constexpr DML_OPERATOR_SCHEMA DML_CUMULATIVE_SUMMATION_OPERATOR_SCHEMA {
@ -1438,6 +1608,22 @@ constexpr DML_OPERATOR_SCHEMA DML_CUMULATIVE_SUMMATION_OPERATOR_SCHEMA {
DML_CUMULATIVE_SUMMATION_OPERATOR_SCHEMA_FIELDS,
};
constexpr DML_SCHEMA_FIELD DML_CUMULATIVE_PRODUCT_OPERATOR_SCHEMA_FIELDS[5] {
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT, "Axis", false },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT, "AxisDirection", false },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT, "HasExclusiveProduct", false },
};
constexpr DML_OPERATOR_SCHEMA DML_CUMULATIVE_PRODUCT_OPERATOR_SCHEMA {
"DML_OPERATOR_CUMULATIVE_PRODUCT",
DML_OPERATOR_CUMULATIVE_PRODUCT,
DML_SCHEMA_OPERATOR_SUPPORT_FLAG_IN_PLACE_EXECUTION,
5,
DML_CUMULATIVE_PRODUCT_OPERATOR_SCHEMA_FIELDS,
};
constexpr DML_SCHEMA_FIELD DML_REVERSE_SUBSEQUENCES_OPERATOR_SCHEMA_FIELDS[4] {
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "SequenceLengthsTensor", false },
@ -1448,7 +1634,7 @@ constexpr DML_SCHEMA_FIELD DML_REVERSE_SUBSEQUENCES_OPERATOR_SCHEMA_FIELDS[4] {
constexpr DML_OPERATOR_SCHEMA DML_REVERSE_SUBSEQUENCES_OPERATOR_SCHEMA {
"DML_OPERATOR_REVERSE_SUBSEQUENCES",
DML_OPERATOR_REVERSE_SUBSEQUENCES,
DML_SCHEMA_OPERATOR_SUPPORT_FLAG_IN_PLACE_EXECUTION,
DML_SCHEMA_OPERATOR_SUPPORT_FLAG_NONE,
4,
DML_REVERSE_SUBSEQUENCES_OPERATOR_SCHEMA_FIELDS,
};
@ -1809,17 +1995,23 @@ constexpr DML_OPERATOR_SCHEMA DML_AVERAGE_POOLING_GRAD_OPERATOR_SCHEMA {
DML_AVERAGE_POOLING_GRAD_OPERATOR_SCHEMA_FIELDS,
};
constexpr DML_SCHEMA_FIELD DML_MAX_POOLING_GRAD_OPERATOR_SCHEMA_FIELDS[3] {
constexpr DML_SCHEMA_FIELD DML_MAX_POOLING_GRAD_OPERATOR_SCHEMA_FIELDS[9] {
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputGradientTensor", false },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputGradientTensor", false },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT, "DimensionCount", false },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT_ARRAY, "Strides", false },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT_ARRAY, "WindowSize", false },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT_ARRAY, "StartPadding", false },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT_ARRAY, "EndPadding", false },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT_ARRAY, "Dilations", false },
};
constexpr DML_OPERATOR_SCHEMA DML_MAX_POOLING_GRAD_OPERATOR_SCHEMA {
"DML_OPERATOR_MAX_POOLING_GRAD",
DML_OPERATOR_MAX_POOLING_GRAD,
DML_SCHEMA_OPERATOR_SUPPORT_FLAG_NONE,
3,
9,
DML_MAX_POOLING_GRAD_OPERATOR_SCHEMA_FIELDS,
};
@ -1932,6 +2124,31 @@ constexpr DML_OPERATOR_SCHEMA DML_ROI_ALIGN_OPERATOR_SCHEMA {
DML_ROI_ALIGN_OPERATOR_SCHEMA_FIELDS,
};
constexpr DML_SCHEMA_FIELD DML_ROI_ALIGN1_OPERATOR_SCHEMA_FIELDS[14] {
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "ROITensor", false },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "BatchIndicesTensor", false },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT, "ReductionFunction", false },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT, "InterpolationMode", false },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_FLOAT, "SpatialScaleX", false },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_FLOAT, "SpatialScaleY", false },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_FLOAT, "InputPixelOffset", false },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_FLOAT, "OutputPixelOffset", false },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_FLOAT, "OutOfBoundsInputValue", false },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT, "MinimumSamplesPerOutput", false },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT, "MaximumSamplesPerOutput", false },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT, "AlignRegionsToCorners", false },
};
constexpr DML_OPERATOR_SCHEMA DML_ROI_ALIGN1_OPERATOR_SCHEMA {
"DML_OPERATOR_ROI_ALIGN1",
DML_OPERATOR_ROI_ALIGN1,
DML_SCHEMA_OPERATOR_SUPPORT_FLAG_NONE,
14,
DML_ROI_ALIGN1_OPERATOR_SCHEMA_FIELDS,
};
constexpr DML_SCHEMA_FIELD DML_GATHER_ND1_OPERATOR_SCHEMA_FIELDS[6] {
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "IndicesTensor", false },
@ -1949,6 +2166,87 @@ constexpr DML_OPERATOR_SCHEMA DML_GATHER_ND1_OPERATOR_SCHEMA {
DML_GATHER_ND1_OPERATOR_SCHEMA_FIELDS,
};
constexpr DML_SCHEMA_FIELD DML_DYNAMIC_QUANTIZE_LINEAR_OPERATOR_SCHEMA_FIELDS[4] {
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputScaleTensor", false },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputZeroPointTensor", false },
};
constexpr DML_OPERATOR_SCHEMA DML_DYNAMIC_QUANTIZE_LINEAR_OPERATOR_SCHEMA {
"DML_OPERATOR_DYNAMIC_QUANTIZE_LINEAR",
DML_OPERATOR_DYNAMIC_QUANTIZE_LINEAR,
DML_SCHEMA_OPERATOR_SUPPORT_FLAG_NONE,
4,
DML_DYNAMIC_QUANTIZE_LINEAR_OPERATOR_SCHEMA_FIELDS,
};
constexpr DML_SCHEMA_FIELD DML_ELEMENT_WISE_QUANTIZED_LINEAR_ADD_OPERATOR_SCHEMA_FIELDS[9] {
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "ATensor", false },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "AScaleTensor", false },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "AZeroPointTensor", true },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "BTensor", false },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "BScaleTensor", false },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "BZeroPointTensor", true },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputScaleTensor", false },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputZeroPointTensor", true },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false },
};
constexpr DML_OPERATOR_SCHEMA DML_ELEMENT_WISE_QUANTIZED_LINEAR_ADD_OPERATOR_SCHEMA {
"DML_OPERATOR_ELEMENT_WISE_QUANTIZED_LINEAR_ADD",
DML_OPERATOR_ELEMENT_WISE_QUANTIZED_LINEAR_ADD,
DML_SCHEMA_OPERATOR_SUPPORT_FLAG_NONE,
9,
DML_ELEMENT_WISE_QUANTIZED_LINEAR_ADD_OPERATOR_SCHEMA_FIELDS,
};
constexpr DML_SCHEMA_FIELD DML_ROI_ALIGN_GRAD_OPERATOR_SCHEMA_FIELDS[15] {
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", true },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputGradientTensor", false },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "ROITensor", false },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "BatchIndicesTensor", false },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputGradientTensor", true },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputROIGradientTensor", true },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT, "ReductionFunction", false },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT, "InterpolationMode", false },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_FLOAT, "SpatialScaleX", false },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_FLOAT, "SpatialScaleY", false },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_FLOAT, "InputPixelOffset", false },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_FLOAT, "OutputPixelOffset", false },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT, "MinimumSamplesPerOutput", false },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT, "MaximumSamplesPerOutput", false },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT, "AlignRegionsToCorners", false },
};
constexpr DML_OPERATOR_SCHEMA DML_ROI_ALIGN_GRAD_OPERATOR_SCHEMA {
"DML_OPERATOR_ROI_ALIGN_GRAD",
DML_OPERATOR_ROI_ALIGN_GRAD,
DML_SCHEMA_OPERATOR_SUPPORT_FLAG_NONE,
15,
DML_ROI_ALIGN_GRAD_OPERATOR_SCHEMA_FIELDS,
};
constexpr DML_SCHEMA_FIELD DML_BATCH_NORMALIZATION_TRAINING_OPERATOR_SCHEMA_FIELDS[9] {
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "ScaleTensor", false },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "BiasTensor", false },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "FusedAddTensor", true },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputMeanTensor", false },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputVarianceTensor", false },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_FLOAT, "Epsilon", false },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_OPERATOR_DESC, "FusedActivation", true },
};
constexpr DML_OPERATOR_SCHEMA DML_BATCH_NORMALIZATION_TRAINING_OPERATOR_SCHEMA {
"DML_OPERATOR_BATCH_NORMALIZATION_TRAINING",
DML_OPERATOR_BATCH_NORMALIZATION_TRAINING,
DML_SCHEMA_OPERATOR_SUPPORT_FLAG_NONE,
9,
DML_BATCH_NORMALIZATION_TRAINING_OPERATOR_SCHEMA_FIELDS,
};
constexpr DML_SCHEMA_FIELD DML_ACTIVATION_ELU_OPERATOR_SCHEMA_FIELDS[3] {
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false },

View file

@ -73,6 +73,38 @@ inline std::vector<OperatorField> GetFields(const DML_ELEMENT_WISE_CLIP_OPERATOR
OperatorField(&DML_ELEMENT_WISE_CLIP_OPERATOR_SCHEMA.Fields[4], ToOperatorFieldType(static_cast<FLOAT>(desc.Max))),
};
}
inline std::vector<OperatorField> GetFields(const DML_ELEMENT_WISE_CLIP1_OPERATOR_DESC& desc)
{
return {
OperatorField(&DML_ELEMENT_WISE_CLIP1_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.InputTensor))),
OperatorField(&DML_ELEMENT_WISE_CLIP1_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.OutputTensor))),
OperatorField(&DML_ELEMENT_WISE_CLIP1_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast<const DML_SCALE_BIAS*>(desc.ScaleBias))),
OperatorField(&DML_ELEMENT_WISE_CLIP1_OPERATOR_SCHEMA.Fields[3], ToOperatorFieldType(static_cast<UINT>(desc.MinMaxDataType))),
OperatorField(&DML_ELEMENT_WISE_CLIP1_OPERATOR_SCHEMA.Fields[4], ToOperatorFieldType(static_cast<DML_SCALAR_UNION>(desc.Min))),
OperatorField(&DML_ELEMENT_WISE_CLIP1_OPERATOR_SCHEMA.Fields[5], ToOperatorFieldType(static_cast<DML_SCALAR_UNION>(desc.Max))),
};
}
inline std::vector<OperatorField> GetFields(const DML_ELEMENT_WISE_CLIP_GRAD_OPERATOR_DESC& desc)
{
return {
OperatorField(&DML_ELEMENT_WISE_CLIP_GRAD_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.InputTensor))),
OperatorField(&DML_ELEMENT_WISE_CLIP_GRAD_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.InputGradientTensor))),
OperatorField(&DML_ELEMENT_WISE_CLIP_GRAD_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.OutputGradientTensor))),
OperatorField(&DML_ELEMENT_WISE_CLIP_GRAD_OPERATOR_SCHEMA.Fields[3], ToOperatorFieldType(static_cast<FLOAT>(desc.Min))),
OperatorField(&DML_ELEMENT_WISE_CLIP_GRAD_OPERATOR_SCHEMA.Fields[4], ToOperatorFieldType(static_cast<FLOAT>(desc.Max))),
};
}
inline std::vector<OperatorField> GetFields(const DML_ELEMENT_WISE_CLIP_GRAD1_OPERATOR_DESC& desc)
{
return {
OperatorField(&DML_ELEMENT_WISE_CLIP_GRAD1_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.InputTensor))),
OperatorField(&DML_ELEMENT_WISE_CLIP_GRAD1_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.InputGradientTensor))),
OperatorField(&DML_ELEMENT_WISE_CLIP_GRAD1_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.OutputGradientTensor))),
OperatorField(&DML_ELEMENT_WISE_CLIP_GRAD1_OPERATOR_SCHEMA.Fields[3], ToOperatorFieldType(static_cast<UINT>(desc.MinMaxDataType))),
OperatorField(&DML_ELEMENT_WISE_CLIP_GRAD1_OPERATOR_SCHEMA.Fields[4], ToOperatorFieldType(static_cast<DML_SCALAR_UNION>(desc.Min))),
OperatorField(&DML_ELEMENT_WISE_CLIP_GRAD1_OPERATOR_SCHEMA.Fields[5], ToOperatorFieldType(static_cast<DML_SCALAR_UNION>(desc.Max))),
};
}
inline std::vector<OperatorField> GetFields(const DML_ELEMENT_WISE_COS_OPERATOR_DESC& desc)
{
return {
@ -258,6 +290,22 @@ inline std::vector<OperatorField> GetFields(const DML_ELEMENT_WISE_SQRT_OPERATOR
OperatorField(&DML_ELEMENT_WISE_SQRT_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast<const DML_SCALE_BIAS*>(desc.ScaleBias))),
};
}
inline std::vector<OperatorField> GetFields(const DML_ELEMENT_WISE_DIFFERENCE_SQUARE_OPERATOR_DESC& desc)
{
return {
OperatorField(&DML_ELEMENT_WISE_DIFFERENCE_SQUARE_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.ATensor))),
OperatorField(&DML_ELEMENT_WISE_DIFFERENCE_SQUARE_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.BTensor))),
OperatorField(&DML_ELEMENT_WISE_DIFFERENCE_SQUARE_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.OutputTensor))),
};
}
inline std::vector<OperatorField> GetFields(const DML_ELEMENT_WISE_ATAN_YX_OPERATOR_DESC& desc)
{
return {
OperatorField(&DML_ELEMENT_WISE_ATAN_YX_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.ATensor))),
OperatorField(&DML_ELEMENT_WISE_ATAN_YX_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.BTensor))),
OperatorField(&DML_ELEMENT_WISE_ATAN_YX_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.OutputTensor))),
};
}
inline std::vector<OperatorField> GetFields(const DML_ELEMENT_WISE_SUBTRACT_OPERATOR_DESC& desc)
{
return {
@ -473,6 +521,19 @@ inline std::vector<OperatorField> GetFields(const DML_PADDING_OPERATOR_DESC& des
OperatorField(&DML_PADDING_OPERATOR_SCHEMA.Fields[6], ToOperatorFieldType(static_cast<const UINT*>(desc.EndPadding), desc.DimensionCount)),
};
}
inline std::vector<OperatorField> GetFields(const DML_PADDING1_OPERATOR_DESC& desc)
{
return {
OperatorField(&DML_PADDING1_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.InputTensor))),
OperatorField(&DML_PADDING1_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.OutputTensor))),
OperatorField(&DML_PADDING1_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast<UINT>(desc.PaddingMode))),
OperatorField(&DML_PADDING1_OPERATOR_SCHEMA.Fields[3], ToOperatorFieldType(static_cast<UINT>(desc.PaddingValueDataType))),
OperatorField(&DML_PADDING1_OPERATOR_SCHEMA.Fields[4], ToOperatorFieldType(static_cast<DML_SCALAR_UNION>(desc.PaddingValue))),
OperatorField(&DML_PADDING1_OPERATOR_SCHEMA.Fields[5], ToOperatorFieldType(static_cast<UINT>(desc.DimensionCount))),
OperatorField(&DML_PADDING1_OPERATOR_SCHEMA.Fields[6], ToOperatorFieldType(static_cast<const UINT*>(desc.StartPadding), desc.DimensionCount)),
OperatorField(&DML_PADDING1_OPERATOR_SCHEMA.Fields[7], ToOperatorFieldType(static_cast<const UINT*>(desc.EndPadding), desc.DimensionCount)),
};
}
inline std::vector<OperatorField> GetFields(const DML_VALUE_SCALE_2D_OPERATOR_DESC& desc)
{
return {
@ -551,6 +612,34 @@ inline std::vector<OperatorField> GetFields(const DML_BATCH_NORMALIZATION_OPERAT
OperatorField(&DML_BATCH_NORMALIZATION_OPERATOR_SCHEMA.Fields[8], ToOperatorFieldType(static_cast<const DML_OPERATOR_DESC*>(desc.FusedActivation))),
};
}
inline std::vector<OperatorField> GetFields(const DML_BATCH_NORMALIZATION_GRAD_OPERATOR_DESC& desc)
{
return {
OperatorField(&DML_BATCH_NORMALIZATION_GRAD_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.InputTensor))),
OperatorField(&DML_BATCH_NORMALIZATION_GRAD_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.InputGradientTensor))),
OperatorField(&DML_BATCH_NORMALIZATION_GRAD_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.MeanTensor))),
OperatorField(&DML_BATCH_NORMALIZATION_GRAD_OPERATOR_SCHEMA.Fields[3], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.VarianceTensor))),
OperatorField(&DML_BATCH_NORMALIZATION_GRAD_OPERATOR_SCHEMA.Fields[4], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.ScaleTensor))),
OperatorField(&DML_BATCH_NORMALIZATION_GRAD_OPERATOR_SCHEMA.Fields[5], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.OutputGradientTensor))),
OperatorField(&DML_BATCH_NORMALIZATION_GRAD_OPERATOR_SCHEMA.Fields[6], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.OutputScaleGradientTensor))),
OperatorField(&DML_BATCH_NORMALIZATION_GRAD_OPERATOR_SCHEMA.Fields[7], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.OutputBiasGradientTensor))),
OperatorField(&DML_BATCH_NORMALIZATION_GRAD_OPERATOR_SCHEMA.Fields[8], ToOperatorFieldType(static_cast<FLOAT>(desc.Epsilon))),
};
}
inline std::vector<OperatorField> GetFields(const DML_BATCH_NORMALIZATION_TRAINING_GRAD_OPERATOR_DESC& desc)
{
return {
OperatorField(&DML_BATCH_NORMALIZATION_TRAINING_GRAD_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.InputTensor))),
OperatorField(&DML_BATCH_NORMALIZATION_TRAINING_GRAD_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.InputGradientTensor))),
OperatorField(&DML_BATCH_NORMALIZATION_TRAINING_GRAD_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.MeanTensor))),
OperatorField(&DML_BATCH_NORMALIZATION_TRAINING_GRAD_OPERATOR_SCHEMA.Fields[3], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.VarianceTensor))),
OperatorField(&DML_BATCH_NORMALIZATION_TRAINING_GRAD_OPERATOR_SCHEMA.Fields[4], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.ScaleTensor))),
OperatorField(&DML_BATCH_NORMALIZATION_TRAINING_GRAD_OPERATOR_SCHEMA.Fields[5], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.OutputGradientTensor))),
OperatorField(&DML_BATCH_NORMALIZATION_TRAINING_GRAD_OPERATOR_SCHEMA.Fields[6], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.OutputScaleGradientTensor))),
OperatorField(&DML_BATCH_NORMALIZATION_TRAINING_GRAD_OPERATOR_SCHEMA.Fields[7], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.OutputBiasGradientTensor))),
OperatorField(&DML_BATCH_NORMALIZATION_TRAINING_GRAD_OPERATOR_SCHEMA.Fields[8], ToOperatorFieldType(static_cast<FLOAT>(desc.Epsilon))),
};
}
inline std::vector<OperatorField> GetFields(const DML_MEAN_VARIANCE_NORMALIZATION_OPERATOR_DESC& desc)
{
return {
@ -576,6 +665,19 @@ inline std::vector<OperatorField> GetFields(const DML_LOCAL_RESPONSE_NORMALIZATI
OperatorField(&DML_LOCAL_RESPONSE_NORMALIZATION_OPERATOR_SCHEMA.Fields[6], ToOperatorFieldType(static_cast<FLOAT>(desc.Bias))),
};
}
inline std::vector<OperatorField> GetFields(const DML_LOCAL_RESPONSE_NORMALIZATION_GRAD_OPERATOR_DESC& desc)
{
return {
OperatorField(&DML_LOCAL_RESPONSE_NORMALIZATION_GRAD_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.InputTensor))),
OperatorField(&DML_LOCAL_RESPONSE_NORMALIZATION_GRAD_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.InputGradientTensor))),
OperatorField(&DML_LOCAL_RESPONSE_NORMALIZATION_GRAD_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.OutputGradientTensor))),
OperatorField(&DML_LOCAL_RESPONSE_NORMALIZATION_GRAD_OPERATOR_SCHEMA.Fields[3], ToOperatorFieldType(static_cast<bool>(desc.CrossChannel))),
OperatorField(&DML_LOCAL_RESPONSE_NORMALIZATION_GRAD_OPERATOR_SCHEMA.Fields[4], ToOperatorFieldType(static_cast<UINT>(desc.LocalSize))),
OperatorField(&DML_LOCAL_RESPONSE_NORMALIZATION_GRAD_OPERATOR_SCHEMA.Fields[5], ToOperatorFieldType(static_cast<FLOAT>(desc.Alpha))),
OperatorField(&DML_LOCAL_RESPONSE_NORMALIZATION_GRAD_OPERATOR_SCHEMA.Fields[6], ToOperatorFieldType(static_cast<FLOAT>(desc.Beta))),
OperatorField(&DML_LOCAL_RESPONSE_NORMALIZATION_GRAD_OPERATOR_SCHEMA.Fields[7], ToOperatorFieldType(static_cast<FLOAT>(desc.Bias))),
};
}
inline std::vector<OperatorField> GetFields(const DML_LP_NORMALIZATION_OPERATOR_DESC& desc)
{
return {
@ -655,6 +757,13 @@ inline std::vector<OperatorField> GetFields(const DML_ELEMENT_WISE_IS_NAN_OPERAT
OperatorField(&DML_ELEMENT_WISE_IS_NAN_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.OutputTensor))),
};
}
inline std::vector<OperatorField> GetFields(const DML_ELEMENT_WISE_NEGATE_OPERATOR_DESC& desc)
{
return {
OperatorField(&DML_ELEMENT_WISE_NEGATE_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.InputTensor))),
OperatorField(&DML_ELEMENT_WISE_NEGATE_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.OutputTensor))),
};
}
inline std::vector<OperatorField> GetFields(const DML_ELEMENT_WISE_ERF_OPERATOR_DESC& desc)
{
return {
@ -845,8 +954,18 @@ inline std::vector<OperatorField> GetFields(const DML_CUMULATIVE_SUMMATION_OPERA
OperatorField(&DML_CUMULATIVE_SUMMATION_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.InputTensor))),
OperatorField(&DML_CUMULATIVE_SUMMATION_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.OutputTensor))),
OperatorField(&DML_CUMULATIVE_SUMMATION_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast<UINT>(desc.Axis))),
OperatorField(&DML_CUMULATIVE_SUMMATION_OPERATOR_SCHEMA.Fields[3], ToOperatorFieldType(static_cast<UINT>(desc.HasExclusiveSum))),
OperatorField(&DML_CUMULATIVE_SUMMATION_OPERATOR_SCHEMA.Fields[4], ToOperatorFieldType(static_cast<UINT>(desc.AxisDirection))),
OperatorField(&DML_CUMULATIVE_SUMMATION_OPERATOR_SCHEMA.Fields[3], ToOperatorFieldType(static_cast<UINT>(desc.AxisDirection))),
OperatorField(&DML_CUMULATIVE_SUMMATION_OPERATOR_SCHEMA.Fields[4], ToOperatorFieldType(static_cast<UINT>(desc.HasExclusiveSum))),
};
}
inline std::vector<OperatorField> GetFields(const DML_CUMULATIVE_PRODUCT_OPERATOR_DESC& desc)
{
return {
OperatorField(&DML_CUMULATIVE_PRODUCT_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.InputTensor))),
OperatorField(&DML_CUMULATIVE_PRODUCT_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.OutputTensor))),
OperatorField(&DML_CUMULATIVE_PRODUCT_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast<UINT>(desc.Axis))),
OperatorField(&DML_CUMULATIVE_PRODUCT_OPERATOR_SCHEMA.Fields[3], ToOperatorFieldType(static_cast<UINT>(desc.AxisDirection))),
OperatorField(&DML_CUMULATIVE_PRODUCT_OPERATOR_SCHEMA.Fields[4], ToOperatorFieldType(static_cast<UINT>(desc.HasExclusiveProduct))),
};
}
inline std::vector<OperatorField> GetFields(const DML_REVERSE_SUBSEQUENCES_OPERATOR_DESC& desc)
@ -1094,6 +1213,12 @@ inline std::vector<OperatorField> GetFields(const DML_MAX_POOLING_GRAD_OPERATOR_
OperatorField(&DML_MAX_POOLING_GRAD_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.InputTensor))),
OperatorField(&DML_MAX_POOLING_GRAD_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.InputGradientTensor))),
OperatorField(&DML_MAX_POOLING_GRAD_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.OutputGradientTensor))),
OperatorField(&DML_MAX_POOLING_GRAD_OPERATOR_SCHEMA.Fields[3], ToOperatorFieldType(static_cast<UINT>(desc.DimensionCount))),
OperatorField(&DML_MAX_POOLING_GRAD_OPERATOR_SCHEMA.Fields[4], ToOperatorFieldType(static_cast<const UINT*>(desc.Strides), desc.DimensionCount)),
OperatorField(&DML_MAX_POOLING_GRAD_OPERATOR_SCHEMA.Fields[5], ToOperatorFieldType(static_cast<const UINT*>(desc.WindowSize), desc.DimensionCount)),
OperatorField(&DML_MAX_POOLING_GRAD_OPERATOR_SCHEMA.Fields[6], ToOperatorFieldType(static_cast<const UINT*>(desc.StartPadding), desc.DimensionCount)),
OperatorField(&DML_MAX_POOLING_GRAD_OPERATOR_SCHEMA.Fields[7], ToOperatorFieldType(static_cast<const UINT*>(desc.EndPadding), desc.DimensionCount)),
OperatorField(&DML_MAX_POOLING_GRAD_OPERATOR_SCHEMA.Fields[8], ToOperatorFieldType(static_cast<const UINT*>(desc.Dilations), desc.DimensionCount)),
};
}
inline std::vector<OperatorField> GetFields(const DML_RANDOM_GENERATOR_OPERATOR_DESC& desc)
@ -1169,6 +1294,25 @@ inline std::vector<OperatorField> GetFields(const DML_ROI_ALIGN_OPERATOR_DESC& d
OperatorField(&DML_ROI_ALIGN_OPERATOR_SCHEMA.Fields[10], ToOperatorFieldType(static_cast<UINT>(desc.MaximumSamplesPerOutput))),
};
}
inline std::vector<OperatorField> GetFields(const DML_ROI_ALIGN1_OPERATOR_DESC& desc)
{
return {
OperatorField(&DML_ROI_ALIGN1_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.InputTensor))),
OperatorField(&DML_ROI_ALIGN1_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.ROITensor))),
OperatorField(&DML_ROI_ALIGN1_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.BatchIndicesTensor))),
OperatorField(&DML_ROI_ALIGN1_OPERATOR_SCHEMA.Fields[3], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.OutputTensor))),
OperatorField(&DML_ROI_ALIGN1_OPERATOR_SCHEMA.Fields[4], ToOperatorFieldType(static_cast<UINT>(desc.ReductionFunction))),
OperatorField(&DML_ROI_ALIGN1_OPERATOR_SCHEMA.Fields[5], ToOperatorFieldType(static_cast<UINT>(desc.InterpolationMode))),
OperatorField(&DML_ROI_ALIGN1_OPERATOR_SCHEMA.Fields[6], ToOperatorFieldType(static_cast<FLOAT>(desc.SpatialScaleX))),
OperatorField(&DML_ROI_ALIGN1_OPERATOR_SCHEMA.Fields[7], ToOperatorFieldType(static_cast<FLOAT>(desc.SpatialScaleY))),
OperatorField(&DML_ROI_ALIGN1_OPERATOR_SCHEMA.Fields[8], ToOperatorFieldType(static_cast<FLOAT>(desc.InputPixelOffset))),
OperatorField(&DML_ROI_ALIGN1_OPERATOR_SCHEMA.Fields[9], ToOperatorFieldType(static_cast<FLOAT>(desc.OutputPixelOffset))),
OperatorField(&DML_ROI_ALIGN1_OPERATOR_SCHEMA.Fields[10], ToOperatorFieldType(static_cast<FLOAT>(desc.OutOfBoundsInputValue))),
OperatorField(&DML_ROI_ALIGN1_OPERATOR_SCHEMA.Fields[11], ToOperatorFieldType(static_cast<UINT>(desc.MinimumSamplesPerOutput))),
OperatorField(&DML_ROI_ALIGN1_OPERATOR_SCHEMA.Fields[12], ToOperatorFieldType(static_cast<UINT>(desc.MaximumSamplesPerOutput))),
OperatorField(&DML_ROI_ALIGN1_OPERATOR_SCHEMA.Fields[13], ToOperatorFieldType(static_cast<UINT>(desc.AlignRegionsToCorners))),
};
}
inline std::vector<OperatorField> GetFields(const DML_GATHER_ND1_OPERATOR_DESC& desc)
{
return {
@ -1180,6 +1324,63 @@ inline std::vector<OperatorField> GetFields(const DML_GATHER_ND1_OPERATOR_DESC&
OperatorField(&DML_GATHER_ND1_OPERATOR_SCHEMA.Fields[5], ToOperatorFieldType(static_cast<UINT>(desc.BatchDimensionCount))),
};
}
inline std::vector<OperatorField> GetFields(const DML_DYNAMIC_QUANTIZE_LINEAR_OPERATOR_DESC& desc)
{
return {
OperatorField(&DML_DYNAMIC_QUANTIZE_LINEAR_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.InputTensor))),
OperatorField(&DML_DYNAMIC_QUANTIZE_LINEAR_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.OutputTensor))),
OperatorField(&DML_DYNAMIC_QUANTIZE_LINEAR_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.OutputScaleTensor))),
OperatorField(&DML_DYNAMIC_QUANTIZE_LINEAR_OPERATOR_SCHEMA.Fields[3], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.OutputZeroPointTensor))),
};
}
inline std::vector<OperatorField> GetFields(const DML_ELEMENT_WISE_QUANTIZED_LINEAR_ADD_OPERATOR_DESC& desc)
{
return {
OperatorField(&DML_ELEMENT_WISE_QUANTIZED_LINEAR_ADD_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.ATensor))),
OperatorField(&DML_ELEMENT_WISE_QUANTIZED_LINEAR_ADD_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.AScaleTensor))),
OperatorField(&DML_ELEMENT_WISE_QUANTIZED_LINEAR_ADD_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.AZeroPointTensor))),
OperatorField(&DML_ELEMENT_WISE_QUANTIZED_LINEAR_ADD_OPERATOR_SCHEMA.Fields[3], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.BTensor))),
OperatorField(&DML_ELEMENT_WISE_QUANTIZED_LINEAR_ADD_OPERATOR_SCHEMA.Fields[4], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.BScaleTensor))),
OperatorField(&DML_ELEMENT_WISE_QUANTIZED_LINEAR_ADD_OPERATOR_SCHEMA.Fields[5], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.BZeroPointTensor))),
OperatorField(&DML_ELEMENT_WISE_QUANTIZED_LINEAR_ADD_OPERATOR_SCHEMA.Fields[6], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.OutputScaleTensor))),
OperatorField(&DML_ELEMENT_WISE_QUANTIZED_LINEAR_ADD_OPERATOR_SCHEMA.Fields[7], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.OutputZeroPointTensor))),
OperatorField(&DML_ELEMENT_WISE_QUANTIZED_LINEAR_ADD_OPERATOR_SCHEMA.Fields[8], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.OutputTensor))),
};
}
inline std::vector<OperatorField> GetFields(const DML_ROI_ALIGN_GRAD_OPERATOR_DESC& desc)
{
return {
OperatorField(&DML_ROI_ALIGN_GRAD_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.InputTensor))),
OperatorField(&DML_ROI_ALIGN_GRAD_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.InputGradientTensor))),
OperatorField(&DML_ROI_ALIGN_GRAD_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.ROITensor))),
OperatorField(&DML_ROI_ALIGN_GRAD_OPERATOR_SCHEMA.Fields[3], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.BatchIndicesTensor))),
OperatorField(&DML_ROI_ALIGN_GRAD_OPERATOR_SCHEMA.Fields[4], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.OutputGradientTensor))),
OperatorField(&DML_ROI_ALIGN_GRAD_OPERATOR_SCHEMA.Fields[5], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.OutputROIGradientTensor))),
OperatorField(&DML_ROI_ALIGN_GRAD_OPERATOR_SCHEMA.Fields[6], ToOperatorFieldType(static_cast<UINT>(desc.ReductionFunction))),
OperatorField(&DML_ROI_ALIGN_GRAD_OPERATOR_SCHEMA.Fields[7], ToOperatorFieldType(static_cast<UINT>(desc.InterpolationMode))),
OperatorField(&DML_ROI_ALIGN_GRAD_OPERATOR_SCHEMA.Fields[8], ToOperatorFieldType(static_cast<FLOAT>(desc.SpatialScaleX))),
OperatorField(&DML_ROI_ALIGN_GRAD_OPERATOR_SCHEMA.Fields[9], ToOperatorFieldType(static_cast<FLOAT>(desc.SpatialScaleY))),
OperatorField(&DML_ROI_ALIGN_GRAD_OPERATOR_SCHEMA.Fields[10], ToOperatorFieldType(static_cast<FLOAT>(desc.InputPixelOffset))),
OperatorField(&DML_ROI_ALIGN_GRAD_OPERATOR_SCHEMA.Fields[11], ToOperatorFieldType(static_cast<FLOAT>(desc.OutputPixelOffset))),
OperatorField(&DML_ROI_ALIGN_GRAD_OPERATOR_SCHEMA.Fields[12], ToOperatorFieldType(static_cast<UINT>(desc.MinimumSamplesPerOutput))),
OperatorField(&DML_ROI_ALIGN_GRAD_OPERATOR_SCHEMA.Fields[13], ToOperatorFieldType(static_cast<UINT>(desc.MaximumSamplesPerOutput))),
OperatorField(&DML_ROI_ALIGN_GRAD_OPERATOR_SCHEMA.Fields[14], ToOperatorFieldType(static_cast<UINT>(desc.AlignRegionsToCorners))),
};
}
inline std::vector<OperatorField> GetFields(const DML_BATCH_NORMALIZATION_TRAINING_OPERATOR_DESC& desc)
{
return {
OperatorField(&DML_BATCH_NORMALIZATION_TRAINING_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.InputTensor))),
OperatorField(&DML_BATCH_NORMALIZATION_TRAINING_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.ScaleTensor))),
OperatorField(&DML_BATCH_NORMALIZATION_TRAINING_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.BiasTensor))),
OperatorField(&DML_BATCH_NORMALIZATION_TRAINING_OPERATOR_SCHEMA.Fields[3], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.FusedAddTensor))),
OperatorField(&DML_BATCH_NORMALIZATION_TRAINING_OPERATOR_SCHEMA.Fields[4], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.OutputTensor))),
OperatorField(&DML_BATCH_NORMALIZATION_TRAINING_OPERATOR_SCHEMA.Fields[5], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.OutputMeanTensor))),
OperatorField(&DML_BATCH_NORMALIZATION_TRAINING_OPERATOR_SCHEMA.Fields[6], ToOperatorFieldType(static_cast<const DML_TENSOR_DESC*>(desc.OutputVarianceTensor))),
OperatorField(&DML_BATCH_NORMALIZATION_TRAINING_OPERATOR_SCHEMA.Fields[7], ToOperatorFieldType(static_cast<FLOAT>(desc.Epsilon))),
OperatorField(&DML_BATCH_NORMALIZATION_TRAINING_OPERATOR_SCHEMA.Fields[8], ToOperatorFieldType(static_cast<const DML_OPERATOR_DESC*>(desc.FusedActivation))),
};
}
inline std::vector<OperatorField> GetFields(const DML_ACTIVATION_ELU_OPERATOR_DESC& desc)
{
return {
@ -1351,6 +1552,9 @@ inline const DML_OPERATOR_SCHEMA& GetSchema(DML_OPERATOR_TYPE operatorType)
case DML_OPERATOR_ELEMENT_WISE_ATAN: return DML_ELEMENT_WISE_ATAN_OPERATOR_SCHEMA;
case DML_OPERATOR_ELEMENT_WISE_CEIL: return DML_ELEMENT_WISE_CEIL_OPERATOR_SCHEMA;
case DML_OPERATOR_ELEMENT_WISE_CLIP: return DML_ELEMENT_WISE_CLIP_OPERATOR_SCHEMA;
case DML_OPERATOR_ELEMENT_WISE_CLIP1: return DML_ELEMENT_WISE_CLIP1_OPERATOR_SCHEMA;
case DML_OPERATOR_ELEMENT_WISE_CLIP_GRAD: return DML_ELEMENT_WISE_CLIP_GRAD_OPERATOR_SCHEMA;
case DML_OPERATOR_ELEMENT_WISE_CLIP_GRAD1: return DML_ELEMENT_WISE_CLIP_GRAD1_OPERATOR_SCHEMA;
case DML_OPERATOR_ELEMENT_WISE_COS: return DML_ELEMENT_WISE_COS_OPERATOR_SCHEMA;
case DML_OPERATOR_ELEMENT_WISE_DIVIDE: return DML_ELEMENT_WISE_DIVIDE_OPERATOR_SCHEMA;
case DML_OPERATOR_ELEMENT_WISE_EXP: return DML_ELEMENT_WISE_EXP_OPERATOR_SCHEMA;
@ -1374,6 +1578,8 @@ inline const DML_OPERATOR_SCHEMA& GetSchema(DML_OPERATOR_TYPE operatorType)
case DML_OPERATOR_ELEMENT_WISE_RECIP: return DML_ELEMENT_WISE_RECIP_OPERATOR_SCHEMA;
case DML_OPERATOR_ELEMENT_WISE_SIN: return DML_ELEMENT_WISE_SIN_OPERATOR_SCHEMA;
case DML_OPERATOR_ELEMENT_WISE_SQRT: return DML_ELEMENT_WISE_SQRT_OPERATOR_SCHEMA;
case DML_OPERATOR_ELEMENT_WISE_DIFFERENCE_SQUARE: return DML_ELEMENT_WISE_DIFFERENCE_SQUARE_OPERATOR_SCHEMA;
case DML_OPERATOR_ELEMENT_WISE_ATAN_YX: return DML_ELEMENT_WISE_ATAN_YX_OPERATOR_SCHEMA;
case DML_OPERATOR_ELEMENT_WISE_SUBTRACT: return DML_ELEMENT_WISE_SUBTRACT_OPERATOR_SCHEMA;
case DML_OPERATOR_ELEMENT_WISE_TAN: return DML_ELEMENT_WISE_TAN_OPERATOR_SCHEMA;
case DML_OPERATOR_ELEMENT_WISE_THRESHOLD: return DML_ELEMENT_WISE_THRESHOLD_OPERATOR_SCHEMA;
@ -1394,6 +1600,7 @@ inline const DML_OPERATOR_SCHEMA& GetSchema(DML_OPERATOR_TYPE operatorType)
case DML_OPERATOR_SPLIT: return DML_SPLIT_OPERATOR_SCHEMA;
case DML_OPERATOR_JOIN: return DML_JOIN_OPERATOR_SCHEMA;
case DML_OPERATOR_PADDING: return DML_PADDING_OPERATOR_SCHEMA;
case DML_OPERATOR_PADDING1: return DML_PADDING1_OPERATOR_SCHEMA;
case DML_OPERATOR_VALUE_SCALE_2D: return DML_VALUE_SCALE_2D_OPERATOR_SCHEMA;
case DML_OPERATOR_UPSAMPLE_2D: return DML_UPSAMPLE_2D_OPERATOR_SCHEMA;
case DML_OPERATOR_GATHER: return DML_GATHER_OPERATOR_SCHEMA;
@ -1402,14 +1609,18 @@ inline const DML_OPERATOR_SCHEMA& GetSchema(DML_OPERATOR_TYPE operatorType)
case DML_OPERATOR_TILE: return DML_TILE_OPERATOR_SCHEMA;
case DML_OPERATOR_TOP_K: return DML_TOP_K_OPERATOR_SCHEMA;
case DML_OPERATOR_BATCH_NORMALIZATION: return DML_BATCH_NORMALIZATION_OPERATOR_SCHEMA;
case DML_OPERATOR_BATCH_NORMALIZATION_GRAD: return DML_BATCH_NORMALIZATION_GRAD_OPERATOR_SCHEMA;
case DML_OPERATOR_BATCH_NORMALIZATION_TRAINING_GRAD: return DML_BATCH_NORMALIZATION_TRAINING_GRAD_OPERATOR_SCHEMA;
case DML_OPERATOR_MEAN_VARIANCE_NORMALIZATION: return DML_MEAN_VARIANCE_NORMALIZATION_OPERATOR_SCHEMA;
case DML_OPERATOR_LOCAL_RESPONSE_NORMALIZATION: return DML_LOCAL_RESPONSE_NORMALIZATION_OPERATOR_SCHEMA;
case DML_OPERATOR_LOCAL_RESPONSE_NORMALIZATION_GRAD: return DML_LOCAL_RESPONSE_NORMALIZATION_GRAD_OPERATOR_SCHEMA;
case DML_OPERATOR_LP_NORMALIZATION: return DML_LP_NORMALIZATION_OPERATOR_SCHEMA;
case DML_OPERATOR_RNN: return DML_RNN_OPERATOR_SCHEMA;
case DML_OPERATOR_LSTM: return DML_LSTM_OPERATOR_SCHEMA;
case DML_OPERATOR_GRU: return DML_GRU_OPERATOR_SCHEMA;
case DML_OPERATOR_ELEMENT_WISE_SIGN: return DML_ELEMENT_WISE_SIGN_OPERATOR_SCHEMA;
case DML_OPERATOR_ELEMENT_WISE_IS_NAN: return DML_ELEMENT_WISE_IS_NAN_OPERATOR_SCHEMA;
case DML_OPERATOR_ELEMENT_WISE_NEGATE: return DML_ELEMENT_WISE_NEGATE_OPERATOR_SCHEMA;
case DML_OPERATOR_ELEMENT_WISE_ERF: return DML_ELEMENT_WISE_ERF_OPERATOR_SCHEMA;
case DML_OPERATOR_ELEMENT_WISE_SINH: return DML_ELEMENT_WISE_SINH_OPERATOR_SCHEMA;
case DML_OPERATOR_ELEMENT_WISE_COSH: return DML_ELEMENT_WISE_COSH_OPERATOR_SCHEMA;
@ -1433,6 +1644,7 @@ inline const DML_OPERATOR_SCHEMA& GetSchema(DML_OPERATOR_TYPE operatorType)
case DML_OPERATOR_FILL_VALUE_CONSTANT: return DML_FILL_VALUE_CONSTANT_OPERATOR_SCHEMA;
case DML_OPERATOR_FILL_VALUE_SEQUENCE: return DML_FILL_VALUE_SEQUENCE_OPERATOR_SCHEMA;
case DML_OPERATOR_CUMULATIVE_SUMMATION: return DML_CUMULATIVE_SUMMATION_OPERATOR_SCHEMA;
case DML_OPERATOR_CUMULATIVE_PRODUCT: return DML_CUMULATIVE_PRODUCT_OPERATOR_SCHEMA;
case DML_OPERATOR_REVERSE_SUBSEQUENCES: return DML_REVERSE_SUBSEQUENCES_OPERATOR_SCHEMA;
case DML_OPERATOR_GATHER_ELEMENTS: return DML_GATHER_ELEMENTS_OPERATOR_SCHEMA;
case DML_OPERATOR_GATHER_ND: return DML_GATHER_ND_OPERATOR_SCHEMA;
@ -1462,7 +1674,12 @@ inline const DML_OPERATOR_SCHEMA& GetSchema(DML_OPERATOR_TYPE operatorType)
case DML_OPERATOR_SLICE_GRAD: return DML_SLICE_GRAD_OPERATOR_SCHEMA;
case DML_OPERATOR_ADAM_OPTIMIZER: return DML_ADAM_OPTIMIZER_OPERATOR_SCHEMA;
case DML_OPERATOR_ROI_ALIGN: return DML_ROI_ALIGN_OPERATOR_SCHEMA;
case DML_OPERATOR_ROI_ALIGN1: return DML_ROI_ALIGN1_OPERATOR_SCHEMA;
case DML_OPERATOR_GATHER_ND1: return DML_GATHER_ND1_OPERATOR_SCHEMA;
case DML_OPERATOR_DYNAMIC_QUANTIZE_LINEAR: return DML_DYNAMIC_QUANTIZE_LINEAR_OPERATOR_SCHEMA;
case DML_OPERATOR_ELEMENT_WISE_QUANTIZED_LINEAR_ADD: return DML_ELEMENT_WISE_QUANTIZED_LINEAR_ADD_OPERATOR_SCHEMA;
case DML_OPERATOR_ROI_ALIGN_GRAD: return DML_ROI_ALIGN_GRAD_OPERATOR_SCHEMA;
case DML_OPERATOR_BATCH_NORMALIZATION_TRAINING: return DML_BATCH_NORMALIZATION_TRAINING_OPERATOR_SCHEMA;
case DML_OPERATOR_ACTIVATION_ELU: return DML_ACTIVATION_ELU_OPERATOR_SCHEMA;
case DML_OPERATOR_ACTIVATION_CELU: return DML_ACTIVATION_CELU_OPERATOR_SCHEMA;
case DML_OPERATOR_ACTIVATION_HARDMAX: return DML_ACTIVATION_HARDMAX_OPERATOR_SCHEMA;
@ -1528,6 +1745,18 @@ inline AbstractOperatorDesc ConvertOperatorDesc(const DML_OPERATOR_DESC& opDesc)
return AbstractOperatorDesc(
&DML_ELEMENT_WISE_CLIP_OPERATOR_SCHEMA,
GetFields(*static_cast<const DML_ELEMENT_WISE_CLIP_OPERATOR_DESC*>(opDesc.Desc)));
case DML_OPERATOR_ELEMENT_WISE_CLIP1:
return AbstractOperatorDesc(
&DML_ELEMENT_WISE_CLIP1_OPERATOR_SCHEMA,
GetFields(*static_cast<const DML_ELEMENT_WISE_CLIP1_OPERATOR_DESC*>(opDesc.Desc)));
case DML_OPERATOR_ELEMENT_WISE_CLIP_GRAD:
return AbstractOperatorDesc(
&DML_ELEMENT_WISE_CLIP_GRAD_OPERATOR_SCHEMA,
GetFields(*static_cast<const DML_ELEMENT_WISE_CLIP_GRAD_OPERATOR_DESC*>(opDesc.Desc)));
case DML_OPERATOR_ELEMENT_WISE_CLIP_GRAD1:
return AbstractOperatorDesc(
&DML_ELEMENT_WISE_CLIP_GRAD1_OPERATOR_SCHEMA,
GetFields(*static_cast<const DML_ELEMENT_WISE_CLIP_GRAD1_OPERATOR_DESC*>(opDesc.Desc)));
case DML_OPERATOR_ELEMENT_WISE_COS:
return AbstractOperatorDesc(
&DML_ELEMENT_WISE_COS_OPERATOR_SCHEMA,
@ -1620,6 +1849,14 @@ inline AbstractOperatorDesc ConvertOperatorDesc(const DML_OPERATOR_DESC& opDesc)
return AbstractOperatorDesc(
&DML_ELEMENT_WISE_SQRT_OPERATOR_SCHEMA,
GetFields(*static_cast<const DML_ELEMENT_WISE_SQRT_OPERATOR_DESC*>(opDesc.Desc)));
case DML_OPERATOR_ELEMENT_WISE_DIFFERENCE_SQUARE:
return AbstractOperatorDesc(
&DML_ELEMENT_WISE_DIFFERENCE_SQUARE_OPERATOR_SCHEMA,
GetFields(*static_cast<const DML_ELEMENT_WISE_DIFFERENCE_SQUARE_OPERATOR_DESC*>(opDesc.Desc)));
case DML_OPERATOR_ELEMENT_WISE_ATAN_YX:
return AbstractOperatorDesc(
&DML_ELEMENT_WISE_ATAN_YX_OPERATOR_SCHEMA,
GetFields(*static_cast<const DML_ELEMENT_WISE_ATAN_YX_OPERATOR_DESC*>(opDesc.Desc)));
case DML_OPERATOR_ELEMENT_WISE_SUBTRACT:
return AbstractOperatorDesc(
&DML_ELEMENT_WISE_SUBTRACT_OPERATOR_SCHEMA,
@ -1700,6 +1937,10 @@ inline AbstractOperatorDesc ConvertOperatorDesc(const DML_OPERATOR_DESC& opDesc)
return AbstractOperatorDesc(
&DML_PADDING_OPERATOR_SCHEMA,
GetFields(*static_cast<const DML_PADDING_OPERATOR_DESC*>(opDesc.Desc)));
case DML_OPERATOR_PADDING1:
return AbstractOperatorDesc(
&DML_PADDING1_OPERATOR_SCHEMA,
GetFields(*static_cast<const DML_PADDING1_OPERATOR_DESC*>(opDesc.Desc)));
case DML_OPERATOR_VALUE_SCALE_2D:
return AbstractOperatorDesc(
&DML_VALUE_SCALE_2D_OPERATOR_SCHEMA,
@ -1732,6 +1973,14 @@ inline AbstractOperatorDesc ConvertOperatorDesc(const DML_OPERATOR_DESC& opDesc)
return AbstractOperatorDesc(
&DML_BATCH_NORMALIZATION_OPERATOR_SCHEMA,
GetFields(*static_cast<const DML_BATCH_NORMALIZATION_OPERATOR_DESC*>(opDesc.Desc)));
case DML_OPERATOR_BATCH_NORMALIZATION_GRAD:
return AbstractOperatorDesc(
&DML_BATCH_NORMALIZATION_GRAD_OPERATOR_SCHEMA,
GetFields(*static_cast<const DML_BATCH_NORMALIZATION_GRAD_OPERATOR_DESC*>(opDesc.Desc)));
case DML_OPERATOR_BATCH_NORMALIZATION_TRAINING_GRAD:
return AbstractOperatorDesc(
&DML_BATCH_NORMALIZATION_TRAINING_GRAD_OPERATOR_SCHEMA,
GetFields(*static_cast<const DML_BATCH_NORMALIZATION_TRAINING_GRAD_OPERATOR_DESC*>(opDesc.Desc)));
case DML_OPERATOR_MEAN_VARIANCE_NORMALIZATION:
return AbstractOperatorDesc(
&DML_MEAN_VARIANCE_NORMALIZATION_OPERATOR_SCHEMA,
@ -1740,6 +1989,10 @@ inline AbstractOperatorDesc ConvertOperatorDesc(const DML_OPERATOR_DESC& opDesc)
return AbstractOperatorDesc(
&DML_LOCAL_RESPONSE_NORMALIZATION_OPERATOR_SCHEMA,
GetFields(*static_cast<const DML_LOCAL_RESPONSE_NORMALIZATION_OPERATOR_DESC*>(opDesc.Desc)));
case DML_OPERATOR_LOCAL_RESPONSE_NORMALIZATION_GRAD:
return AbstractOperatorDesc(
&DML_LOCAL_RESPONSE_NORMALIZATION_GRAD_OPERATOR_SCHEMA,
GetFields(*static_cast<const DML_LOCAL_RESPONSE_NORMALIZATION_GRAD_OPERATOR_DESC*>(opDesc.Desc)));
case DML_OPERATOR_LP_NORMALIZATION:
return AbstractOperatorDesc(
&DML_LP_NORMALIZATION_OPERATOR_SCHEMA,
@ -1764,6 +2017,10 @@ inline AbstractOperatorDesc ConvertOperatorDesc(const DML_OPERATOR_DESC& opDesc)
return AbstractOperatorDesc(
&DML_ELEMENT_WISE_IS_NAN_OPERATOR_SCHEMA,
GetFields(*static_cast<const DML_ELEMENT_WISE_IS_NAN_OPERATOR_DESC*>(opDesc.Desc)));
case DML_OPERATOR_ELEMENT_WISE_NEGATE:
return AbstractOperatorDesc(
&DML_ELEMENT_WISE_NEGATE_OPERATOR_SCHEMA,
GetFields(*static_cast<const DML_ELEMENT_WISE_NEGATE_OPERATOR_DESC*>(opDesc.Desc)));
case DML_OPERATOR_ELEMENT_WISE_ERF:
return AbstractOperatorDesc(
&DML_ELEMENT_WISE_ERF_OPERATOR_SCHEMA,
@ -1856,6 +2113,10 @@ inline AbstractOperatorDesc ConvertOperatorDesc(const DML_OPERATOR_DESC& opDesc)
return AbstractOperatorDesc(
&DML_CUMULATIVE_SUMMATION_OPERATOR_SCHEMA,
GetFields(*static_cast<const DML_CUMULATIVE_SUMMATION_OPERATOR_DESC*>(opDesc.Desc)));
case DML_OPERATOR_CUMULATIVE_PRODUCT:
return AbstractOperatorDesc(
&DML_CUMULATIVE_PRODUCT_OPERATOR_SCHEMA,
GetFields(*static_cast<const DML_CUMULATIVE_PRODUCT_OPERATOR_DESC*>(opDesc.Desc)));
case DML_OPERATOR_REVERSE_SUBSEQUENCES:
return AbstractOperatorDesc(
&DML_REVERSE_SUBSEQUENCES_OPERATOR_SCHEMA,
@ -1972,10 +2233,30 @@ inline AbstractOperatorDesc ConvertOperatorDesc(const DML_OPERATOR_DESC& opDesc)
return AbstractOperatorDesc(
&DML_ROI_ALIGN_OPERATOR_SCHEMA,
GetFields(*static_cast<const DML_ROI_ALIGN_OPERATOR_DESC*>(opDesc.Desc)));
case DML_OPERATOR_ROI_ALIGN1:
return AbstractOperatorDesc(
&DML_ROI_ALIGN1_OPERATOR_SCHEMA,
GetFields(*static_cast<const DML_ROI_ALIGN1_OPERATOR_DESC*>(opDesc.Desc)));
case DML_OPERATOR_GATHER_ND1:
return AbstractOperatorDesc(
&DML_GATHER_ND1_OPERATOR_SCHEMA,
GetFields(*static_cast<const DML_GATHER_ND1_OPERATOR_DESC*>(opDesc.Desc)));
case DML_OPERATOR_DYNAMIC_QUANTIZE_LINEAR:
return AbstractOperatorDesc(
&DML_DYNAMIC_QUANTIZE_LINEAR_OPERATOR_SCHEMA,
GetFields(*static_cast<const DML_DYNAMIC_QUANTIZE_LINEAR_OPERATOR_DESC*>(opDesc.Desc)));
case DML_OPERATOR_ELEMENT_WISE_QUANTIZED_LINEAR_ADD:
return AbstractOperatorDesc(
&DML_ELEMENT_WISE_QUANTIZED_LINEAR_ADD_OPERATOR_SCHEMA,
GetFields(*static_cast<const DML_ELEMENT_WISE_QUANTIZED_LINEAR_ADD_OPERATOR_DESC*>(opDesc.Desc)));
case DML_OPERATOR_ROI_ALIGN_GRAD:
return AbstractOperatorDesc(
&DML_ROI_ALIGN_GRAD_OPERATOR_SCHEMA,
GetFields(*static_cast<const DML_ROI_ALIGN_GRAD_OPERATOR_DESC*>(opDesc.Desc)));
case DML_OPERATOR_BATCH_NORMALIZATION_TRAINING:
return AbstractOperatorDesc(
&DML_BATCH_NORMALIZATION_TRAINING_OPERATOR_SCHEMA,
GetFields(*static_cast<const DML_BATCH_NORMALIZATION_TRAINING_OPERATOR_DESC*>(opDesc.Desc)));
case DML_OPERATOR_ACTIVATION_ELU:
return AbstractOperatorDesc(
&DML_ACTIVATION_ELU_OPERATOR_SCHEMA,

View file

@ -139,29 +139,6 @@ namespace Dml::GraphDescBuilder
&graphNodeInfo
);
// Determine the number of valid inputs and outputs of this node. The graph currently supports opererators
// with unused inputs and outputs only at the end of each list.
uint32_t validOpInputCount = 0;
uint32_t validOpOutputCount = 0;
for (uint32_t i = 0; i < graphNodeInfo.kernelInputIndices.size(); ++i)
{
if (graphNodeInfo.kernelInputIndices[i] != std::numeric_limits<uint32_t>::max())
{
assert(i - validOpInputCount == 0);
++validOpInputCount;
}
}
for (uint32_t i = 0; i < graphNodeInfo.kernelOutputIndices.size(); ++i)
{
if (graphNodeInfo.kernelOutputIndices[i] != std::numeric_limits<uint32_t>::max())
{
assert(i - validOpOutputCount == 0);
++validOpOutputCount;
}
}
uint32_t nodeIndex = gsl::narrow_cast<uint32_t>(graphNodes.size());
AbstractOperatorDesc opDesc = *graphNodeInfo.desc; // Make a copy
@ -171,8 +148,13 @@ namespace Dml::GraphDescBuilder
std::vector<DmlBufferTensorDesc*> outputTensorDescs = opDesc.GetOutputTensors();
// Set connections of the new node
for (uint32_t inputIndex = 0; inputIndex < validOpInputCount; ++inputIndex)
for (uint32_t inputIndex = 0; inputIndex < graphNodeInfo.kernelInputIndices.size(); ++inputIndex)
{
if (graphNodeInfo.kernelInputIndices[inputIndex] == std::numeric_limits<uint32_t>::max())
{
continue;
}
uint32_t kernelInputIndex = graphNodeInfo.kernelInputIndices[inputIndex];
const onnxruntime::NodeArg* arg = node.InputDefs()[kernelInputIndex];
@ -224,8 +206,13 @@ namespace Dml::GraphDescBuilder
// Store the new node for lookup when downstream nodes consume it.
for (uint32_t outputIndex = 0; outputIndex < validOpOutputCount; ++outputIndex)
for (uint32_t outputIndex = 0; outputIndex < graphNodeInfo.kernelOutputIndices.size(); ++outputIndex)
{
if (graphNodeInfo.kernelOutputIndices[outputIndex] == std::numeric_limits<uint32_t>::max())
{
continue;
}
uint32_t kernelOutputIndex = graphNodeInfo.kernelOutputIndices[outputIndex];
const onnxruntime::NodeArg* arg = node.OutputDefs()[kernelOutputIndex];
if (arg->Exists())

View file

@ -139,7 +139,11 @@ namespace Dml
}
};
bool NodeArgSupportedInGraph(const onnxruntime::NodeArg* arg, uint32_t supportedDeviceDataTypeMask)
bool NodeArgSupportedInGraph(
const onnxruntime::NodeArg* arg,
bool supports64BitTensorsViaEmulation,
uint32_t supportedDeviceDataTypeMask
)
{
if (arg->Exists())
{
@ -154,8 +158,11 @@ namespace Dml
MLOperatorTensorDataType mlDataType = ToMLTensorDataType(static_cast<onnx::TensorProto_DataType>(tensorType.elem_type()));
if (mlDataType == MLOperatorTensorDataType::UInt64 ||
mlDataType == MLOperatorTensorDataType::Int64)
// Do not include operators in the graph if tensor types are unsupported,
// except cases that are always supported via emulation.
if ((mlDataType == MLOperatorTensorDataType::UInt64 ||
mlDataType == MLOperatorTensorDataType::Int64) &&
!supports64BitTensorsViaEmulation)
{
constexpr uint32_t deviceDataTypeMask64bit = (1 << DML_TENSOR_DATA_TYPE_UINT64) | (1 << DML_TENSOR_DATA_TYPE_INT64);
if ((supportedDeviceDataTypeMask & deviceDataTypeMask64bit) != deviceDataTypeMask64bit)
@ -181,6 +188,7 @@ namespace Dml
if (!isConstantCpuInput &&
!NodeArgSupportedInGraph(
node.InputDefs()[i],
registration.support64BitTensorsViaEmulation,
supportedDeviceDataTypeMask
))
{
@ -192,6 +200,7 @@ namespace Dml
{
if (!NodeArgSupportedInGraph(
arg,
registration.support64BitTensorsViaEmulation,
supportedDeviceDataTypeMask
))
{
@ -234,6 +243,7 @@ namespace Dml
ORT_THROW_HR_IF(E_INVALIDARG, allow64BitInputThroughStrides && !nodeNameToPartitionMap);
bool prefer64BitTensorsDirectly = false;
bool support64BitTensorsViaEmulation = false;
bool supportedWith64BitTensorsVia32BitStrides = false;
bool supportedWith64BitTensorsVia32BitStridesFromAnyEp = false;
std::vector<onnxruntime::NodeArg const*> constantCpuInputs;
@ -244,9 +254,10 @@ namespace Dml
// to 32-bit tensors via strides. If the caller passes allow64BitInputThroughStrides = false
// in this particular call, then the operator-specific flags do not matter as the caller has
// disabled 64-bit support.
prefer64BitTensorsDirectly = regInfo->prefer64BitTensorsDirectly;
support64BitTensorsViaEmulation = regInfo->support64BitTensorsViaEmulation;
if (allow64BitInputThroughStrides)
{
prefer64BitTensorsDirectly = regInfo->prefer64BitTensorsDirectly;
supportedWith64BitTensorsVia32BitStridesFromAnyEp = regInfo->supportedWith64BitTensorsVia32BitStridesFromAnyEp;
supportedWith64BitTensorsVia32BitStrides = regInfo->supportedWith64BitTensorsVia32BitStrides | supportedWith64BitTensorsVia32BitStridesFromAnyEp;
}
@ -320,31 +331,44 @@ namespace Dml
// operator, graph input or initializer, it's not safe to assume the input
// can be represented with 32 bits.
//
bool isDataTypeSupported = (1 << dmlElementType) & supportedDeviceDataTypeMask;
bool is64BitIntType = (dmlElementType == DML_TENSOR_DATA_TYPE_UINT64 || dmlElementType == DML_TENSOR_DATA_TYPE_INT64);
bool needsFallbackTo32Bit = !prefer64BitTensorsDirectly || !((1 << dmlElementType) & supportedDeviceDataTypeMask);
if (is64BitIntType && supportedWith64BitTensorsVia32BitStrides && needsFallbackTo32Bit)
if (is64BitIntType)
{
dmlElementType = Remap64bitDmlDataTypeTo32bit(dmlElementType);
if (isInput && !supportedWith64BitTensorsVia32BitStridesFromAnyEp)
if (support64BitTensorsViaEmulation)
{
// Look up the input partition. If it's a graph input or initializer it will be missing
// from the partition map.
const std::string& argName = nodeArg.Name();
// Consider it supported regardless of hardware support.
isDataTypeSupported = true;
}
else if (prefer64BitTensorsDirectly && isDataTypeSupported)
{
// Operator supports native int64/uint64 tensors.
}
else if (supportedWith64BitTensorsVia32BitStrides || supportedWith64BitTensorsVia32BitStridesFromAnyEp)
{
dmlElementType = Remap64bitDmlDataTypeTo32bit(dmlElementType);
isDataTypeSupported = (1 << dmlElementType) & supportedDeviceDataTypeMask;
// If input tensor's data comes from the output of a different execution provider,
// consider it unsafe to apply fallback to.
auto partitionIter = nodeNameToPartitionMap->find(argName);
if (partitionIter == nodeNameToPartitionMap->end() || !partitionIter->second->IsDmlPartition())
if (isInput && !supportedWith64BitTensorsVia32BitStridesFromAnyEp)
{
nodeContainsSupportedDataTypes = false;
return;
// Look up the input partition. If it's a graph input or initializer it will be missing
// from the partition map.
const std::string& argName = nodeArg.Name();
// If input tensor's data comes from the output of a different execution provider,
// consider it unsafe to apply fallback to.
auto partitionIter = nodeNameToPartitionMap->find(argName);
if (partitionIter == nodeNameToPartitionMap->end() || !partitionIter->second->IsDmlPartition())
{
nodeContainsSupportedDataTypes = false;
return;
}
}
}
}
// Reject node if the data type is unsupported by the device.
if (!((1 << dmlElementType) & supportedDeviceDataTypeMask))
if (!isDataTypeSupported)
{
nodeContainsSupportedDataTypes = false;
return;

View file

@ -22,6 +22,13 @@ public:
std::vector<std::optional<uint32_t>> outputIndices = { 0 };
Initialize(kernelCreationContext, inputIndices, outputIndices);
std::vector<DML_TENSOR_DESC> outputDescs = GetDmlOutputDescs();
DML_FILL_VALUE_CONSTANT_OPERATOR_DESC operatorDesc = {};
operatorDesc.OutputTensor = outputDescs.data();
operatorDesc.ValueDataType = this->m_outputTensorDescs.front().GetDmlDataType();
// operatorDesc.Value already zeroed.
// Read the tensor attribute for the output fill pattern.
if (kernelCreationContext.HasAttribute(AttrName::Value, MLOperatorAttributeTypeTensor))
{
@ -40,15 +47,14 @@ public:
ML_CHECK_VALID_ARGUMENT(elementCount == 1); // Expect exactly one element.
const size_t rawDataByteSize = GetByteSizeFromMlDataType(wrappedValueTensor.GetTensorDataType());
const std::byte* rawData = static_cast<const std::byte*>(valueTensor->GetData());
valueBytes.assign(rawData, rawData + rawDataByteSize);
memcpy(operatorDesc.Value.Bytes, rawData, std::min(rawDataByteSize, sizeof(operatorDesc.Value.Bytes)));
}
// Else valueBytes is empty, and the default fill pattern is 0.
}
void Compute(const MLOperatorKernelContext& kernelContext) override
{
std::vector<IMLOperatorTensor*> outputTensors = GetOutputTensorsForExecute(kernelContext);
ORT_THROW_IF_FAILED(m_executionProvider->FillTensorWithPattern(outputTensors.front(), valueBytes));
DML_OPERATOR_DESC opDesc = { DML_OPERATOR_FILL_VALUE_CONSTANT, &operatorDesc };
SetDmlOperatorDesc(opDesc, kernelCreationContext);
}
private:

View file

@ -7,30 +7,45 @@ namespace Dml
{
class DmlOperatorDynamicQuantizeLinear : public DmlOperator
{
{
enum DmlInputTensors {
IN_A,
};
enum DmlOutputTensors {
OUT_Y,
OUT_Y_SCALE,
OUT_Y_ZERO_POINT
};
public:
using Self = DmlOperatorDynamicQuantizeLinear;
DmlOperatorDynamicQuantizeLinear(const MLOperatorKernelCreationContext& kernelCreationContext)
: DmlOperator(kernelCreationContext)
{
#if 0 // TODO:NickFe - https://github.com/onnx/onnx/blob/master/docs/Operators.md#dynamicquantizelinear
ML_CHECK_VALID_ARGUMENT(kernelCreationContext.GetInputCount() == 1);
ML_CHECK_VALID_ARGUMENT(kernelCreationContext.GetOutputCount() == 3);
DmlOperator::Initialize(kernelCreationContext, inputIndices, outputIndices);
DmlOperator::Initialize(kernelCreationContext);
m_inputTensorDescs[IN_A] = CreateTensorDescFromInput(kernelCreationContext, 0/*A OnnxIndex*/, TensorAxis::DoNotCoerce, TensorAxis::W, TensorAxis::RightAligned);
m_outputTensorDescs[OUT_Y] = CreateTensorDescFromOutput(kernelCreationContext, 0/*Y OnnxIndex*/, TensorAxis::DoNotCoerce, TensorAxis::W, TensorAxis::RightAligned);
m_outputTensorDescs[OUT_Y_SCALE] = CreateTensorDescFromOutput(kernelCreationContext, 1/*Y Scale OnnxIndex*/, TensorAxis::DoNotCoerce, TensorAxis::W, TensorAxis::RightAligned);
m_outputTensorDescs[OUT_Y_ZERO_POINT] = CreateTensorDescFromOutput(kernelCreationContext, 2/*Y Zero point OnnxIndex*/, TensorAxis::DoNotCoerce, TensorAxis::W, TensorAxis::RightAligned);
std::vector<DML_TENSOR_DESC> inputDescs = GetDmlInputDescs();
std::vector<DML_TENSOR_DESC> outputDescs = GetDmlOutputDescs();
DML_PLACEHOLDER_OPERATOR_DESC operatorDesc = {};
operatorDesc.IndicesTensor = &inputDescs[0];
operatorDesc.ValuesTensor = &inputDescs[1];
operatorDesc.OutputTensor = outputDescs.data();
operatorDesc.Axis = dmlAxis;
DML_DYNAMIC_QUANTIZE_LINEAR_OPERATOR_DESC operatorDesc = {};
operatorDesc.InputTensor = &inputDescs[IN_A];
operatorDesc.OutputTensor = &outputDescs[OUT_Y];
operatorDesc.OutputScaleTensor = &outputDescs[OUT_Y_SCALE];
operatorDesc.OutputZeroPointTensor = &outputDescs[OUT_Y_ZERO_POINT];
DML_OPERATOR_DESC opDesc = { DML_OPERATOR_PLACEHOLDER, &operatorDesc };
DML_OPERATOR_DESC opDesc = { DML_OPERATOR_DYNAMIC_QUANTIZE_LINEAR, &operatorDesc };
SetDmlOperatorDesc(opDesc, kernelCreationContext);
#endif
}
};

View file

@ -442,28 +442,34 @@ public:
std::vector<DML_TENSOR_DESC> inputDescs = GetDmlInputDescs();
std::vector<DML_TENSOR_DESC> outputDescs = GetDmlOutputDescs();
float minValue = -FLT_MAX;
float maxValue = FLT_MAX;
if (kernelInfo.IsInputValid(1))
{
minValue = static_cast<float>(ReadScalarTensorCastToFloat64(kernelInfo.GetConstantInputTensor(1)));
}
if (kernelInfo.IsInputValid(2)) {
maxValue = static_cast<float>(ReadScalarTensorCastToFloat64(kernelInfo.GetConstantInputTensor(2)));
}
DML_ELEMENT_WISE_CLIP_OPERATOR_DESC opDesc = {};
DML_ELEMENT_WISE_CLIP1_OPERATOR_DESC opDesc = {};
opDesc.InputTensor = inputDescs.data();
opDesc.OutputTensor = outputDescs.data();
opDesc.Min = minValue;
opDesc.Max = maxValue;
// MinMaxDataType will always be equal to inputDataTensorDataType
// Assigning minMaxDataType to inputDataTensorDataType because this field
// has to be assigned even if program does not go through below conditional
// logic for some corner test case
// Same applies to min and max value.
opDesc.MinMaxDataType = this->m_inputTensorDescs[0].GetDmlDataType();
CastToClampedScalarUnion<double>(opDesc.MinMaxDataType, -DBL_MAX, /*out*/&opDesc.Min);
CastToClampedScalarUnion<double>(opDesc.MinMaxDataType, DBL_MAX, /*out*/&opDesc.Max);
SetDmlOperatorDesc({ DML_OPERATOR_ELEMENT_WISE_CLIP, &opDesc}, kernelInfo);
if (kernelInfo.IsInputValid(1))
{
ReadScalarTensorData(kernelInfo.GetConstantInputTensor(1), /*out*/ &opDesc.Min.Bytes, sizeof(opDesc.Min.Bytes));
}
if (kernelInfo.IsInputValid(2))
{
ReadScalarTensorData(kernelInfo.GetConstantInputTensor(2), /*out*/ &opDesc.Max.Bytes, sizeof(opDesc.Max.Bytes));
}
SetDmlOperatorDesc({ DML_OPERATOR_ELEMENT_WISE_CLIP1, &opDesc}, kernelInfo);
}
};
// Same operator signature as 11. Only difference is new type support
using DmlOperatorElementwiseClip12 = DmlOperatorElementwiseClip11;
using DmlOperatorElementwiseClip13 = DmlOperatorElementwiseClip11;
class DmlOperatorElementwisePow : public DmlOperator
{
@ -692,7 +698,7 @@ DML_OP_DEFINE_CREATION_FUNCTION(Ceil, DmlOperatorElementwiseUnary<DM
DML_OP_DEFINE_CREATION_FUNCTION(Floor, DmlOperatorElementwiseUnary<DML_ELEMENT_WISE_FLOOR_OPERATOR_DESC>);
DML_OP_DEFINE_CREATION_FUNCTION(Not, DmlOperatorElementwiseUnary<DML_ELEMENT_WISE_LOGICAL_NOT_OPERATOR_DESC>);
DML_OP_DEFINE_CREATION_FUNCTION(Sign, DmlOperatorElementwiseUnary<DML_ELEMENT_WISE_SIGN_OPERATOR_DESC>);
DML_OP_DEFINE_CREATION_FUNCTION(IsNan, DmlOperatorElementwiseUnary<DML_ELEMENT_WISE_IS_NAN_OPERATOR_DESC>);
DML_OP_DEFINE_CREATION_FUNCTION(IsNaN, DmlOperatorElementwiseUnary<DML_ELEMENT_WISE_IS_NAN_OPERATOR_DESC>);
DML_OP_DEFINE_CREATION_FUNCTION(Sinh, DmlOperatorElementwiseUnary<DML_ELEMENT_WISE_SINH_OPERATOR_DESC>);
DML_OP_DEFINE_CREATION_FUNCTION(Cosh, DmlOperatorElementwiseUnary<DML_ELEMENT_WISE_COSH_OPERATOR_DESC>);
DML_OP_DEFINE_CREATION_FUNCTION(Asinh, DmlOperatorElementwiseUnary<DML_ELEMENT_WISE_ASINH_OPERATOR_DESC>);
@ -724,6 +730,7 @@ DML_OP_DEFINE_CREATION_FUNCTION(Mean, DmlOperatorElementwiseMean);
DML_OP_DEFINE_CREATION_FUNCTION(Clip7, DmlOperatorElementwiseClip7);
DML_OP_DEFINE_CREATION_FUNCTION(Clip11, DmlOperatorElementwiseClip11);
DML_OP_DEFINE_CREATION_FUNCTION(Clip12, DmlOperatorElementwiseClip12);
DML_OP_DEFINE_CREATION_FUNCTION(Clip13, DmlOperatorElementwiseClip13);
DML_OP_DEFINE_CREATION_FUNCTION(Pow, DmlOperatorElementwisePow);
DML_OP_DEFINE_CREATION_FUNCTION(QuantizeLinear, DmlOperatorElementwiseQLinear<DML_ELEMENT_WISE_QUANTIZE_LINEAR_OPERATOR_DESC>);
DML_OP_DEFINE_CREATION_FUNCTION(DequantizeLinear, DmlOperatorElementwiseQLinear<DML_ELEMENT_WISE_DEQUANTIZE_LINEAR_OPERATOR_DESC>);

View file

@ -22,16 +22,11 @@ public:
std::vector<DML_TENSOR_DESC> inputDescs = GetDmlInputDescs();
std::vector<DML_TENSOR_DESC> outputDescs = GetDmlOutputDescs();
DML_SCALE_BIAS scaleBias = {};
scaleBias.Scale = -1.0f;
scaleBias.Bias = 0.0f;
DML_ELEMENT_WISE_IDENTITY_OPERATOR_DESC opDesc = {};
DML_ELEMENT_WISE_NEGATE_OPERATOR_DESC opDesc = {};
opDesc.InputTensor = inputDescs.data();
opDesc.OutputTensor = outputDescs.data();
opDesc.ScaleBias = &scaleBias;
SetDmlOperatorDesc({ DML_OPERATOR_ELEMENT_WISE_IDENTITY, &opDesc}, kernelInfo);
SetDmlOperatorDesc({ DML_OPERATOR_ELEMENT_WISE_NEGATE, &opDesc}, kernelInfo);
}
};

View file

@ -56,34 +56,40 @@ public:
ML_INVALID_ARGUMENT("Unknown Pad mode attribute.");
}
std::vector<DML_TENSOR_DESC> inputDescs = GetDmlInputDescs();
std::vector<DML_TENSOR_DESC> outputDescs = GetDmlOutputDescs();
DML_PADDING1_OPERATOR_DESC paddingDesc = {};
paddingDesc.InputTensor = inputDescs.data();
paddingDesc.OutputTensor = outputDescs.data();
paddingDesc.PaddingMode = mode;
paddingDesc.DimensionCount = gsl::narrow_cast<uint32_t>(m_startPadding.size());
paddingDesc.StartPadding = m_startPadding.data();
paddingDesc.EndPadding = m_endPadding.data();
// PaddingValueDataType will always be equal to inputDataTensorDataType
// Assigning paddingValueDataType to inputDataTensorDataType because this field
// has to be assigned even if program does not go through below conditional
// logic for some corner test case (like opsetVersion >= 11, but no validInput at index 2)
// Same applies to paddingValue.
paddingDesc.PaddingValueDataType = this->m_inputTensorDescs[0].GetDmlDataType();
CastToClampedScalarUnion<float>(paddingDesc.PaddingValueDataType, 0.0f, /*out*/&paddingDesc.PaddingValue);
// Read the constant value which can come from an attribute or tensor.
float value = 0.0f;
if (opsetVersion >= 11)
{
if (kernelInfo.IsInputValid(2))
{
auto valueTensor = kernelInfo.GetConstantInputTensor(2);
value = static_cast<float>(ReadScalarTensorCastToFloat64(valueTensor));
MLOperatorTensor constantPaddingValueTensor = kernelInfo.GetConstantInputTensor(2);
ReadScalarTensorData(constantPaddingValueTensor, /*out*/ &paddingDesc.PaddingValue.Bytes, sizeof(paddingDesc.PaddingValue.Bytes));
}
}
else
{
value = kernelInfo.GetOptionalAttribute<float>(AttrName::Value, 0.0f);
auto value = kernelInfo.GetOptionalAttribute<float>(AttrName::Value, 0.0f);
CastToClampedScalarUnion<float>(paddingDesc.PaddingValueDataType, value, /*out*/&paddingDesc.PaddingValue);
}
std::vector<DML_TENSOR_DESC> inputDescs = GetDmlInputDescs();
std::vector<DML_TENSOR_DESC> outputDescs = GetDmlOutputDescs();
DML_PADDING_OPERATOR_DESC paddingDesc = {};
paddingDesc.InputTensor = inputDescs.data();
paddingDesc.OutputTensor = outputDescs.data();
paddingDesc.PaddingMode = mode;
paddingDesc.PaddingValue = value;
paddingDesc.DimensionCount = gsl::narrow_cast<uint32_t>(m_startPadding.size());
paddingDesc.StartPadding = m_startPadding.data();
paddingDesc.EndPadding = m_endPadding.data();
DML_OPERATOR_DESC opDesc = { DML_OPERATOR_PADDING, &paddingDesc };
DML_OPERATOR_DESC opDesc = { DML_OPERATOR_PADDING1, &paddingDesc };
SetDmlOperatorDesc(opDesc, kernelInfo);
}
@ -91,5 +97,6 @@ public:
DML_OP_DEFINE_CREATION_FUNCTION(Pad7, VersionedKernel<DmlOperatorPadding, 7>);
DML_OP_DEFINE_CREATION_FUNCTION(Pad11, VersionedKernel<DmlOperatorPadding, 11>);
DML_OP_DEFINE_CREATION_FUNCTION(Pad13, VersionedKernel<DmlOperatorPadding, 13>);
} // namespace Dml

View file

@ -0,0 +1,69 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "precomp.h"
namespace Dml
{
class DmlOperatorQLinearAdd : public DmlOperator
{
enum InputTensors {
IN_A,
IN_A_SCALE,
IN_A_ZERO_POINT,
IN_B,
IN_B_SCALE,
IN_B_ZERO_POINT,
IN_C_SCALE,
IN_C_ZERO_POINT
};
public:
DmlOperatorQLinearAdd(const MLOperatorKernelCreationContext& kernelInfo)
: DmlOperator(kernelInfo)
{
DmlOperator::Initialize(kernelInfo);
std::vector<DimensionType> outputShape = kernelInfo.GetTensorShapeDescription().GetOutputTensorShape(0);
uint32_t dmlDimSize = m_inputTensorDescs[0].GetDimensionCount();
// Initialize the input descriptions with broadcasting
m_inputTensorDescs[IN_A] = CreateTensorDescFromInput(kernelInfo, 0/*A OnnxIndex*/, TensorAxis::DoNotCoerce, TensorAxis::W, TensorAxis::RightAligned, outputShape);
m_inputTensorDescs[IN_B] = CreateTensorDescFromInput(kernelInfo, 3/*B OnnxIndex*/, TensorAxis::DoNotCoerce, TensorAxis::W, TensorAxis::RightAligned, outputShape);
m_inputTensorDescs[IN_A_SCALE] = CreateTensorDescFromInput(kernelInfo, 1/*A Scale OnnxIndex*/, TensorAxis::DoNotCoerce, TensorAxis::W, TensorAxis::RightAligned, std::nullopt, dmlDimSize);
m_inputTensorDescs[IN_A_ZERO_POINT] = CreateTensorDescFromInput(kernelInfo, 2/*A Zero point OnnxIndex*/, TensorAxis::DoNotCoerce, TensorAxis::W, TensorAxis::RightAligned, std::nullopt, dmlDimSize);
m_inputTensorDescs[IN_B_SCALE] = CreateTensorDescFromInput(kernelInfo, 4/*B Scale OnnxIndex*/, TensorAxis::DoNotCoerce, TensorAxis::W, TensorAxis::RightAligned, std::nullopt, dmlDimSize);
m_inputTensorDescs[IN_B_ZERO_POINT] = CreateTensorDescFromInput(kernelInfo, 5/*B Zero point OnnxIndex*/, TensorAxis::DoNotCoerce, TensorAxis::W, TensorAxis::RightAligned, std::nullopt, dmlDimSize);
m_inputTensorDescs[IN_C_SCALE] = CreateTensorDescFromInput(kernelInfo, 6/*C Zero point OnnxIndex*/, TensorAxis::DoNotCoerce, TensorAxis::W, TensorAxis::RightAligned, std::nullopt, dmlDimSize);
m_inputTensorDescs[IN_C_ZERO_POINT] = CreateTensorDescFromInput(kernelInfo, 7/*C Zero point OnnxIndex*/, TensorAxis::DoNotCoerce, TensorAxis::W, TensorAxis::RightAligned, std::nullopt, dmlDimSize);
// Initialize the output description while overriding the shape
m_outputTensorDescs[0] = CreateTensorDescFromOutput(kernelInfo, 0, TensorAxis::DoNotCoerce, TensorAxis::W, TensorAxis::RightAligned, outputShape);
std::vector<DML_TENSOR_DESC> inputDescs = GetDmlInputDescs();
std::vector<DML_TENSOR_DESC> outputDescs = GetDmlOutputDescs();
DML_ELEMENT_WISE_QUANTIZED_LINEAR_ADD_OPERATOR_DESC AddDesc = {};
AddDesc.ATensor = &inputDescs[IN_A];
AddDesc.AScaleTensor = &inputDescs[IN_A_SCALE];
AddDesc.AZeroPointTensor = inputDescs[IN_A_ZERO_POINT].Desc != nullptr ? &inputDescs[IN_A_ZERO_POINT] : nullptr;
AddDesc.BTensor = &inputDescs[IN_B];
AddDesc.BScaleTensor = &inputDescs[IN_B_SCALE];
AddDesc.BZeroPointTensor = inputDescs[IN_B_ZERO_POINT].Desc != nullptr ? &inputDescs[IN_B_ZERO_POINT] : nullptr;
AddDesc.OutputScaleTensor = &inputDescs[IN_C_SCALE];
AddDesc.OutputZeroPointTensor = inputDescs[IN_C_ZERO_POINT].Desc != nullptr ? &inputDescs[IN_C_ZERO_POINT] : nullptr;
AddDesc.OutputTensor = &outputDescs[0];
DML_OPERATOR_DESC opDesc = { DML_OPERATOR_ELEMENT_WISE_QUANTIZED_LINEAR_ADD, &AddDesc };
SetDmlOperatorDesc(opDesc, kernelInfo);
}
};
DML_OP_DEFINE_CREATION_FUNCTION(QLinearAdd, DmlOperatorQLinearAdd);
} // namespace Dml

View file

@ -341,5 +341,6 @@ DML_OP_DEFINE_CREATION_FUNCTION(Resize11, VersionedKernel<DmlOperatorResize, 11>
DML_OP_DEFINE_CREATION_FUNCTION(Upsample7, VersionedKernel<DmlOperatorResize, 7>);
DML_OP_DEFINE_CREATION_FUNCTION(Upsample9, VersionedKernel<DmlOperatorResize, 9>);
DML_OP_DEFINE_CREATION_FUNCTION(Upsample10, VersionedKernel<DmlOperatorResize, 10>);
DML_OP_DEFINE_CREATION_FUNCTION(Upsample13, VersionedKernel<DmlOperatorResize, 13>);
} // namespace Dml

View file

@ -113,6 +113,7 @@ public:
DML_OP_DEFINE_CREATION_FUNCTION(Scatter9, DmlOperatorScatter);
DML_OP_DEFINE_CREATION_FUNCTION(Scatter11, DmlOperatorScatter);
DML_OP_DEFINE_CREATION_FUNCTION(Scatter13, DmlOperatorScatter);
DML_OP_DEFINE_CREATION_FUNCTION(ScatterElements, DmlOperatorScatter);
DML_OP_DEFINE_CREATION_FUNCTION(ScatterND, DmlOperatorScatterNd);

View file

@ -55,4 +55,5 @@ void CALLBACK QuerySlice(IMLOperatorSupportQueryContextPrivate* context, bool* i
DML_OP_DEFINE_CREATION_FUNCTION(Slice7, VersionedKernel<DmlOperatorSlice, 7> );
DML_OP_DEFINE_CREATION_FUNCTION(Slice10, VersionedKernel<DmlOperatorSlice, 10>);
DML_OP_DEFINE_CREATION_FUNCTION(Slice11, VersionedKernel<DmlOperatorSlice, 11>);
DML_OP_DEFINE_CREATION_FUNCTION(Slice13, VersionedKernel<DmlOperatorSlice, 13>);
} // namespace Dml

View file

@ -35,14 +35,18 @@ enum class SupportedTensorDataTypes : uint32_t
Complex64 = 1<<14,
Complex128 = 1<<15,
Ints8to32 = UInt8|Int8|UInt16|Int16|UInt32|Int32,
Int32to64 = UInt32|Int32|UInt64|Int64,
Float16to32 = Float16|Float32, // Float64 is not supported by DirectML.
NumericDefault = Ints8to32|Float16to32,
Ints32to64 = UInt32|Int32|UInt64|Int64,
Ints8to64 = UInt8|Int8|UInt16|Int16|UInt32|Int32|UInt64|Int64,
UInt8to64 = UInt8|UInt16|UInt32|UInt64,
Float16to32 = Float16|Float32,
Float16to64 = Float16|Float32|Float64,
NumericDefault = Ints8to32|Float16to32, // Only simple numbers, not bool, complex, or string.
Scalars8to32 = UInt8|Int8|UInt16|Int16|UInt32|Int32|Float16to32|Bool,
AllScalars = UInt8|Int8|UInt16|Int16|UInt32|Int32|UInt64|Int64|Float16to32|Bool,
AllScalars = UInt8|Int8|UInt16|Int16|UInt32|Int32|UInt64|Int64|Float16|Float32|Float64|Bool,
Ints8Bit = UInt8|Int8,
Ints16Bit = UInt16|Int16,
Ints32Bit = UInt32|Int32,
Ints64Bit = UInt64|Int64,
All = static_cast<uint32_t>(-1),
};
DEFINE_ENUM_FLAG_OPERATORS(Dml::SupportedTensorDataTypes);
@ -54,6 +58,7 @@ enum class DmlGraphSupport : uint32_t
SupportedWith64BitTensorsVia32BitStrides = 2, // Supports them via 32-bit tensors and doubled strides.
SupportedWith64BitTensorsVia32BitStridesFromAnyEp = 4, // Supports input from any execution provider (otherwise only inputs from other DML nodes)
Prefer64BitTensorsDirectly = 8, // Natively supports 64-bit tensors. So avoid strided 32-bit unless the device lacks support.
Support64BitTensorsViaEmulation = 16,// supports int/uint64 tensors via emulation of 32-bit types.
};
DEFINE_ENUM_FLAG_OPERATORS(DmlGraphSupport);
@ -109,8 +114,10 @@ DML_OP_EXTERN_CREATION_FUNCTION(Concat);
DML_OP_EXTERN_CREATION_FUNCTION(Slice7);
DML_OP_EXTERN_CREATION_FUNCTION(Slice10);
DML_OP_EXTERN_CREATION_FUNCTION(Slice11);
DML_OP_EXTERN_CREATION_FUNCTION(Slice13);
DML_OP_EXTERN_CREATION_FUNCTION(Pad7);
DML_OP_EXTERN_CREATION_FUNCTION(Pad11);
DML_OP_EXTERN_CREATION_FUNCTION(Pad13);
DML_OP_EXTERN_CREATION_FUNCTION(SpaceToDepth);
DML_OP_EXTERN_CREATION_FUNCTION(DepthToSpace);
DML_OP_EXTERN_CREATION_FUNCTION(Sqrt);
@ -124,6 +131,7 @@ DML_OP_EXTERN_CREATION_FUNCTION(Floor);
DML_OP_EXTERN_CREATION_FUNCTION(Clip7);
DML_OP_EXTERN_CREATION_FUNCTION(Clip11);
DML_OP_EXTERN_CREATION_FUNCTION(Clip12);
DML_OP_EXTERN_CREATION_FUNCTION(Clip13);
DML_OP_EXTERN_CREATION_FUNCTION(Greater);
DML_OP_EXTERN_CREATION_FUNCTION(Less);
DML_OP_EXTERN_CREATION_FUNCTION(GreaterOrEqual);
@ -161,6 +169,7 @@ DML_OP_EXTERN_CREATION_FUNCTION(ImageScaler);
DML_OP_EXTERN_CREATION_FUNCTION(Upsample7);
DML_OP_EXTERN_CREATION_FUNCTION(Upsample9);
DML_OP_EXTERN_CREATION_FUNCTION(Upsample10);
DML_OP_EXTERN_CREATION_FUNCTION(Upsample13);
DML_OP_EXTERN_CREATION_FUNCTION(Sigmoid);
DML_OP_EXTERN_CREATION_FUNCTION(HardSigmoid);
DML_OP_EXTERN_CREATION_FUNCTION(Tanh);
@ -206,7 +215,7 @@ DML_OP_EXTERN_CREATION_FUNCTION(FusedSum);
DML_OP_EXTERN_CREATION_FUNCTION(QuantizeLinear);
DML_OP_EXTERN_CREATION_FUNCTION(DequantizeLinear);
DML_OP_EXTERN_CREATION_FUNCTION(Sign);
DML_OP_EXTERN_CREATION_FUNCTION(IsNan);
DML_OP_EXTERN_CREATION_FUNCTION(IsNaN);
DML_OP_EXTERN_CREATION_FUNCTION(Sinh);
DML_OP_EXTERN_CREATION_FUNCTION(Cosh);
DML_OP_EXTERN_CREATION_FUNCTION(Tanh);
@ -221,6 +230,7 @@ DML_OP_EXTERN_CREATION_FUNCTION(EyeLike);
DML_OP_EXTERN_CREATION_FUNCTION(MaxUnpool);
DML_OP_EXTERN_CREATION_FUNCTION(Scatter9);
DML_OP_EXTERN_CREATION_FUNCTION(Scatter11);
DML_OP_EXTERN_CREATION_FUNCTION(Scatter13);
DML_OP_EXTERN_CREATION_FUNCTION(Resize10);
DML_OP_EXTERN_CREATION_FUNCTION(Resize11);
DML_OP_EXTERN_CREATION_FUNCTION(ConstantOfShape);
@ -235,6 +245,7 @@ DML_OP_EXTERN_CREATION_FUNCTION(ReverseSequence);
DML_OP_EXTERN_CREATION_FUNCTION(Round);
DML_OP_EXTERN_CREATION_FUNCTION(ScatterElements);
DML_OP_EXTERN_CREATION_FUNCTION(ScatterND);
DML_OP_EXTERN_CREATION_FUNCTION(QLinearAdd);
DML_OP_EXTERN_CREATION_FUNCTION(QLinearConv);
DML_OP_EXTERN_CREATION_FUNCTION(QLinearMatMul);
DML_OP_EXTERN_CREATION_FUNCTION(DynamicQuantizeLinear);
@ -251,6 +262,7 @@ constexpr static std::array<const char*, 2> typeNameListTwo = { "T1", "T2" };
constexpr static std::array<const char*, 3> typeNameListThree = { "T1", "T2", "T3" };
constexpr static std::array<const char*, 4> typeNameListFour = { "T1", "T2", "T3", "T4" };
constexpr static std::array<const char*, 2> typeNameListTopK = { "T", "I" };
constexpr static std::array<const char*, 2> typeNameListMaxPool = { "T", "I" };
constexpr static std::array<const char*, 2> typeNameListLogicalComparison = { "T", "T1" };
constexpr static std::array<const char*, 2> typeNameListPow12 = {"T", "T1"};
constexpr static std::array<const char*, 2> typeNameListConstantOfShape = { "T1", "T2" };
@ -259,51 +271,65 @@ constexpr static std::array<const char*, 1> typeNameListScatterGatherND = { "T"
constexpr static std::array<const char*, 2> typeNameListSlice10 = { "T", "Tind" };
constexpr static std::array<const char*, 2> typeNameListWhere = { "B", "T" };
constexpr static std::array<const char*, 1> typeNameListEyeLike = { "T2" };
constexpr static std::array<SupportedTensorDataTypes, 1> supportedTypeListAll = {SupportedTensorDataTypes::All};
constexpr static std::array<SupportedTensorDataTypes, 1> supportedTypeListFloat32 = {SupportedTensorDataTypes::Float32};
constexpr static std::array<SupportedTensorDataTypes, 1> supportedTypeListFloat16to32 = {SupportedTensorDataTypes::Float16to32};
constexpr static std::array<SupportedTensorDataTypes, 1> supportedTypeListFloat16to32Int8 = {SupportedTensorDataTypes::Float16to32 | SupportedTensorDataTypes::Ints8Bit };
constexpr static std::array<SupportedTensorDataTypes, 1> supportedTypeListFloat16to32Int32 = {SupportedTensorDataTypes::Float16to32 | SupportedTensorDataTypes::Int32 | SupportedTensorDataTypes::UInt32};
constexpr static std::array<SupportedTensorDataTypes, 1> supportedTypeListFloat16to32Int8to32 = {SupportedTensorDataTypes::Float16to32 | SupportedTensorDataTypes::Ints8Bit | SupportedTensorDataTypes::Ints16Bit | SupportedTensorDataTypes::Ints32Bit};
constexpr static std::array<SupportedTensorDataTypes, 1> supportedTypeListInt8to32 = {SupportedTensorDataTypes::Ints8to32};
constexpr static std::array<SupportedTensorDataTypes, 1> supportedTypeListInt32to64AndFloat16to32 = {SupportedTensorDataTypes::Int32to64|SupportedTensorDataTypes::Float16to32};
constexpr static std::array<SupportedTensorDataTypes, 1> supportedTypeListFloat16to32Ints8 = {SupportedTensorDataTypes::Float16to32 | SupportedTensorDataTypes::Ints8Bit };
constexpr static std::array<SupportedTensorDataTypes, 1> supportedTypeListFloat16to32Ints32 = {SupportedTensorDataTypes::Float16to32 | SupportedTensorDataTypes::Int32 | SupportedTensorDataTypes::UInt32};
constexpr static std::array<SupportedTensorDataTypes, 1> supportedTypeListFloat16to32Ints8to32 = {SupportedTensorDataTypes::Float16to32 | SupportedTensorDataTypes::Ints8Bit | SupportedTensorDataTypes::Ints16Bit | SupportedTensorDataTypes::Ints32Bit};
constexpr static std::array<SupportedTensorDataTypes, 1> supportedTypeListFloat16to32Ints8to64 = {SupportedTensorDataTypes::Float16to32 | SupportedTensorDataTypes::Ints8Bit | SupportedTensorDataTypes::Ints16Bit | SupportedTensorDataTypes::Ints32Bit | SupportedTensorDataTypes::Ints64Bit};
constexpr static std::array<SupportedTensorDataTypes, 1> supportedTypeListFloat16to32Ints32to64 = {SupportedTensorDataTypes::Float16to32 | SupportedTensorDataTypes::Ints32Bit | SupportedTensorDataTypes::Ints64Bit};
constexpr static std::array<SupportedTensorDataTypes, 1> supportedTypeListUInt8to64 = {SupportedTensorDataTypes::UInt8to64};
constexpr static std::array<SupportedTensorDataTypes, 1> supportedTypeListNumericDefault = { SupportedTensorDataTypes::NumericDefault };
constexpr static std::array<SupportedTensorDataTypes, 1> supportedTypeListAllScalars = { SupportedTensorDataTypes::AllScalars };
constexpr static std::array<SupportedTensorDataTypes, 1> supportedTypeListBool = {SupportedTensorDataTypes::Bool};
constexpr static std::array<SupportedTensorDataTypes, 2> supportedTypeListPow12 = {SupportedTensorDataTypes::Int32 | SupportedTensorDataTypes::Float16to32, SupportedTensorDataTypes::NumericDefault};
constexpr static std::array<SupportedTensorDataTypes, 2> supportedTypeListTopK = {SupportedTensorDataTypes::NumericDefault, SupportedTensorDataTypes::Int64};
constexpr static std::array<SupportedTensorDataTypes, 2> supportedTypeListTopK = {SupportedTensorDataTypes::NumericDefault | SupportedTensorDataTypes::Ints64Bit, SupportedTensorDataTypes::Int64};
constexpr static std::array<SupportedTensorDataTypes, 2> supportedTypeListMaxPool = {SupportedTensorDataTypes::Float16to32 | SupportedTensorDataTypes::Ints8Bit, SupportedTensorDataTypes::Int64};
constexpr static std::array<SupportedTensorDataTypes, 2> supportedTypeListMaxUnpool = {SupportedTensorDataTypes::Float16to32, SupportedTensorDataTypes::Int64};
constexpr static std::array<SupportedTensorDataTypes, 1> supportedTypeListIndices = { SupportedTensorDataTypes::Int32|SupportedTensorDataTypes::Int64 };
constexpr static std::array<SupportedTensorDataTypes, 2> supportedTypeListCast = { SupportedTensorDataTypes::AllScalars, SupportedTensorDataTypes::AllScalars };
constexpr static std::array<SupportedTensorDataTypes, 1> supportedTypeListScalars8to32 = { SupportedTensorDataTypes::Scalars8to32 };
constexpr static std::array<SupportedTensorDataTypes, 2> supportedTypeListScatterGather = { SupportedTensorDataTypes::Scalars8to32, SupportedTensorDataTypes::Int32 | SupportedTensorDataTypes::Int64 };
constexpr static std::array<SupportedTensorDataTypes, 1> supportedTypeListScatterGatherND = { SupportedTensorDataTypes::Scalars8to32 };
constexpr static std::array<SupportedTensorDataTypes, 2> supportedTypeListSlice10 = { SupportedTensorDataTypes::NumericDefault, SupportedTensorDataTypes::Int32 | SupportedTensorDataTypes::Int64 };
constexpr static std::array<SupportedTensorDataTypes, 2> supportedTypeListScatterGather = { SupportedTensorDataTypes::AllScalars, SupportedTensorDataTypes::Int32 | SupportedTensorDataTypes::Int64 };
constexpr static std::array<SupportedTensorDataTypes, 1> supportedTypeListScatterGatherND = { SupportedTensorDataTypes::AllScalars };
constexpr static std::array<SupportedTensorDataTypes, 2> supportedTypeListSlice10 = { SupportedTensorDataTypes::AllScalars, SupportedTensorDataTypes::Int32 | SupportedTensorDataTypes::Int64 };
constexpr static std::array<SupportedTensorDataTypes, 2> supportedTypeListQuantizeLinear = { SupportedTensorDataTypes::Float32 | SupportedTensorDataTypes::Int32, SupportedTensorDataTypes::UInt8 | SupportedTensorDataTypes::Int8 };
constexpr static std::array<SupportedTensorDataTypes, 2> supportedTypeListDequantizeLinear = { SupportedTensorDataTypes::Float32, SupportedTensorDataTypes::UInt8 | SupportedTensorDataTypes::Int8 | SupportedTensorDataTypes::Int32 };
constexpr static std::array<SupportedTensorDataTypes, 2> supportedTypeListQuantize = { SupportedTensorDataTypes::Float32, SupportedTensorDataTypes::UInt8 };
constexpr static std::array<SupportedTensorDataTypes, 2> supportedTypeListIsNan = { SupportedTensorDataTypes::Float16to32, SupportedTensorDataTypes::Bool };
constexpr static std::array<SupportedTensorDataTypes, 2> supportedTypeListIsInf = { SupportedTensorDataTypes::Float32, SupportedTensorDataTypes::Bool };
constexpr static std::array<SupportedTensorDataTypes, 2> supportedTypeListConstantOfShape = { SupportedTensorDataTypes::Int32|SupportedTensorDataTypes::Int64, SupportedTensorDataTypes::Float16to32 };
constexpr static std::array<SupportedTensorDataTypes, 2> supportedTypeListWhere = { SupportedTensorDataTypes::Bool, SupportedTensorDataTypes::AllScalars };
constexpr static std::array<SupportedTensorDataTypes, 3> supportedTypeListOneHot = /* indices, depth, values */ { SupportedTensorDataTypes::Int32to64, SupportedTensorDataTypes::AllScalars, SupportedTensorDataTypes::Scalars8to32 };
constexpr static std::array<SupportedTensorDataTypes, 2> supportedTypeListConstantOfShape = { SupportedTensorDataTypes::Int64, SupportedTensorDataTypes::AllScalars };
constexpr static std::array<SupportedTensorDataTypes, 2> supportedTypeListWhere = { SupportedTensorDataTypes::Bool, SupportedTensorDataTypes::Scalars8to32 };
constexpr static std::array<SupportedTensorDataTypes, 3> supportedTypeListOneHot = /* indices, depth, values */ { SupportedTensorDataTypes::Ints32to64, SupportedTensorDataTypes::AllScalars, SupportedTensorDataTypes::AllScalars };
constexpr static std::array<SupportedTensorDataTypes, 2> supportedTypeListLogicalComparison7 = /* A&B,C */ { SupportedTensorDataTypes::Float16to32, SupportedTensorDataTypes::Bool };
constexpr static std::array<SupportedTensorDataTypes, 2> supportedTypeListLogicalComparison9 = /* A&B,C */ { SupportedTensorDataTypes::NumericDefault, SupportedTensorDataTypes::Bool };
constexpr static std::array<SupportedTensorDataTypes, 1> supportedTypeListSigned = { SupportedTensorDataTypes::Float16to32 | SupportedTensorDataTypes::Int32 | SupportedTensorDataTypes::Int16 | SupportedTensorDataTypes::Int8 };
constexpr static std::array<SupportedTensorDataTypes, 1> supportedTypeListRange = {SupportedTensorDataTypes::Int16|SupportedTensorDataTypes::Int32|SupportedTensorDataTypes::Float32};
constexpr static std::array<SupportedTensorDataTypes, 2> supportedTypeListLogicalComparison9 = /* A&B,C */ { SupportedTensorDataTypes::Float16to32|SupportedTensorDataTypes::Ints8to64, SupportedTensorDataTypes::Bool };
constexpr static std::array<SupportedTensorDataTypes, 1> supportedTypeListSigned = { SupportedTensorDataTypes::Float16to32 | SupportedTensorDataTypes::Int64 | SupportedTensorDataTypes::Int32 | SupportedTensorDataTypes::Int16 | SupportedTensorDataTypes::Int8 };
constexpr static std::array<SupportedTensorDataTypes, 1> supportedTypeListRange = {SupportedTensorDataTypes::Int16|SupportedTensorDataTypes::Int32|SupportedTensorDataTypes::Int64|SupportedTensorDataTypes::Float32};
constexpr static std::array<SupportedTensorDataTypes, 2> supportedTypeListResize11 = {SupportedTensorDataTypes::Float16to32, SupportedTensorDataTypes::Float16to32 /* ROI read by CPU */};
constexpr static std::array<SupportedTensorDataTypes, 3> supportedTypeListInteger = {SupportedTensorDataTypes::Int8|SupportedTensorDataTypes::UInt8, SupportedTensorDataTypes::Int8|SupportedTensorDataTypes::UInt8, SupportedTensorDataTypes::Int32 };
constexpr static std::array<SupportedTensorDataTypes, 1> supportedTypeListInteger8 = {SupportedTensorDataTypes::Int8|SupportedTensorDataTypes::UInt8 };
constexpr static std::array<SupportedTensorDataTypes, 2> supportedTypeListRoiAlign = {SupportedTensorDataTypes::Float16to32, SupportedTensorDataTypes::Int32|SupportedTensorDataTypes::Int64 };
constexpr static std::array<SupportedTensorDataTypes, 1> supportedTypeListArgMinMax = {SupportedTensorDataTypes::Float16to32|SupportedTensorDataTypes::Ints8to64};
constexpr static std::array<SupportedTensorDataTypes, 3> supportedTypeListQLinearMatMul = {
SupportedTensorDataTypes::Int8|SupportedTensorDataTypes::UInt8,
SupportedTensorDataTypes::Int8|SupportedTensorDataTypes::UInt8,
SupportedTensorDataTypes::Int8|SupportedTensorDataTypes::UInt8
SupportedTensorDataTypes::Int8|SupportedTensorDataTypes::UInt8,
SupportedTensorDataTypes::Int8|SupportedTensorDataTypes::UInt8,
SupportedTensorDataTypes::Int8|SupportedTensorDataTypes::UInt8
};
constexpr static std::array<SupportedTensorDataTypes, 4> supportedTypeListQLinearConv = {
SupportedTensorDataTypes::Int8|SupportedTensorDataTypes::UInt8,
SupportedTensorDataTypes::Int8|SupportedTensorDataTypes::UInt8,
SupportedTensorDataTypes::Int8|SupportedTensorDataTypes::UInt8,
SupportedTensorDataTypes::Int8|SupportedTensorDataTypes::UInt8,
SupportedTensorDataTypes::Int8|SupportedTensorDataTypes::UInt8,
SupportedTensorDataTypes::Int8|SupportedTensorDataTypes::UInt8,
SupportedTensorDataTypes::Int32
};
constexpr static std::array<SupportedTensorDataTypes, 2> supportedTypeListDynamicQuantizeLinear = {
SupportedTensorDataTypes::Float32,
SupportedTensorDataTypes::UInt8,
};
template<typename... Args>
constexpr auto requiredConstantCpuInputs(Args... args)
{
@ -336,249 +362,313 @@ constexpr auto requiredConstantCpuInputs(Args... args)
constexpr static OperatorRegistrationInformation operatorRegistrationInformationTable[] =
{
/// Domain/Type, Ver, Name, TypeNames, Types, Graph Support, Required const CPU inputs,
/// Input count required for graph support,
/// Support query function
/// Domain/Type, Ver, Name, TypeNames, Types, Graph Support, Required const CPU inputs,
/// Input count required for graph support,
/// Support query function
// Deep Learning Standard Layers
{REG_INFO( 7, Conv, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO( 11, Conv, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO( 7, ConvTranspose, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO( 11, ConvTranspose, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO( 7, AveragePool, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO( 10, AveragePool, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO( 11, AveragePool, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO( 7, GlobalAveragePool, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO( 7, MaxPool, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides)},
{REG_INFO( 8, MaxPool, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides, requiredConstantCpuInputs(), std::nullopt, QueryMaxPool)},
{REG_INFO( 10, MaxPool, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides, requiredConstantCpuInputs(), std::nullopt, QueryMaxPool)},
{REG_INFO( 11, MaxPool, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides, requiredConstantCpuInputs(), std::nullopt, QueryMaxPool)},
{REG_INFO( 12, MaxPool, typeNameListDefault, supportedTypeListFloat16to32Int8, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides, requiredConstantCpuInputs(), std::nullopt, QueryMaxPool)},
{REG_INFO( 7, Conv, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO( 11, Conv, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO( 7, ConvTranspose, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO( 11, ConvTranspose, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO( 7, AveragePool, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO( 10, AveragePool, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO( 11, AveragePool, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO( 7, GlobalAveragePool, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO( 7, MaxPool, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides)},
{REG_INFO( 8, MaxPool, typeNameListMaxPool, supportedTypeListMaxPool, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly, requiredConstantCpuInputs(), std::nullopt, QueryMaxPool)},
{REG_INFO( 10, MaxPool, typeNameListMaxPool, supportedTypeListMaxPool, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly, requiredConstantCpuInputs(), std::nullopt, QueryMaxPool)},
{REG_INFO( 11, MaxPool, typeNameListMaxPool, supportedTypeListMaxPool, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly, requiredConstantCpuInputs(), std::nullopt, QueryMaxPool)},
{REG_INFO( 12, MaxPool, typeNameListMaxPool, supportedTypeListMaxPool, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly, requiredConstantCpuInputs(), std::nullopt, QueryMaxPool)},
{REG_INFO( 7, GlobalMaxPool, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO( 7, LpPool, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO( 11, LpPool, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO( 7, GlobalLpPool, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO( 7, MaxRoiPool, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO_VER( 10, RoiAlign, typeNameListTwo, supportedTypeListRoiAlign, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStridesFromAnyEp)},
{REG_INFO( 7, InstanceNormalization, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO( 7, BatchNormalization, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO( 9, BatchNormalization, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, // v9 just removes 'spatial' attribute.
{REG_INFO( 7, LRN, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO( 7, MeanVarianceNormalization, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO( 9, MeanVarianceNormalization, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO( 7, LpNormalization, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO( 7, RNN, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::NotSupported)},
{REG_INFO( 7, GRU, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::NotSupported)},
{REG_INFO( 7, LSTM, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::NotSupported)},
{REG_INFO_MS( 1, ConvTransposeWithDynamicPads, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported, requiredConstantCpuInputs(2))},
{REG_INFO( 7, GlobalMaxPool, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO( 7, LpPool, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO( 11, LpPool, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO( 7, GlobalLpPool, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO( 7, MaxRoiPool, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO_VER( 10, RoiAlign, typeNameListTwo, supportedTypeListRoiAlign, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly)},
{REG_INFO( 7, InstanceNormalization, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO( 7, BatchNormalization, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO( 9, BatchNormalization, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, // v9 just removes 'spatial' attribute.
{REG_INFO( 7, LRN, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO( 13, LRN, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO( 7, MeanVarianceNormalization, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO( 9, MeanVarianceNormalization, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO( 13, MeanVarianceNormalization, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO( 7, LpNormalization, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO( 7, RNN, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::NotSupported)},
{REG_INFO( 7, GRU, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::NotSupported)},
{REG_INFO( 7, LSTM, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::NotSupported)},
{REG_INFO_MS( 1, ConvTransposeWithDynamicPads, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported, requiredConstantCpuInputs(2))},
// Data Reorganization Layers
{REG_INFO( 7, Split, typeNameListDefault, supportedTypeListNumericDefault, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides)},
{REG_INFO( 11, Split, typeNameListDefault, supportedTypeListNumericDefault, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides)}, // Adds negative axis.
{REG_INFO( 7, Transpose, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides)},
{REG_INFO( 7, Concat, typeNameListDefault, supportedTypeListNumericDefault, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides)},
{REG_INFO( 11, Concat, typeNameListDefault, supportedTypeListNumericDefault, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides)}, // Adds negative axis.
{REG_INFO_VER( 7, Slice, typeNameListDefault, supportedTypeListNumericDefault, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides)},
{REG_INFO_VER( 10, Slice, typeNameListSlice10, supportedTypeListSlice10, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides, requiredConstantCpuInputs(1, 2, 3, 4), std::nullopt, QuerySlice)}, // Adds negative axes.
{REG_INFO_VER( 11, Slice, typeNameListSlice10, supportedTypeListSlice10, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides, requiredConstantCpuInputs(1, 2, 3, 4), std::nullopt, QuerySlice)},
{REG_INFO_VER( 7, Pad, typeNameListDefault, supportedTypeListNumericDefault, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides)},
{REG_INFO_VER( 11, Pad, typeNameListDefault, supportedTypeListNumericDefault, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides, requiredConstantCpuInputs(1, 2) /*pads, value*/)}, // https://microsoft.visualstudio.com/OS/_workitems/edit/26007728
{REG_INFO( 7, SpaceToDepth, typeNameListDefault, supportedTypeListScalars8to32, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides)},
{REG_INFO( 7, DepthToSpace, typeNameListDefault, supportedTypeListScalars8to32, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides)},
{REG_INFO( 11, DepthToSpace, typeNameListDefault, supportedTypeListScalars8to32, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides)},
{REG_INFO( 7, Tile, typeNameListDefault, supportedTypeListScalars8to32, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides, requiredConstantCpuInputs(1))},
{REG_INFO( 8, Expand, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides, requiredConstantCpuInputs(1))},
{REG_INFO( 9, ConstantOfShape, typeNameListConstantOfShape, supportedTypeListConstantOfShape, DmlGraphSupport::NotSupported, requiredConstantCpuInputs(0))},
{REG_INFO( 7, Gather, typeNameListScatterGather, supportedTypeListScatterGather, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStridesFromAnyEp)},
{REG_INFO( 11, Gather, typeNameListScatterGather, supportedTypeListScatterGather, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStridesFromAnyEp)},
{REG_INFO( 11, GatherElements, typeNameListScatterGather, supportedTypeListScatterGather, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStridesFromAnyEp)},
{REG_INFO( 11, GatherND, typeNameListScatterGatherND, supportedTypeListScatterGatherND, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStridesFromAnyEp)},
{REG_INFO( 12, GatherND, typeNameListScatterGatherND, supportedTypeListScatterGatherND, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStridesFromAnyEp)},
{REG_INFO_VER( 9, Scatter, typeNameListScatterGather, supportedTypeListScatterGather, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStridesFromAnyEp)},
{REG_INFO_VER( 11, Scatter, typeNameListScatterGather, supportedTypeListScatterGather, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStridesFromAnyEp)},
{REG_INFO( 11, ScatterElements, typeNameListScatterGather, supportedTypeListScatterGather, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStridesFromAnyEp)},
{REG_INFO( 11, ScatterND, typeNameListScatterGatherND, supportedTypeListScatterGatherND, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStridesFromAnyEp)},
{REG_INFO( 9, EyeLike, typeNameListEyeLike, supportedTypeListScalars8to32, DmlGraphSupport::Supported)},
{REG_INFO( 7, Split, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly)},
{REG_INFO( 11, Split, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly)}, // Adds negative axis.
{REG_INFO( 7, Transpose, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly)},
{REG_INFO( 13, Transpose, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly)},
{REG_INFO( 7, Concat, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly)},
{REG_INFO( 11, Concat, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly)}, // Adds negative axis.
{REG_INFO( 13, Concat, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly)}, // Adds negative axis.
{REG_INFO_VER( 7, Slice, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly)},
{REG_INFO_VER( 10, Slice, typeNameListSlice10, supportedTypeListSlice10, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly, requiredConstantCpuInputs(1, 2, 3, 4), std::nullopt, QuerySlice)}, // Adds negative axes.
{REG_INFO_VER( 11, Slice, typeNameListSlice10, supportedTypeListSlice10, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly, requiredConstantCpuInputs(1, 2, 3, 4), std::nullopt, QuerySlice)},
{REG_INFO_VER( 13, Slice, typeNameListSlice10, supportedTypeListSlice10, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly, requiredConstantCpuInputs(1, 2, 3, 4), std::nullopt, QuerySlice)},
{REG_INFO_VER( 7, Pad, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly)},
{REG_INFO_VER( 11, Pad, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly, requiredConstantCpuInputs(1, 2) /*pads, value*/)}, // https://microsoft.visualstudio.com/OS/_workitems/edit/26007728
{REG_INFO_VER( 13, Pad, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly, requiredConstantCpuInputs(1, 2) /*pads, value*/)}, // https://microsoft.visualstudio.com/OS/_workitems/edit/26007728
{REG_INFO( 7, SpaceToDepth, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly)},
{REG_INFO( 13, SpaceToDepth, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly)},
{REG_INFO( 7, DepthToSpace, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly)},
{REG_INFO( 11, DepthToSpace, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly)},
{REG_INFO( 13, DepthToSpace, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly)},
{REG_INFO( 7, Tile, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly, requiredConstantCpuInputs(1))},
{REG_INFO( 13, Tile, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly, requiredConstantCpuInputs(1))},
{REG_INFO( 8, Expand, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly, requiredConstantCpuInputs(1))},
{REG_INFO( 13, Expand, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly, requiredConstantCpuInputs(1))},
{REG_INFO( 9, ConstantOfShape, typeNameListConstantOfShape, supportedTypeListConstantOfShape, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly, requiredConstantCpuInputs(0))},
{REG_INFO( 7, Gather, typeNameListScatterGather, supportedTypeListScatterGather, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly)},
{REG_INFO( 11, Gather, typeNameListScatterGather, supportedTypeListScatterGather, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly)},
{REG_INFO( 13, Gather, typeNameListScatterGather, supportedTypeListScatterGather, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly)},
{REG_INFO( 11, GatherElements, typeNameListScatterGather, supportedTypeListScatterGather, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly)},
{REG_INFO( 13, GatherElements, typeNameListScatterGather, supportedTypeListScatterGather, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly)},
{REG_INFO( 11, GatherND, typeNameListScatterGatherND, supportedTypeListScatterGatherND, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly)},
{REG_INFO( 12, GatherND, typeNameListScatterGatherND, supportedTypeListScatterGatherND, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly)},
{REG_INFO( 13, GatherND, typeNameListScatterGatherND, supportedTypeListScatterGatherND, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly)},
{REG_INFO_VER( 9, Scatter, typeNameListScatterGather, supportedTypeListScatterGather, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly)},
{REG_INFO_VER( 11, Scatter, typeNameListScatterGather, supportedTypeListScatterGather, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly)},
{REG_INFO_VER( 13, Scatter, typeNameListScatterGather, supportedTypeListScatterGather, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly)},
{REG_INFO( 11, ScatterElements, typeNameListScatterGather, supportedTypeListScatterGather, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly)},
{REG_INFO( 13, ScatterElements, typeNameListScatterGather, supportedTypeListScatterGather, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly)},
{REG_INFO( 11, ScatterND, typeNameListScatterGatherND, supportedTypeListScatterGatherND, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly)},
{REG_INFO( 13, ScatterND, typeNameListScatterGatherND, supportedTypeListScatterGatherND, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly)},
{REG_INFO( 9, EyeLike, typeNameListEyeLike, supportedTypeListScalars8to32, DmlGraphSupport::Supported)},
// Data reorganization that merely changes the dimensions while keeping the data identical.
{REG_INFO_ID( 7, Identity, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides)},
{REG_INFO_ID( 7, Flatten, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides)},
{REG_INFO_ID( 9, Flatten, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides)},
{REG_INFO_ID( 11, Flatten, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides)},
{REG_INFO_ID( 7, Squeeze, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides)},
{REG_INFO_ID( 11, Squeeze, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides)},
{REG_INFO_ID( 7, Unsqueeze, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides)},
{REG_INFO_ID( 11, Unsqueeze, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides)},
{REG_INFO_ID( 7, Reshape, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides, requiredConstantCpuInputs(1))},
{REG_INFO_ID( 7, Identity, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly)},
{REG_INFO_ID( 13, Identity, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly)},
{REG_INFO_ID( 7, Flatten, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly)},
{REG_INFO_ID( 9, Flatten, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly)},
{REG_INFO_ID( 11, Flatten, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly)},
{REG_INFO_ID( 13, Flatten, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly)},
{REG_INFO_ID( 7, Squeeze, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly)},
{REG_INFO_ID( 11, Squeeze, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly)},
{REG_INFO_ID( 7, Unsqueeze, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly)},
{REG_INFO_ID( 11, Unsqueeze, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly)},
{REG_INFO_ID( 7, Reshape, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly, requiredConstantCpuInputs(1))},
{REG_INFO_ID( 13, Reshape, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly, requiredConstantCpuInputs(1))},
// Elementwise
{REG_INFO( 7, Sqrt, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO( 7, Reciprocal, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO( 7, Pow, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO( 12, Pow, typeNameListPow12, supportedTypeListPow12, DmlGraphSupport::Supported)},
{REG_INFO( 7, Exp, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO( 7, Log, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO( 7, Abs, typeNameListDefault, supportedTypeListNumericDefault, DmlGraphSupport::Supported)},
{REG_INFO( 7, Ceil, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO( 7, Floor, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO_VER( 7, Clip, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO_VER( 11, Clip, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported, requiredConstantCpuInputs(1,2))},
{REG_INFO_VER( 12, Clip, typeNameListDefault, supportedTypeListFloat16to32Int8to32,DmlGraphSupport::Supported, requiredConstantCpuInputs(1,2))},
{REG_INFO( 7, Add, typeNameListDefault, supportedTypeListFloat16to32Int32, DmlGraphSupport::Supported)},
{REG_INFO( 7, Sub, typeNameListDefault, supportedTypeListFloat16to32Int32, DmlGraphSupport::Supported)},
{REG_INFO( 7, Mul, typeNameListDefault, supportedTypeListFloat16to32Int32, DmlGraphSupport::Supported)},
{REG_INFO( 7, Div, typeNameListDefault, supportedTypeListFloat16to32Int32, DmlGraphSupport::Supported)},
{REG_INFO( 7, Sum, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported, requiredConstantCpuInputs(), 2)},
{REG_INFO( 8, Sum, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported, requiredConstantCpuInputs(), 2)},
{REG_INFO( 7, Mean, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported, requiredConstantCpuInputs(), 2)},
{REG_INFO( 8, Mean, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported, requiredConstantCpuInputs(), 2)},
{REG_INFO( 7, Max, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported, requiredConstantCpuInputs(), 2)},
{REG_INFO( 8, Max, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported, requiredConstantCpuInputs(), 2)},
{REG_INFO( 12, Max, typeNameListDefault, supportedTypeListFloat16to32Int8to32,DmlGraphSupport::Supported, requiredConstantCpuInputs(), 2)},
{REG_INFO( 7, Min, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported, requiredConstantCpuInputs(), 2)},
{REG_INFO( 8, Min, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported, requiredConstantCpuInputs(), 2)},
{REG_INFO( 12, Min, typeNameListDefault, supportedTypeListFloat16to32Int8to32,DmlGraphSupport::Supported, requiredConstantCpuInputs(), 2)},
{REG_INFO( 7, Cos, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO( 7, Sin, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO( 7, Tan, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO( 7, Acos, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO( 7, Asin, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO( 7, Atan, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO( 7, Affine, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO( 10, QuantizeLinear, typeNameListTwo, supportedTypeListQuantizeLinear, DmlGraphSupport::Supported)},
{REG_INFO( 10, DequantizeLinear, typeNameListTwo, supportedTypeListDequantizeLinear, DmlGraphSupport::Supported)},
{REG_INFO_MS( 1, QuantizeLinear, typeNameListTwo, supportedTypeListQuantize, DmlGraphSupport::Supported)},
{REG_INFO_MS( 1, DequantizeLinear, typeNameListTwo, supportedTypeListQuantize, DmlGraphSupport::Supported)},
{REG_INFO( 9, Sign, typeNameListDefault, supportedTypeListNumericDefault, DmlGraphSupport::Supported)},
{REG_INFO( 9, IsNan, typeNameListTwo, supportedTypeListIsNan, DmlGraphSupport::Supported)},
{REG_INFO( 9, Sinh, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO( 9, Cosh, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO( 9, Asinh, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO( 9, Acosh, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO( 9, Atanh, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO( 9, Erf, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO( 9, Where, typeNameListWhere, supportedTypeListWhere, DmlGraphSupport::Supported)},
{REG_INFO( 7, ReduceSum, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO( 11, ReduceSum, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO_VER( 12, Einsum, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported, requiredConstantCpuInputs(), std::nullopt, QueryEinSum )},
{REG_INFO( 7, ReduceMean, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO( 11, ReduceMean, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO( 7, ReduceProd, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO( 11, ReduceProd, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO( 7, ReduceLogSum, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO( 11, ReduceLogSum, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO( 7, ReduceLogSumExp, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO( 11, ReduceLogSumExp, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO( 7, ReduceSumSquare, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO( 11, ReduceSumSquare, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO( 7, ReduceL1, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO( 11, ReduceL1, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO( 7, ReduceL2, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO( 11, ReduceL2, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO( 7, ReduceMax, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO( 11, ReduceMax, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO( 12, ReduceMax, typeNameListDefault, supportedTypeListFloat16to32Int8, DmlGraphSupport::Supported)},
{REG_INFO( 7, ReduceMin, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO( 11, ReduceMin, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO( 12, ReduceMin, typeNameListDefault, supportedTypeListFloat16to32Int8, DmlGraphSupport::Supported)},
{REG_INFO( 7, ArgMax, typeNameListDefault, supportedTypeListNumericDefault, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides)},
{REG_INFO( 11, ArgMax, typeNameListDefault, supportedTypeListNumericDefault, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides)},
{REG_INFO( 12, ArgMax, typeNameListDefault, supportedTypeListNumericDefault, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides)},
{REG_INFO( 7, ArgMin, typeNameListDefault, supportedTypeListNumericDefault, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides)},
{REG_INFO( 11, ArgMin, typeNameListDefault, supportedTypeListNumericDefault, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides)},
{REG_INFO( 12, ArgMin, typeNameListDefault, supportedTypeListNumericDefault, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides)},
{REG_INFO( 7, Gemm, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO( 11, Gemm, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO( 9, Gemm, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO( 7, Neg, typeNameListDefault, supportedTypeListSigned, DmlGraphSupport::Supported)},
{REG_INFO( 7, Greater, typeNameListLogicalComparison, supportedTypeListLogicalComparison7,DmlGraphSupport::Supported)},
{REG_INFO( 9, Greater, typeNameListLogicalComparison, supportedTypeListLogicalComparison9,DmlGraphSupport::Supported)},
{REG_INFO( 7, Less, typeNameListLogicalComparison, supportedTypeListLogicalComparison7,DmlGraphSupport::Supported)},
{REG_INFO( 9, Less, typeNameListLogicalComparison, supportedTypeListLogicalComparison9,DmlGraphSupport::Supported)},
{REG_INFO( 12, GreaterOrEqual, typeNameListLogicalComparison, supportedTypeListLogicalComparison9,DmlGraphSupport::Supported)},
{REG_INFO( 12, LessOrEqual, typeNameListLogicalComparison, supportedTypeListLogicalComparison9,DmlGraphSupport::Supported)},
{REG_INFO( 7, Equal, typeNameListLogicalComparison, supportedTypeListLogicalComparison7,DmlGraphSupport::Supported)},
{REG_INFO( 11, Equal, typeNameListLogicalComparison, supportedTypeListLogicalComparison9,DmlGraphSupport::Supported)},
{REG_INFO( 7, Not, typeNameListDefault, supportedTypeListBool, DmlGraphSupport::Supported)},
{REG_INFO( 7, And, typeNameListDefault, supportedTypeListBool, DmlGraphSupport::Supported)},
{REG_INFO( 7, Or, typeNameListDefault, supportedTypeListBool, DmlGraphSupport::Supported)},
{REG_INFO( 7, Xor, typeNameListDefault, supportedTypeListBool, DmlGraphSupport::Supported)},
{REG_INFO( 7, Sqrt, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO( 13, Sqrt, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO( 7, Reciprocal, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO( 13, Reciprocal, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO( 7, Pow, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO( 12, Pow, typeNameListPow12, supportedTypeListPow12, DmlGraphSupport::Supported)},
{REG_INFO( 13, Pow, typeNameListPow12, supportedTypeListPow12, DmlGraphSupport::Supported)},
{REG_INFO( 7, Exp, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO( 13, Exp, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO( 7, Log, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO( 13, Log, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO( 7, Abs, typeNameListDefault, supportedTypeListNumericDefault, DmlGraphSupport::Supported)},
{REG_INFO( 13, Abs, typeNameListDefault, supportedTypeListNumericDefault, DmlGraphSupport::Supported)},
{REG_INFO( 7, Ceil, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO( 13, Ceil, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO( 7, Floor, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO( 13, Floor, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO_VER( 7, Clip, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO_VER( 11, Clip, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported, requiredConstantCpuInputs(1,2))},
{REG_INFO_VER( 12, Clip, typeNameListDefault, supportedTypeListFloat16to32Ints8to64, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly, requiredConstantCpuInputs(1,2))},
{REG_INFO_VER( 13, Clip, typeNameListDefault, supportedTypeListFloat16to32Ints8to64, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly, requiredConstantCpuInputs(1,2))},
{REG_INFO( 7, Add, typeNameListDefault, supportedTypeListFloat16to32Ints32to64, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly)},
{REG_INFO( 13, Add, typeNameListDefault, supportedTypeListFloat16to32Ints32to64, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly)},
{REG_INFO( 7, Sub, typeNameListDefault, supportedTypeListFloat16to32Ints32to64, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly)},
{REG_INFO( 13, Sub, typeNameListDefault, supportedTypeListFloat16to32Ints32to64, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly)},
{REG_INFO( 7, Mul, typeNameListDefault, supportedTypeListFloat16to32Ints32to64, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly)},
{REG_INFO( 13, Mul, typeNameListDefault, supportedTypeListFloat16to32Ints32to64, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly)},
{REG_INFO( 7, Div, typeNameListDefault, supportedTypeListFloat16to32Ints32, DmlGraphSupport::Supported)},
{REG_INFO( 13, Div, typeNameListDefault, supportedTypeListFloat16to32Ints32, DmlGraphSupport::Supported)},
{REG_INFO( 7, Sum, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported, requiredConstantCpuInputs(), 2)},
{REG_INFO( 8, Sum, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported, requiredConstantCpuInputs(), 2)},
{REG_INFO( 13, Sum, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported, requiredConstantCpuInputs(), 2)},
{REG_INFO( 7, Mean, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported, requiredConstantCpuInputs(), 2)},
{REG_INFO( 8, Mean, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported, requiredConstantCpuInputs(), 2)},
{REG_INFO( 13, Mean, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported, requiredConstantCpuInputs(), 2)},
{REG_INFO( 7, Max, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported, requiredConstantCpuInputs(), 2)},
{REG_INFO( 8, Max, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported, requiredConstantCpuInputs(), 2)},
{REG_INFO( 12, Max, typeNameListDefault, supportedTypeListFloat16to32Ints8to64, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly, requiredConstantCpuInputs(), 2)},
{REG_INFO( 13, Max, typeNameListDefault, supportedTypeListFloat16to32Ints8to64, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly, requiredConstantCpuInputs(), 2)},
{REG_INFO( 7, Min, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported, requiredConstantCpuInputs(), 2)},
{REG_INFO( 8, Min, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported, requiredConstantCpuInputs(), 2)},
{REG_INFO( 12, Min, typeNameListDefault, supportedTypeListFloat16to32Ints8to64, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly, requiredConstantCpuInputs(), 2)},
{REG_INFO( 13, Min, typeNameListDefault, supportedTypeListFloat16to32Ints8to64, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly, requiredConstantCpuInputs(), 2)},
{REG_INFO( 7, Cos, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO( 7, Sin, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO( 7, Tan, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO( 7, Acos, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO( 7, Asin, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO( 7, Atan, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO( 7, Affine, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO( 10, QuantizeLinear, typeNameListTwo, supportedTypeListQuantizeLinear, DmlGraphSupport::Supported)},
{REG_INFO( 10, DequantizeLinear, typeNameListTwo, supportedTypeListDequantizeLinear, DmlGraphSupport::Supported)},
{REG_INFO_MS( 1, QuantizeLinear, typeNameListTwo, supportedTypeListQuantize, DmlGraphSupport::Supported)},
{REG_INFO_MS( 1, DequantizeLinear, typeNameListTwo, supportedTypeListQuantize, DmlGraphSupport::Supported)},
{REG_INFO( 9, Sign, typeNameListDefault, supportedTypeListNumericDefault, DmlGraphSupport::Supported)},
{REG_INFO( 13, Sign, typeNameListDefault, supportedTypeListNumericDefault, DmlGraphSupport::Supported)},
{REG_INFO( 9, IsNaN, typeNameListTwo, supportedTypeListIsNan, DmlGraphSupport::Supported)},
{REG_INFO( 13, IsNaN, typeNameListTwo, supportedTypeListIsNan, DmlGraphSupport::Supported)},
{REG_INFO( 9, Sinh, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO( 9, Cosh, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO( 9, Asinh, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO( 9, Acosh, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO( 9, Atanh, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO( 9, Erf, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO( 13, Erf, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO( 9, Where, typeNameListWhere, supportedTypeListWhere, DmlGraphSupport::Supported)},
{REG_INFO( 7, ReduceSum, typeNameListDefault, supportedTypeListFloat16to32Ints32to64, DmlGraphSupport::Supported)},
{REG_INFO( 11, ReduceSum, typeNameListDefault, supportedTypeListFloat16to32Ints32to64, DmlGraphSupport::Supported)},
{REG_INFO_VER( 12, Einsum, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported, requiredConstantCpuInputs(), std::nullopt, QueryEinSum )},
{REG_INFO( 7, ReduceMean, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO( 11, ReduceMean, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO( 13, ReduceMean, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO( 7, ReduceProd, typeNameListDefault, supportedTypeListFloat16to32Ints32to64, DmlGraphSupport::Supported)},
{REG_INFO( 11, ReduceProd, typeNameListDefault, supportedTypeListFloat16to32Ints32to64, DmlGraphSupport::Supported)},
{REG_INFO( 13, ReduceProd, typeNameListDefault, supportedTypeListFloat16to32Ints32to64, DmlGraphSupport::Supported)},
{REG_INFO( 7, ReduceLogSum, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO( 11, ReduceLogSum, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO( 13, ReduceLogSum, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO( 7, ReduceLogSumExp, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO( 11, ReduceLogSumExp, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO( 13, ReduceLogSumExp, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO( 7, ReduceSumSquare, typeNameListDefault, supportedTypeListFloat16to32Ints32to64, DmlGraphSupport::Supported)},
{REG_INFO( 11, ReduceSumSquare, typeNameListDefault, supportedTypeListFloat16to32Ints32to64, DmlGraphSupport::Supported)},
{REG_INFO( 13, ReduceSumSquare, typeNameListDefault, supportedTypeListFloat16to32Ints32to64, DmlGraphSupport::Supported)},
{REG_INFO( 7, ReduceL1, typeNameListDefault, supportedTypeListFloat16to32Ints32to64, DmlGraphSupport::Supported)},
{REG_INFO( 11, ReduceL1, typeNameListDefault, supportedTypeListFloat16to32Ints32to64, DmlGraphSupport::Supported)},
{REG_INFO( 13, ReduceL1, typeNameListDefault, supportedTypeListFloat16to32Ints32to64, DmlGraphSupport::Supported)},
{REG_INFO( 7, ReduceL2, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO( 11, ReduceL2, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO( 13, ReduceL2, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO( 7, ReduceMax, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO( 11, ReduceMax, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO( 12, ReduceMax, typeNameListDefault, supportedTypeListFloat16to32Ints8to64, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly)},
{REG_INFO( 13, ReduceMax, typeNameListDefault, supportedTypeListFloat16to32Ints8to64, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly)},
{REG_INFO( 7, ReduceMin, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO( 11, ReduceMin, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO( 12, ReduceMin, typeNameListDefault, supportedTypeListFloat16to32Ints8to64, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly)},
{REG_INFO( 13, ReduceMin, typeNameListDefault, supportedTypeListFloat16to32Ints8to64, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly)},
{REG_INFO( 7, ArgMax, typeNameListDefault, supportedTypeListArgMinMax, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly)},
{REG_INFO( 11, ArgMax, typeNameListDefault, supportedTypeListArgMinMax, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly)},
{REG_INFO( 12, ArgMax, typeNameListDefault, supportedTypeListArgMinMax, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly)},
{REG_INFO( 13, ArgMax, typeNameListDefault, supportedTypeListArgMinMax, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly)},
{REG_INFO( 7, ArgMin, typeNameListDefault, supportedTypeListArgMinMax, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly)},
{REG_INFO( 11, ArgMin, typeNameListDefault, supportedTypeListArgMinMax, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly)},
{REG_INFO( 12, ArgMin, typeNameListDefault, supportedTypeListArgMinMax, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly)},
{REG_INFO( 13, ArgMin, typeNameListDefault, supportedTypeListArgMinMax, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly)},
{REG_INFO( 7, Gemm, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO( 9, Gemm, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO( 11, Gemm, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO( 13, Gemm, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO( 7, Neg, typeNameListDefault, supportedTypeListSigned, DmlGraphSupport::Supported)},
{REG_INFO( 13, Neg, typeNameListDefault, supportedTypeListSigned, DmlGraphSupport::Supported)},
{REG_INFO( 7, Greater, typeNameListLogicalComparison, supportedTypeListLogicalComparison7, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly)},
{REG_INFO( 9, Greater, typeNameListLogicalComparison, supportedTypeListLogicalComparison9, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly)},
{REG_INFO( 13, Greater, typeNameListLogicalComparison, supportedTypeListLogicalComparison9, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly)},
{REG_INFO( 7, Less, typeNameListLogicalComparison, supportedTypeListLogicalComparison7, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly)},
{REG_INFO( 9, Less, typeNameListLogicalComparison, supportedTypeListLogicalComparison9, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly)},
{REG_INFO( 13, Less, typeNameListLogicalComparison, supportedTypeListLogicalComparison9, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly)},
{REG_INFO( 12, GreaterOrEqual, typeNameListLogicalComparison, supportedTypeListLogicalComparison9, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly)},
{REG_INFO( 12, LessOrEqual, typeNameListLogicalComparison, supportedTypeListLogicalComparison9, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly)},
{REG_INFO( 7, Equal, typeNameListLogicalComparison, supportedTypeListLogicalComparison7, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly)},
{REG_INFO( 11, Equal, typeNameListLogicalComparison, supportedTypeListLogicalComparison9, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly)},
{REG_INFO( 13, Equal, typeNameListLogicalComparison, supportedTypeListLogicalComparison9, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly)},
{REG_INFO( 7, Not, typeNameListDefault, supportedTypeListBool, DmlGraphSupport::Supported)},
{REG_INFO( 7, And, typeNameListDefault, supportedTypeListBool, DmlGraphSupport::Supported)},
{REG_INFO( 7, Or, typeNameListDefault, supportedTypeListBool, DmlGraphSupport::Supported)},
{REG_INFO( 7, Xor, typeNameListDefault, supportedTypeListBool, DmlGraphSupport::Supported)},
// Imaging Operators
{REG_INFO( 7, Crop, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO( 7, ImageScaler, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO_VER( 7, Upsample, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO_VER( 9, Upsample, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported, requiredConstantCpuInputs(1) /*scales*/)},
{REG_INFO_VER( 10, Upsample, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported, requiredConstantCpuInputs(1) /*scales*/)},
{REG_INFO_VER( 10, Resize, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported, requiredConstantCpuInputs(1) /*scales*/)},
{REG_INFO_VER( 11, Resize, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported, requiredConstantCpuInputs(1, 2, 3) /*roi, scales, sizes*/, std::nullopt, QueryResize)},
{REG_INFO( 7, Crop, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO( 7, ImageScaler, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO_VER( 7, Upsample, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO_VER( 9, Upsample, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported, requiredConstantCpuInputs(1) /*scales*/)},
{REG_INFO_VER( 10, Upsample, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported, requiredConstantCpuInputs(1) /*scales*/)},
{REG_INFO_VER( 13, Upsample, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported, requiredConstantCpuInputs(1) /*scales*/)},
{REG_INFO_VER( 10, Resize, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported, requiredConstantCpuInputs(1) /*scales*/)},
{REG_INFO_VER( 11, Resize, typeNameListTwo, supportedTypeListResize11, DmlGraphSupport::Supported, requiredConstantCpuInputs(1, 2, 3) /*roi, scales, sizes*/, std::nullopt, QueryResize)},
// Activation Functions
{REG_INFO( 7, Sigmoid, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO( 7, HardSigmoid, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO( 7, Tanh, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO( 7, ScaledTanh, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO( 7, Relu, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO( 7, LeakyRelu, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO( 7, PRelu, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO( 9, PRelu, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO( 7, ThresholdedRelu, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO( 10, ThresholdedRelu, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO( 7, Elu, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO( 12, Celu, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO( 7, Selu, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO( 7, Softmax, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO( 11, Softmax, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO( 7, LogSoftmax, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO( 11, LogSoftmax, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO( 7, Hardmax, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO( 11, Hardmax, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO( 7, Softsign, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO( 7, Softplus, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO( 7, ParametricSoftplus, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO( 7, Dropout, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO( 9, Shrink, typeNameListDefault, supportedTypeListNumericDefault, DmlGraphSupport::Supported)},
{REG_INFO( 7, Sigmoid, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO( 13, Sigmoid, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO( 7, HardSigmoid, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO( 7, Tanh, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO( 13, Tanh, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO( 7, ScaledTanh, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO( 7, Relu, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO( 13, Relu, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO( 7, LeakyRelu, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO( 7, PRelu, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO( 9, PRelu, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO( 7, ThresholdedRelu, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO( 10, ThresholdedRelu, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO( 7, Elu, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO( 12, Celu, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO( 7, Selu, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO( 7, Softmax, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO( 11, Softmax, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO( 7, LogSoftmax, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO( 11, LogSoftmax, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO( 7, Hardmax, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO( 11, Hardmax, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO( 7, Softsign, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO( 7, Softplus, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO( 7, ParametricSoftplus, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO( 7, Dropout, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO( 9, Shrink, typeNameListDefault, supportedTypeListNumericDefault, DmlGraphSupport::Supported)},
// Uncategorized
{REG_INFO( 7, MatMul, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO( 9, MatMul, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO( 7, Cast, typeNameListTwo, supportedTypeListCast, DmlGraphSupport::Supported)},
{REG_INFO( 9, Cast, typeNameListTwo, supportedTypeListCast, DmlGraphSupport::Supported)},
{REG_INFO( 7, MatMul, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO( 9, MatMul, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO( 13, MatMul, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO( 7, Cast, typeNameListTwo, supportedTypeListCast, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly)},
{REG_INFO( 9, Cast, typeNameListTwo, supportedTypeListCast, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly)},
{REG_INFO( 13, Cast, typeNameListTwo, supportedTypeListCast, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly)},
{REG_INFO( 7, MemcpyFromHost, typeNameListDefault, supportedTypeListAll)},
{REG_INFO( 7, MemcpyToHost, typeNameListDefault, supportedTypeListAll)},
{REG_INFO_VER( 7, TopK, typeNameListTopK, supportedTypeListTopK, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides)},
{REG_INFO_VER( 10, TopK, typeNameListTopK, supportedTypeListTopK, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides, requiredConstantCpuInputs(1))},
{REG_INFO_VER( 11, TopK, typeNameListTopK, supportedTypeListTopK, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides, requiredConstantCpuInputs(1))},
{REG_INFO( 9, OneHot, typeNameListThree, supportedTypeListOneHot, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStridesFromAnyEp, requiredConstantCpuInputs(1))},
{REG_INFO( 11, OneHot, typeNameListThree, supportedTypeListOneHot, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStridesFromAnyEp, requiredConstantCpuInputs(1))},
{REG_INFO_VER( 7, TopK, typeNameListTopK, supportedTypeListTopK, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly)},
{REG_INFO_VER( 10, TopK, typeNameListTopK, supportedTypeListTopK, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly, requiredConstantCpuInputs(1))},
{REG_INFO_VER( 11, TopK, typeNameListTopK, supportedTypeListTopK, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly, requiredConstantCpuInputs(1))},
{REG_INFO( 9, OneHot, typeNameListThree, supportedTypeListOneHot, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly, requiredConstantCpuInputs(1))},
{REG_INFO( 11, OneHot, typeNameListThree, supportedTypeListOneHot, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly, requiredConstantCpuInputs(1))},
// Fused operators
{REG_INFO_MSDML(1, FusedConv, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO_MSDML(1, FusedConvTranspose, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO_MSDML(1, FusedInstanceNormalization, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO_MSDML(1, FusedBatchNormalization, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO_MSDML(1, FusedMeanVarianceNormalization, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO_MSDML(1, FusedGemm, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO_MSDML(1, FusedMatMul, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO_MSDML(1, FusedAdd, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO_MSDML(1, FusedSum, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported, requiredConstantCpuInputs(), 2)},
{REG_INFO_MSDML(1, FusedConv, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO_MSDML(1, FusedConvTranspose, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO_MSDML(1, FusedInstanceNormalization, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO_MSDML(1, FusedBatchNormalization, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO_MSDML(1, FusedMeanVarianceNormalization, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO_MSDML(1, FusedGemm, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO_MSDML(1, FusedMatMul, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO_MSDML(1, FusedAdd, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO_MSDML(1, FusedSum, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported, requiredConstantCpuInputs(), 2)},
{REG_INFO( 10, IsInf, typeNameListTwo, supportedTypeListIsInf, DmlGraphSupport::Supported)},
{REG_INFO( 10, Mod, typeNameListDefault, supportedTypeListNumericDefault, DmlGraphSupport::Supported)},
{REG_INFO( 11, BitShift, typeNameListDefault, supportedTypeListInt8to32, DmlGraphSupport::Supported)},
{REG_INFO( 11, Round, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO( 10, ReverseSequence, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStridesFromAnyEp)},
{REG_INFO( 11, CumSum, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported, requiredConstantCpuInputs(1))},
{REG_INFO( 11, Range, typeNameListDefault, supportedTypeListRange, DmlGraphSupport::Supported, requiredConstantCpuInputs(0,1,2))},
{REG_INFO( 10, IsInf, typeNameListTwo, supportedTypeListIsInf, DmlGraphSupport::Supported)},
{REG_INFO( 10, Mod, typeNameListDefault, supportedTypeListNumericDefault, DmlGraphSupport::Supported)},
{REG_INFO( 13, Mod, typeNameListDefault, supportedTypeListNumericDefault, DmlGraphSupport::Supported)},
{REG_INFO( 11, BitShift, typeNameListDefault, supportedTypeListUInt8to64, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly)},
{REG_INFO( 11, Round, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO( 10, ReverseSequence, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly)},
{REG_INFO( 11, CumSum, typeNameListDefault, supportedTypeListFloat16to32Ints32to64, DmlGraphSupport::Supported, requiredConstantCpuInputs(1))},
{REG_INFO( 11, Range, typeNameListDefault, supportedTypeListRange, DmlGraphSupport::Supported, requiredConstantCpuInputs(0,1,2))},
{REG_INFO( 9, MaxUnpool, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStridesFromAnyEp, requiredConstantCpuInputs(2))},
{REG_INFO( 11, MaxUnpool, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStridesFromAnyEp, requiredConstantCpuInputs(2))}, // 11 is identical to 9.
{REG_INFO( 10, QLinearConv, typeNameListFour, supportedTypeListQLinearConv, DmlGraphSupport::NotSupported)},
{REG_INFO( 10, QLinearMatMul, typeNameListThree, supportedTypeListQLinearMatMul, DmlGraphSupport::NotSupported)},
{REG_INFO( 10, MatMulInteger, typeNameListThree, supportedTypeListInteger, DmlGraphSupport::NotSupported)},
{REG_INFO( 10, ConvInteger, typeNameListThree, supportedTypeListInteger, DmlGraphSupport::NotSupported)},
{REG_INFO( 9, MaxUnpool, typeNameListTwo, supportedTypeListMaxUnpool, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly, requiredConstantCpuInputs(2))},
{REG_INFO( 11, MaxUnpool, typeNameListTwo, supportedTypeListMaxUnpool, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly, requiredConstantCpuInputs(2))}, // 11 is identical to 9.
{REG_INFO_MS( 1, QLinearAdd, typeNameListDefault, supportedTypeListInteger8, DmlGraphSupport::Supported)},
{REG_INFO( 10, QLinearConv, typeNameListFour, supportedTypeListQLinearConv, DmlGraphSupport::Supported)},
{REG_INFO( 10, QLinearMatMul, typeNameListThree, supportedTypeListQLinearMatMul, DmlGraphSupport::Supported)},
{REG_INFO( 10, MatMulInteger, typeNameListThree, supportedTypeListInteger, DmlGraphSupport::Supported)},
{REG_INFO( 10, ConvInteger, typeNameListThree, supportedTypeListInteger, DmlGraphSupport::Supported)},
{REG_INFO( 11, DynamicQuantizeLinear, typeNameListTwo, supportedTypeListDynamicQuantizeLinear, DmlGraphSupport::Supported)},
};
template<typename T>
@ -607,6 +697,7 @@ void RegisterDmlOperators(IMLOperatorRegistry* registry)
// The graph must be configured with operators from only the legacy DML API, or only the new DML API
bool kernelSupportsGraph = !bool(information.dmlGraphSupport & DmlGraphSupport::NotSupported);
bool prefer64BitTensorsDirectly = bool(information.dmlGraphSupport & DmlGraphSupport::Prefer64BitTensorsDirectly);
bool support64BitTensorsViaEmulation = bool(information.dmlGraphSupport & DmlGraphSupport::Support64BitTensorsViaEmulation);
bool supportedWith64BitTensorsVia32BitStrides = bool(information.dmlGraphSupport & DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides);
bool supportedWith64BitTensorsVia32BitStridesFromAnyEp = bool(information.dmlGraphSupport & DmlGraphSupport::SupportedWith64BitTensorsVia32BitStridesFromAnyEp);
@ -692,6 +783,7 @@ void RegisterDmlOperators(IMLOperatorRegistry* registry)
supportedWith64BitTensorsVia32BitStrides,
supportedWith64BitTensorsVia32BitStridesFromAnyEp,
prefer64BitTensorsDirectly,
support64BitTensorsViaEmulation,
information.requiredConstantCpuInputs.first.data(),
static_cast<uint32_t>(information.requiredConstantCpuInputs.second)
));

View file

@ -151,15 +151,20 @@ namespace Dml
OperatorInfo{ "InstanceNormalization", onnxruntime::kOnnxDomain, OnnxOperatorSet7::sc_sinceVer_InstanceNormalization },
OperatorInfo{ "MeanVarianceNormalization", onnxruntime::kOnnxDomain, OnnxOperatorSet7::sc_sinceVer_MeanVarianceNormalization },
OperatorInfo{ "MeanVarianceNormalization", onnxruntime::kOnnxDomain, OnnxOperatorSet9::sc_sinceVer_MeanVarianceNormalization },
OperatorInfo{ "MeanVarianceNormalization", onnxruntime::kOnnxDomain, OnnxOperatorSet13::sc_sinceVer_MeanVarianceNormalization },
OperatorInfo{ "Gemm", onnxruntime::kOnnxDomain, OnnxOperatorSet7::sc_sinceVer_Gemm },
OperatorInfo{ "Gemm", onnxruntime::kOnnxDomain, OnnxOperatorSet9::sc_sinceVer_Gemm },
OperatorInfo{ "Gemm", onnxruntime::kOnnxDomain, OnnxOperatorSet11::sc_sinceVer_Gemm },
OperatorInfo{ "Gemm", onnxruntime::kOnnxDomain, OnnxOperatorSet13::sc_sinceVer_Gemm },
OperatorInfo{ "MatMul", onnxruntime::kOnnxDomain, OnnxOperatorSet7::sc_sinceVer_MatMul },
OperatorInfo{ "MatMul", onnxruntime::kOnnxDomain, OnnxOperatorSet9::sc_sinceVer_MatMul },
OperatorInfo{ "MatMul", onnxruntime::kOnnxDomain, OnnxOperatorSet13::sc_sinceVer_MatMul },
// The filter for activation functions maps to what DML's fused op internally fuses at the shader level.
OperatorInfo{ "Add", onnxruntime::kOnnxDomain, OnnxOperatorSet7::sc_sinceVer_Add, {"Relu", "LeakyRelu"} },
OperatorInfo{ "Add", onnxruntime::kOnnxDomain, OnnxOperatorSet13::sc_sinceVer_Add, {"Relu", "LeakyRelu"} },
OperatorInfo{ "Sum", onnxruntime::kOnnxDomain, OnnxOperatorSet8::sc_sinceVer_Sum, {"Relu", "LeakyRelu"}, 2 },
OperatorInfo{ "Sum", onnxruntime::kOnnxDomain, OnnxOperatorSet13::sc_sinceVer_Sum, {"Relu", "LeakyRelu"}, 2 },
};
// Not all activations can be fused - only simple elementwise activations (i.e. activation functions which
@ -167,10 +172,13 @@ namespace Dml
static const OperatorInfo c_activationOps[] =
{
OperatorInfo{ "Sigmoid", onnxruntime::kOnnxDomain, OnnxOperatorSet7::sc_sinceVer_Sigmoid },
OperatorInfo{ "Sigmoid", onnxruntime::kOnnxDomain, OnnxOperatorSet13::sc_sinceVer_Sigmoid },
OperatorInfo{ "HardSigmoid", onnxruntime::kOnnxDomain, OnnxOperatorSet7::sc_sinceVer_HardSigmoid },
OperatorInfo{ "Tanh", onnxruntime::kOnnxDomain, OnnxOperatorSet7::sc_sinceVer_Tanh },
OperatorInfo{ "Tanh", onnxruntime::kOnnxDomain, OnnxOperatorSet13::sc_sinceVer_Tanh },
OperatorInfo{ "ScaledTanh", onnxruntime::kOnnxDomain, OnnxOperatorSet7::sc_sinceVer_ScaledTanh },
OperatorInfo{ "Relu", onnxruntime::kOnnxDomain, OnnxOperatorSet7::sc_sinceVer_Relu },
OperatorInfo{ "Relu", onnxruntime::kOnnxDomain, OnnxOperatorSet13::sc_sinceVer_Relu },
OperatorInfo{ "LeakyRelu", onnxruntime::kOnnxDomain, OnnxOperatorSet7::sc_sinceVer_LeakyRelu },
OperatorInfo{ "PRelu", onnxruntime::kOnnxDomain, OnnxOperatorSet7::sc_sinceVer_PRelu },
OperatorInfo{ "PRelu", onnxruntime::kOnnxDomain, OnnxOperatorSet9::sc_sinceVer_PRelu },

View file

@ -4,27 +4,47 @@
#pragma once
#define ML_CHECK_VALID_ARGUMENT(x, ...)\
{\
if ((x) == false) {\
ORT_THROW_HR(E_INVALIDARG);\
}\
}
{\
if ((x) == false)\
{\
ORT_THROW_HR(E_INVALIDARG);\
}\
}
#define ML_INVALID_ARGUMENT(msg)\
ORT_THROW_HR(E_INVALIDARG);\
ORT_THROW_HR(E_INVALIDARG);\
#define ML_CHECK_HRESULT(hr, ...)\
{\
if (FAILED(hr)) {\
ORT_THROW_HR(E_INVALIDARG);\
}\
}
{\
if (FAILED(hr))\
{\
ORT_THROW_HR(E_INVALIDARG);\
}\
}
namespace OperatorHelper
{
template<typename T, typename I> T clamp_cast(I input)
// Clamp the input value to the maximum range of the output, before casting to the output type.
//
// e.g. int32(300) would yield int8(255) rather than int8(44).
// float32(-42) would yield uint32(0) rather than a huge positive number.
template<typename OutputType, typename InputType> OutputType clamp_cast(InputType input)
{
return static_cast<T>(std::clamp<I>(input, std::numeric_limits<T>::lowest(), std::numeric_limits<T>::max()));
// Determine the larger type to decide which numeric limits to clamp to.
using InputLimits = std::numeric_limits<InputType>;
using OutputLimits = std::numeric_limits<OutputType>;
constexpr int inputMaxDigits = std::max(InputLimits::max_exponent, InputLimits::digits);
constexpr int outputMaxDigits = std::max(OutputLimits::max_exponent, OutputLimits::digits);
constexpr bool isEitherTypeUnsigned = std::is_unsigned_v<InputType> || std::is_unsigned_v<OutputType>;
constexpr bool isOutputTypeLarger = outputMaxDigits > inputMaxDigits;
InputType lowestValue = isEitherTypeUnsigned ? static_cast<InputType>(0) :
isOutputTypeLarger ? InputLimits::lowest() :
static_cast<InputType>(OutputLimits::lowest());
InputType highestValue = isOutputTypeLarger ? InputLimits::max() :
static_cast<InputType>(OutputLimits::max());
return static_cast<OutputType>(std::clamp<InputType>(input, lowestValue, highestValue));
}
enum TensorAxis { N, C, H, W, DoNotCoerce = INT_MAX, LeftAligned = INT_MAX, RightAligned = INT_MIN, NoPlacementAdjustment = 0 };
enum BroadcastMode { NoBroadcast, UnidirectionalBroadcast, MultidirectionalBroadcast };

View file

@ -114,6 +114,7 @@ IMLOperatorRegistryPrivate : public IUnknown
bool supportedWith64BitTensorsVia32BitStrides = false,
bool supportedWith64BitTensorsVia32BitStridesFromAnyEp = false,
bool prefer64BitTensorsDirectly = false,
bool support64BitTensorsViaEmulation = false,
_In_reads_(constantCpuInputCount) const uint32_t* constantCpuInputs = nullptr,
uint32_t constantCpuInputCount = 0
) const noexcept PURE;

View file

@ -32,87 +32,6 @@ namespace OperatorHelper
}
}
void ReadCpuLocalTensorIntoInt32(
const MLOperatorTensor& tensor,
std::vector<int32_t>& result
)
{
result.clear();
ML_CHECK_VALID_ARGUMENT(tensor.IsCpuData(), "Tensor must be CPU Tensor.");
const std::vector<uint32_t>& tensorDimensions = tensor.GetShape();
const uint32_t elementCount = ComputeElementCountFromDimensions(tensorDimensions);
switch (tensor.GetTensorDataType())
{
case MLOperatorTensorDataType::Int32:
{
const int32_t* data = tensor.GetData<int32_t>();
result.assign(data, data + elementCount);
}
break;
case MLOperatorTensorDataType::Int64:
{
const int64_t* data = tensor.GetData<int64_t>();
result.reserve(elementCount);
// Use clamped cast rather than static_cast/narrow_cast,
// because it's not uncommon for a model to specify a
// 64-bit INTMAX constant as a sentinel value to mean
// the largest possible value (even though the actual
// dimension values come nowhere close to that, far
// less than 32-bit INTMAX).
for (auto d : gsl::make_span(data, data + elementCount))
{
result.push_back(clamp_cast<int32_t>(d));
}
}
break;
default:
ML_INVALID_ARGUMENT("Expecting CPU local tensor of type int32 or int64.");
break;
}
}
void ReadCpuLocalTensorIntoFloat32(
const MLOperatorTensor& tensor,
std::vector<float>& result
)
{
result.clear();
ML_CHECK_VALID_ARGUMENT(tensor.IsCpuData(), "Tensor must be CPU Tensor.");
const std::vector<uint32_t>& tensorDimensions = tensor.GetShape();
const uint32_t elementCount = ComputeElementCountFromDimensions(tensorDimensions);
switch (tensor.GetTensorDataType())
{
case MLOperatorTensorDataType::Float:
{
const float* data = tensor.GetData<float>();
result.assign(data, data + elementCount);
}
break;
default:
ML_INVALID_ARGUMENT("Expecting CPU local tensor of type float32.");
break;
}
}
void DowncastDimensions(gsl::span<const int64_t> inputDimensions, std::vector<DimensionType>& outputDimensions)
{
outputDimensions.reserve(inputDimensions.size());
outputDimensions.clear();
for (int64_t dim : inputDimensions)
{
outputDimensions.push_back(gsl::narrow_cast<uint32_t>(std::clamp<int64_t>(dim, INT32_MIN, INT32_MAX)));
}
}
float CastFloat16ToFloat32(uint16_t input)
{
// Promote float16m10e5s1 to float32m23e8s1.
@ -201,6 +120,130 @@ namespace OperatorHelper
}
#pragma warning(pop)
void ReadCpuLocalTensorIntoInt32(
const MLOperatorTensor& tensor,
std::vector<int32_t>& result
)
{
result.clear();
ML_CHECK_VALID_ARGUMENT(tensor.IsCpuData(), "Tensor must be CPU Tensor.");
const std::vector<uint32_t>& tensorDimensions = tensor.GetShape();
const uint32_t elementCount = ComputeElementCountFromDimensions(tensorDimensions);
switch (tensor.GetTensorDataType())
{
case MLOperatorTensorDataType::Int32:
{
const int32_t* data = tensor.GetData<int32_t>();
result.assign(data, data + elementCount);
}
break;
case MLOperatorTensorDataType::Int64:
{
const int64_t* data = tensor.GetData<int64_t>();
result.reserve(elementCount);
// Use clamped cast rather than static_cast/narrow_cast,
// because it's not uncommon for a model to specify a
// 64-bit INTMAX constant as a sentinel value to mean
// the largest possible value (even though the actual
// dimension values come nowhere close to that, far
// less than 32-bit INTMAX).
for (auto d : gsl::make_span(data, data + elementCount))
{
result.push_back(clamp_cast<int32_t>(d));
}
}
break;
default:
ML_INVALID_ARGUMENT("Expecting CPU local tensor of type int32 or int64.");
break;
}
}
void ReadCpuLocalTensorIntoFloat32(
const MLOperatorTensor& tensor,
std::vector<float>& result
)
{
result.clear();
ML_CHECK_VALID_ARGUMENT(tensor.IsCpuData(), "Tensor must be CPU Tensor.");
const std::vector<uint32_t>& tensorDimensions = tensor.GetShape();
const uint32_t elementCount = ComputeElementCountFromDimensions(tensorDimensions);
result.resize(elementCount);
switch (tensor.GetTensorDataType())
{
case MLOperatorTensorDataType::Float16:
{
const uint16_t* data = reinterpret_cast<const uint16_t*>(tensor.GetByteData());
std::transform(data, data + elementCount, result.begin(), CastFloat16ToFloat32);
}
break;
case MLOperatorTensorDataType::/*Float32*/Float:
{
const float* data = tensor.GetData<float>();
result.assign(data, data + elementCount);
}
break;
case MLOperatorTensorDataType::/*Float64*/Double:
{
const double* data = tensor.GetData<double>();
std::transform(data, data + elementCount, result.begin(), [](auto v) {return static_cast<float>(v); });
}
break;
case MLOperatorTensorDataType::Int32:
{
const int32_t* data = tensor.GetData<int32_t>();
std::transform(data, data + elementCount, result.begin(), [](auto v) {return static_cast<float>(v); });
}
break;
case MLOperatorTensorDataType::UInt32:
{
const uint32_t* data = tensor.GetData<uint32_t>();
std::transform(data, data + elementCount, result.begin(), [](auto v) {return static_cast<float>(v); });
}
break;
case MLOperatorTensorDataType::Int64:
{
const int64_t* data = tensor.GetData<int64_t>();
std::transform(data, data + elementCount, result.begin(), [](auto v) {return static_cast<float>(v); });
}
break;
case MLOperatorTensorDataType::UInt64:
{
const uint64_t* data = tensor.GetData<uint64_t>();
std::transform(data, data + elementCount, result.begin(), [](auto v) {return static_cast<float>(v); });
}
break;
default:
ML_INVALID_ARGUMENT("Expecting CPU local tensor of type float32.");
break;
}
}
void DowncastDimensions(gsl::span<const int64_t> inputDimensions, std::vector<DimensionType>& outputDimensions)
{
outputDimensions.reserve(inputDimensions.size());
outputDimensions.clear();
for (int64_t dim : inputDimensions)
{
outputDimensions.push_back(gsl::narrow_cast<uint32_t>(std::clamp<int64_t>(dim, INT32_MIN, INT32_MAX)));
}
}
int64_t IsFloatDataType(MLOperatorTensorDataType tensorDataType)
{
switch (tensorDataType)
@ -1089,7 +1132,7 @@ namespace OperatorHelper
ML_CHECK_VALID_ARGUMENT(inputCount + 1 == m_components.size(), "Mismatch between input tensor count and string equation component count.");
ML_CHECK_VALID_ARGUMENT(outputCount == 1, "EinSum expects exactly 1 output tensor.");
std::vector<uint32_t> labelSizes(m_labelIndices.size(), static_cast<uint32_t>(INT_MIN));
std::vector<uint32_t> labelSizes(m_labelIndices.size(), UINT_MAX);
// Read every input tensor, comparing labels to ensure consistent sizes from the equation parsed earlier.
for (uint32_t i = 0; i < inputCount; ++i)
@ -1110,7 +1153,7 @@ namespace OperatorHelper
uint32_t labelIndex = labelIndices[j];
assert(labelIndex < labelSizes.size());
if (labelSizes[labelIndex] == INT_MIN)
if (labelSizes[labelIndex] == UINT_MAX)
{
labelSizes[labelIndex] = dimensionSize;
}

View file

@ -142,7 +142,7 @@ namespace OperatorHelper
namespace OnnxOperatorSet9
{
static const int sc_sinceVer_Sign = 9;
static const int sc_sinceVer_IsNan = 9;
static const int sc_sinceVer_IsNaN = 9;
static const int sc_sinceVer_Sinh = 9;
static const int sc_sinceVer_Cosh = 9;
static const int sc_sinceVer_Asinh = 9;
@ -233,6 +233,7 @@ namespace OperatorHelper
static const int sc_sinceVer_ReduceSumSquare = 11;
static const int sc_sinceVer_Resize = 11;
static const int sc_sinceVer_Round = 11;
static const int sc_sinceVer_DynamicQuantizeLinear = 11;
static const int sc_sinceVer_Scan = 11;
static const int sc_sinceVer_Scatter = 11; // Deprecated alias
static const int sc_sinceVer_ScatterElements = 11;
@ -263,6 +264,73 @@ namespace OperatorHelper
static const int sc_sinceVer_ReduceMin = 12;
} // namespace OnnxOperatorSet12
namespace OnnxOperatorSet13
{
static const int sc_sinceVer_Abs = 13;
static const int sc_sinceVer_Add = 13;
static const int sc_sinceVer_ArgMax = 13;
static const int sc_sinceVer_ArgMin = 13;
static const int sc_sinceVer_Cast = 13;
static const int sc_sinceVer_Ceil = 13;
static const int sc_sinceVer_Clip = 13;
static const int sc_sinceVer_Concat = 13;
static const int sc_sinceVer_Constant = 13;
static const int sc_sinceVer_DepthToSpace = 13;
static const int sc_sinceVer_Div = 13;
static const int sc_sinceVer_Equal = 13;
static const int sc_sinceVer_Erf = 13;
static const int sc_sinceVer_Exp = 13;
static const int sc_sinceVer_Expand = 13;
static const int sc_sinceVer_Flatten = 13;
static const int sc_sinceVer_Floor = 13;
static const int sc_sinceVer_Gather = 13;
static const int sc_sinceVer_GatherElements = 13;
static const int sc_sinceVer_GatherND = 13;
static const int sc_sinceVer_Gemm = 13;
static const int sc_sinceVer_Greater = 13;
static const int sc_sinceVer_Identity = 13;
static const int sc_sinceVer_IsNaN = 13;
static const int sc_sinceVer_LRN = 13;
static const int sc_sinceVer_Less = 13;
static const int sc_sinceVer_Log = 13;
static const int sc_sinceVer_MatMul = 13;
static const int sc_sinceVer_Max = 13;
static const int sc_sinceVer_Mean = 13;
static const int sc_sinceVer_MeanVarianceNormalization = 13;
static const int sc_sinceVer_Min = 13;
static const int sc_sinceVer_Mod = 13;
static const int sc_sinceVer_Mul = 13;
static const int sc_sinceVer_Neg = 13;
static const int sc_sinceVer_Pad = 13;
static const int sc_sinceVer_Pow = 13;
static const int sc_sinceVer_Reciprocal = 13;
static const int sc_sinceVer_ReduceL1 = 13;
static const int sc_sinceVer_ReduceL2 = 13;
static const int sc_sinceVer_ReduceLogSum = 13;
static const int sc_sinceVer_ReduceLogSumExp = 13;
static const int sc_sinceVer_ReduceMax = 13;
static const int sc_sinceVer_ReduceMean = 13;
static const int sc_sinceVer_ReduceMin = 13;
static const int sc_sinceVer_ReduceProd = 13;
static const int sc_sinceVer_ReduceSumSquare = 13;
static const int sc_sinceVer_Relu = 13;
static const int sc_sinceVer_Reshape = 13;
static const int sc_sinceVer_Scatter = 13;
static const int sc_sinceVer_ScatterElements = 13;
static const int sc_sinceVer_ScatterND = 13;
static const int sc_sinceVer_Sigmoid = 13;
static const int sc_sinceVer_Sign = 13;
static const int sc_sinceVer_Slice = 13;
static const int sc_sinceVer_SpaceToDepth = 13;
static const int sc_sinceVer_Sqrt = 13;
static const int sc_sinceVer_Sub = 13;
static const int sc_sinceVer_Sum = 13;
static const int sc_sinceVer_Tanh = 13;
static const int sc_sinceVer_Tile = 13;
static const int sc_sinceVer_Transpose = 13;
static const int sc_sinceVer_Upsample = 13;
} // namespace OnnxOperatorSet13
namespace MsftOperatorSet1
{
static const int sc_sinceVer_FusedConv = 1;
@ -277,6 +345,7 @@ namespace OperatorHelper
static const int sc_sinceVer_QuantizeLinear = 1;
static const int sc_sinceVer_DequantizeLinear = 1;
static const int sc_sinceVer_ConvTransposeWithDynamicPads = 1;
static const int sc_sinceVer_QLinearAdd = 1;
} // namespace MsftOperatorSet1
} // namespace OperatorHelper

View file

@ -58,6 +58,7 @@ HRESULT STDMETHODCALLTYPE AbiCustomRegistryImpl::RegisterOperatorKernel(
bool supportedWith64BitTensorsVia32BitStrides,
bool supportedWith64BitTensorsVia32BitStridesFromAnyEp,
bool prefer64BitTensorsDirectly,
bool support64BitTensorsViaEmulation,
_In_reads_(constantCpuInputCount) const uint32_t* requiredConstantCpuInputs,
uint32_t constantCpuInputCount) const noexcept try {
#ifdef LAYERING_DONE
@ -83,6 +84,7 @@ HRESULT STDMETHODCALLTYPE AbiCustomRegistryImpl::RegisterOperatorKernel(
supportedWith64BitTensorsVia32BitStrides,
supportedWith64BitTensorsVia32BitStridesFromAnyEp,
prefer64BitTensorsDirectly,
support64BitTensorsViaEmulation,
requiredConstantCpuInputs,
constantCpuInputCount);
}

View file

@ -31,6 +31,7 @@ class AbiCustomRegistryImpl : public AbiCustomRegistry {
bool supports_64bit_directly = false,
bool allows_64bit_via_strides = false,
bool allows_64bit_via_strides_from_any_ep = false,
bool supports_64bit_tensors_via_emulation = false,
_In_reads_(constant_cpu_input_count) const uint32_t* required_constant_cpu_inputs = nullptr,
uint32_t constant_cpu_input_count = 0) const noexcept override;